from node import Node, Comparable, RotateFunction from ranked_tree import RankedTree import logging from typing import TypeVar, Optional logger = logging.getLogger(__name__) T = TypeVar("T", bound=Comparable) def _balance_factor(node: Optional[Node[T]]) -> int: if not node: return 0 left, right = Node.get_rank(node.left), Node.get_rank(node.right) return right - left def _update_rank(node: Node[T]) -> None: left, right = Node.get_rank(node.left), Node.get_rank(node.right) # logger.debug(f"[_update_rank] on {node} = ({left}, {right})") node.rank = 1 + max(left, right) class AVLTree(RankedTree[T]): def is_correct_node( self, node: Optional[Node[T]], recursive: bool = True ) -> bool: if not node: return True if not (-1 <= _balance_factor(node) <= 1): return False return not recursive or ( self.is_correct_node(node.left) and self.is_correct_node(node.right) ) # region InsertRebalance def __fix_0_child( self, x: Node[T], y: Optional[Node[T]], z: Node[T], rotate_left: RotateFunction[T], rotate_right: RotateFunction[T], ) -> Optional[Node[T]]: new_root = x.parent if not y or Node.difference(y) == 2: new_root = rotate_right(z) z.demote() elif Node.difference(y) == 1: rotate_left(x) new_root = rotate_right(z) y.promote() x.demote() z.demote() return new_root def _insert_rebalance(self, x: Node[T]) -> None: diffs = Node.differences(x.parent) while x.parent and (diffs == (0, 1) or diffs == (1, 0)): x.parent.promote() x = x.parent diffs = Node.differences(x.parent) if not x.parent: return rotating_around_root = x.parent.parent is None new_root: Optional[Node[T]] = x.parent rank_difference = Node.difference(x) if rank_difference != 0: return x_parent = x.parent assert x_parent is not None if rank_difference == 0 and x.parent.left is x: new_root = self.__fix_0_child( x, x.right, x_parent, Node.rotate_left, Node.rotate_right ) elif rank_difference == 0 and x.parent.right is x: new_root = self.__fix_0_child( x, x.left, x_parent, Node.rotate_right, Node.rotate_left ) if rotating_around_root: self.root = new_root # endregion InsertRebalance # region DeleteRebalance def __delete_rotate( self, x: Node[T], y: Node[T], leaning: int, rotating_around_root: bool, rotate_left: RotateFunction[T], rotate_right: RotateFunction[T], ) -> bool: new_root = x factor = _balance_factor(y) if factor in (0, leaning): new_root = rotate_left(x) else: rotate_right(y) new_root = rotate_left(x) for n in filter(None, (new_root.left, new_root.right, new_root)): _update_rank(n) if rotating_around_root: self.root = new_root return factor == 0 def __delete_fixup( self, y: Optional[Node[T]], parent: Optional[Node[T]] = None ) -> bool: x = y if y else parent assert x factor = _balance_factor(x) if factor == 0: _update_rank(x) return False elif factor in (-1, 1): return True rotating_around_root = x.parent is None y, leaning, to_left, to_right = ( (x.right, 1, Node.rotate_left, Node.rotate_right) if factor == 2 else (x.left, -1, Node.rotate_right, Node.rotate_left) ) assert y return self.__delete_rotate( x, y, leaning, rotating_around_root, to_left, to_right, ) def _delete_rebalance( self, node: Optional[Node[T]], parent: Optional[Node[T]] ) -> None: while node or parent: # TODO: Check if it is possible to not propagate all the way up. # if self.__delete_fixup(node, parent): # return self.__delete_fixup(node, parent) node, parent = parent, (parent.parent if parent else None) # endregion DeleteRebalance