from node import Node, Comparable from ranked_tree import RankedTree, RotateFunction import logging from typing import TypeVar, Optional logger = logging.getLogger(__name__) T = TypeVar("T", bound=Comparable) def _balance_factor(node: Optional[Node[T]]) -> int: if not node: return 0 left, right = Node.get_rank(node.left), Node.get_rank(node.right) return right - left def _update_rank(node: Node[T]) -> None: left, right = Node.get_rank(node.left), Node.get_rank(node.right) node.rank = 1 + max(left, right) class AVLTree(RankedTree[T]): def is_correct_node( self, node: Optional[Node[T]], recursive: bool = True ) -> bool: if not node: return True if not ( -1 <= _balance_factor(node) <= 1 and node.rank == 1 + max(Node.get_rank(node.left), Node.get_rank(node.right)) ): return False return not recursive or ( self.is_correct_node(node.left) and self.is_correct_node(node.right) ) # region InsertRebalance def __fix_0_child( self, x: Node[T], y: Optional[Node[T]], z: Node[T], rotate_left: RotateFunction[T], rotate_right: RotateFunction[T], ) -> None: if not y or Node.difference(y) == 2: rotate_right(z) z.demote() elif Node.difference(y) == 1: rotate_left(x) rotate_right(z) y.promote() x.demote() z.demote() def _insert_rebalance(self, x: Node[T]) -> None: diffs = Node.differences(x.parent) while x.parent and (diffs == (0, 1) or diffs == (1, 0)): x.parent.promote() x = x.parent diffs = Node.differences(x.parent) if not x.parent: return rank_difference = Node.difference(x) if rank_difference != 0: return x_parent = x.parent assert x_parent is not None if rank_difference == 0 and x.parent.left is x: self.__fix_0_child( x, x.right, x_parent, self.rotate_left, self.rotate_right ) elif rank_difference == 0 and x.parent.right is x: self.__fix_0_child( x, x.left, x_parent, self.rotate_right, self.rotate_left ) # endregion InsertRebalance # region DeleteRebalance def __delete_rotate( self, x: Node[T], y: Node[T], leaning: int, rotate_left: RotateFunction[T], rotate_right: RotateFunction[T], ) -> bool: new_root = x factor = _balance_factor(y) if factor in (0, leaning): new_root = rotate_left(x) else: rotate_right(y) new_root = rotate_left(x) for n in filter(None, (new_root.left, new_root.right, new_root)): _update_rank(n) return factor != 0 def __delete_fixup( self, x: Node[T], parent: Optional[Node[T]] = None ) -> bool: factor = _balance_factor(x) if factor == 0: _update_rank(x) return True elif factor in (-1, 1): return False y, leaning, to_left, to_right = ( (x.right, 1, self.rotate_left, self.rotate_right) if factor == 2 else (x.left, -1, self.rotate_right, self.rotate_left) ) assert y return self.__delete_rotate( x, y, leaning, to_left, to_right, ) def _delete_rebalance( self, node: Optional[Node[T]], parent: Optional[Node[T]] ) -> None: if not node and not parent: return if not node and parent: node, parent = parent, parent.parent while node and self.__delete_fixup(node, parent): node, parent = parent, parent.parent if parent else None # endregion DeleteRebalance