166 lines
4.4 KiB
Python
166 lines
4.4 KiB
Python
|
from node import Node, Comparable, RotateFunction
|
||
|
from ranked_tree import RankedTree
|
||
|
|
||
|
import logging
|
||
|
from typing import TypeVar, Optional
|
||
|
|
||
|
logger = logging.getLogger(__name__)
|
||
|
T = TypeVar("T", bound=Comparable)
|
||
|
|
||
|
|
||
|
def _balance_factor(node: Optional[Node[T]]) -> int:
|
||
|
if not node:
|
||
|
return 0
|
||
|
|
||
|
left, right = Node.get_rank(node.left), Node.get_rank(node.right)
|
||
|
return right - left
|
||
|
|
||
|
|
||
|
def _update_rank(node: Node[T]) -> None:
|
||
|
left, right = Node.get_rank(node.left), Node.get_rank(node.right)
|
||
|
# logger.debug(f"[_update_rank] on {node} = ({left}, {right})")
|
||
|
node.rank = 1 + max(left, right)
|
||
|
|
||
|
|
||
|
class AVLTree(RankedTree[T]):
|
||
|
def is_correct_node(
|
||
|
self, node: Optional[Node[T]], recursive: bool = True
|
||
|
) -> bool:
|
||
|
if not node:
|
||
|
return True
|
||
|
|
||
|
if not (-1 <= _balance_factor(node) <= 1):
|
||
|
return False
|
||
|
|
||
|
return not recursive or (
|
||
|
self.is_correct_node(node.left)
|
||
|
and self.is_correct_node(node.right)
|
||
|
)
|
||
|
|
||
|
# region InsertRebalance
|
||
|
|
||
|
def __fix_0_child(
|
||
|
self,
|
||
|
x: Node[T],
|
||
|
y: Optional[Node[T]],
|
||
|
z: Node[T],
|
||
|
rotate_left: RotateFunction[T],
|
||
|
rotate_right: RotateFunction[T],
|
||
|
) -> Optional[Node[T]]:
|
||
|
new_root = x.parent
|
||
|
|
||
|
if not y or Node.difference(y) == 2:
|
||
|
new_root = rotate_right(z)
|
||
|
z.demote()
|
||
|
elif Node.difference(y) == 1:
|
||
|
rotate_left(x)
|
||
|
new_root = rotate_right(z)
|
||
|
|
||
|
y.promote()
|
||
|
x.demote()
|
||
|
z.demote()
|
||
|
|
||
|
return new_root
|
||
|
|
||
|
def _insert_rebalance(self, x: Node[T]) -> None:
|
||
|
diffs = Node.differences(x.parent)
|
||
|
while x.parent and (diffs == (0, 1) or diffs == (1, 0)):
|
||
|
x.parent.promote()
|
||
|
x = x.parent
|
||
|
|
||
|
diffs = Node.differences(x.parent)
|
||
|
|
||
|
if not x.parent:
|
||
|
return
|
||
|
|
||
|
rotating_around_root = x.parent.parent is None
|
||
|
new_root: Optional[Node[T]] = x.parent
|
||
|
|
||
|
rank_difference = Node.difference(x)
|
||
|
if rank_difference != 0:
|
||
|
return
|
||
|
|
||
|
x_parent = x.parent
|
||
|
assert x_parent is not None
|
||
|
if rank_difference == 0 and x.parent.left is x:
|
||
|
new_root = self.__fix_0_child(
|
||
|
x, x.right, x_parent, Node.rotate_left, Node.rotate_right
|
||
|
)
|
||
|
elif rank_difference == 0 and x.parent.right is x:
|
||
|
new_root = self.__fix_0_child(
|
||
|
x, x.left, x_parent, Node.rotate_right, Node.rotate_left
|
||
|
)
|
||
|
|
||
|
if rotating_around_root:
|
||
|
self.root = new_root
|
||
|
|
||
|
# endregion InsertRebalance
|
||
|
|
||
|
# region DeleteRebalance
|
||
|
|
||
|
def __delete_rotate(
|
||
|
self,
|
||
|
x: Node[T],
|
||
|
y: Node[T],
|
||
|
leaning: int,
|
||
|
rotating_around_root: bool,
|
||
|
rotate_left: RotateFunction[T],
|
||
|
rotate_right: RotateFunction[T],
|
||
|
) -> bool:
|
||
|
new_root = x
|
||
|
|
||
|
factor = _balance_factor(y)
|
||
|
if factor in (0, leaning):
|
||
|
new_root = rotate_left(x)
|
||
|
else:
|
||
|
rotate_right(y)
|
||
|
new_root = rotate_left(x)
|
||
|
|
||
|
for n in filter(None, (new_root.left, new_root.right, new_root)):
|
||
|
_update_rank(n)
|
||
|
|
||
|
if rotating_around_root:
|
||
|
self.root = new_root
|
||
|
return factor == 0
|
||
|
|
||
|
def __delete_fixup(
|
||
|
self, y: Optional[Node[T]], parent: Optional[Node[T]] = None
|
||
|
) -> bool:
|
||
|
x = y if y else parent
|
||
|
assert x
|
||
|
|
||
|
factor = _balance_factor(x)
|
||
|
if factor == 0:
|
||
|
_update_rank(x)
|
||
|
return False
|
||
|
elif factor in (-1, 1):
|
||
|
return True
|
||
|
|
||
|
rotating_around_root = x.parent is None
|
||
|
y, leaning, to_left, to_right = (
|
||
|
(x.right, 1, Node.rotate_left, Node.rotate_right)
|
||
|
if factor == 2
|
||
|
else (x.left, -1, Node.rotate_right, Node.rotate_left)
|
||
|
)
|
||
|
assert y
|
||
|
return self.__delete_rotate(
|
||
|
x,
|
||
|
y,
|
||
|
leaning,
|
||
|
rotating_around_root,
|
||
|
to_left,
|
||
|
to_right,
|
||
|
)
|
||
|
|
||
|
def _delete_rebalance(
|
||
|
self, node: Optional[Node[T]], parent: Optional[Node[T]]
|
||
|
) -> None:
|
||
|
while node or parent:
|
||
|
# TODO: Check if it is possible to not propagate all the way up.
|
||
|
# if self.__delete_fixup(node, parent):
|
||
|
# return
|
||
|
self.__delete_fixup(node, parent)
|
||
|
node, parent = parent, (parent.parent if parent else None)
|
||
|
|
||
|
# endregion DeleteRebalance
|