python/avl.py

166 lines
4.4 KiB
Python
Raw Normal View History

from node import Node, Comparable, RotateFunction
from ranked_tree import RankedTree
import logging
from typing import TypeVar, Optional
logger = logging.getLogger(__name__)
T = TypeVar("T", bound=Comparable)
def _balance_factor(node: Optional[Node[T]]) -> int:
if not node:
return 0
left, right = Node.get_rank(node.left), Node.get_rank(node.right)
return right - left
def _update_rank(node: Node[T]) -> None:
left, right = Node.get_rank(node.left), Node.get_rank(node.right)
# logger.debug(f"[_update_rank] on {node} = ({left}, {right})")
node.rank = 1 + max(left, right)
class AVLTree(RankedTree[T]):
def is_correct_node(
self, node: Optional[Node[T]], recursive: bool = True
) -> bool:
if not node:
return True
if not (-1 <= _balance_factor(node) <= 1):
return False
return not recursive or (
self.is_correct_node(node.left)
and self.is_correct_node(node.right)
)
# region InsertRebalance
def __fix_0_child(
self,
x: Node[T],
y: Optional[Node[T]],
z: Node[T],
rotate_left: RotateFunction[T],
rotate_right: RotateFunction[T],
) -> Optional[Node[T]]:
new_root = x.parent
if not y or Node.difference(y) == 2:
new_root = rotate_right(z)
z.demote()
elif Node.difference(y) == 1:
rotate_left(x)
new_root = rotate_right(z)
y.promote()
x.demote()
z.demote()
return new_root
def _insert_rebalance(self, x: Node[T]) -> None:
diffs = Node.differences(x.parent)
while x.parent and (diffs == (0, 1) or diffs == (1, 0)):
x.parent.promote()
x = x.parent
diffs = Node.differences(x.parent)
if not x.parent:
return
rotating_around_root = x.parent.parent is None
new_root: Optional[Node[T]] = x.parent
rank_difference = Node.difference(x)
if rank_difference != 0:
return
x_parent = x.parent
assert x_parent is not None
if rank_difference == 0 and x.parent.left is x:
new_root = self.__fix_0_child(
x, x.right, x_parent, Node.rotate_left, Node.rotate_right
)
elif rank_difference == 0 and x.parent.right is x:
new_root = self.__fix_0_child(
x, x.left, x_parent, Node.rotate_right, Node.rotate_left
)
if rotating_around_root:
self.root = new_root
# endregion InsertRebalance
# region DeleteRebalance
def __delete_rotate(
self,
x: Node[T],
y: Node[T],
leaning: int,
rotating_around_root: bool,
rotate_left: RotateFunction[T],
rotate_right: RotateFunction[T],
) -> bool:
new_root = x
factor = _balance_factor(y)
if factor in (0, leaning):
new_root = rotate_left(x)
else:
rotate_right(y)
new_root = rotate_left(x)
for n in filter(None, (new_root.left, new_root.right, new_root)):
_update_rank(n)
if rotating_around_root:
self.root = new_root
return factor == 0
def __delete_fixup(
self, y: Optional[Node[T]], parent: Optional[Node[T]] = None
) -> bool:
x = y if y else parent
assert x
factor = _balance_factor(x)
if factor == 0:
_update_rank(x)
return False
elif factor in (-1, 1):
return True
rotating_around_root = x.parent is None
y, leaning, to_left, to_right = (
(x.right, 1, Node.rotate_left, Node.rotate_right)
if factor == 2
else (x.left, -1, Node.rotate_right, Node.rotate_left)
)
assert y
return self.__delete_rotate(
x,
y,
leaning,
rotating_around_root,
to_left,
to_right,
)
def _delete_rebalance(
self, node: Optional[Node[T]], parent: Optional[Node[T]]
) -> None:
while node or parent:
# TODO: Check if it is possible to not propagate all the way up.
# if self.__delete_fixup(node, parent):
# return
self.__delete_fixup(node, parent)
node, parent = parent, (parent.parent if parent else None)
# endregion DeleteRebalance