from ranked_tree import RotateFunction from avl import AVLTree from node import Node, NodeType, Comparable import logging from typing import TypeVar, Optional logger = logging.getLogger(__name__) T = TypeVar("T", bound=Comparable) class WAVLTree(AVLTree[T]): def is_correct_node( self, node: Optional[Node[T]], recursive: bool = True ) -> bool: if not node: return True for child_rank in Node.differences(node): if child_rank not in (1, 2): return False if node.type == NodeType.LEAF: return node.rank == 0 return not recursive or ( self.is_correct_node(node.left) and self.is_correct_node(node.right) ) # region DeleteRebalance def __fix_delete( self, x: Optional[Node[T]], y: Node[T], z: Node[T], reversed: bool, rotate_left: RotateFunction[T], rotate_right: RotateFunction[T], ) -> None: v = y.left w = y.right if reversed: v, w = w, v w_diff = Node.difference(w, y) if w_diff == 1 and y.parent: rotate_left(y.parent) y.promote() z.demote() if z.type == NodeType.LEAF: z.demote() elif w_diff == 2 and v.parent: rotate_right(v.parent) rotate_left(v.parent) v.promote().promote() y.demote() z.demote().demote() def __bottomup_delete( self, x: Optional[Node[T]], parent: Optional[Node[T]] ) -> None: x_diff = Node.difference(x, parent) if x_diff != 3 or not parent: return y = parent.right if parent.left is x else parent.left y_diff = Node.difference(y, parent) while ( parent and x_diff == 3 and y and (y_diff == 2 or Node.differences(y) == (2, 2)) ): if y_diff != 2: y.demote() parent.demote() x = parent parent = x.parent if not parent: return y = parent.right if parent.left is x else parent.left x_diff = Node.difference(x, parent) y_diff = Node.difference(y, parent) parent_node_diffs = Node.differences(parent) if parent_node_diffs in ((1, 3), (3, 1)): if parent.left is x: assert parent.right self.__fix_delete( x, parent.right, parent, False, self.rotate_left, self.rotate_right, ) else: assert parent.left self.__fix_delete( x, parent.left, parent, True, self.rotate_right, self.rotate_left, ) def _delete_rebalance( self, y: Optional[Node[T]], parent: Optional[Node[T]] ) -> None: if Node.differences(y) == (2, 2): y.demote() parent = y.parent elif Node.differences(parent) == (2, 2): parent.demote() parent = parent.parent if not parent: return for y in (parent.left, parent.right): if Node.difference(y, parent) == 3: self.__bottomup_delete(y, parent) return # endregion DeleteRebalance