diff --git a/wavl.py b/wavl.py index 7bbb27a..d236649 100644 --- a/wavl.py +++ b/wavl.py @@ -29,15 +29,16 @@ class WAVLTree(AVLTree[T]): # region DeleteRebalance - @staticmethod def __fix_delete( + self, x: Optional[Node[T]], y: Node[T], z: Node[T], reversed: bool, + rotating_around_root: bool, rotate_left: RotateFunction[T], rotate_right: RotateFunction[T], - ) -> Optional[Node[T]]: + ) -> None: new_root = x v = y.left w = y.right @@ -54,6 +55,8 @@ class WAVLTree(AVLTree[T]): if w_diff == 1 and y.parent: logger.debug(f"y.parent = {y.parent}") new_root = rotate_left(y.parent) + if rotating_around_root: + self.root = new_root y.promote() z.demote() @@ -64,13 +67,13 @@ class WAVLTree(AVLTree[T]): logger.debug(f"v.parent = {v.parent}") rotate_right(v.parent) new_root = rotate_left(v.parent) + if rotating_around_root: + self.root = new_root v.promote().promote() y.demote() z.demote().demote() - return new_root - def __bottomup_delete( self, x: Optional[Node[T]], parent: Optional[Node[T]] ) -> None: @@ -104,57 +107,54 @@ class WAVLTree(AVLTree[T]): return rotating_around_root = parent.parent is None - new_root: Optional[Node[T]] = parent parent_node_diffs = Node.differences(parent) if parent_node_diffs in ((1, 3), (3, 1)): if parent.left is x: assert parent.right - new_root = WAVLTree.__fix_delete( + self.__fix_delete( x, parent.right, parent, False, + rotating_around_root, Node.rotate_left, Node.rotate_right, ) else: assert parent.left - new_root = WAVLTree.__fix_delete( + self.__fix_delete( x, parent.left, parent, True, + rotating_around_root, Node.rotate_right, Node.rotate_left, ) - if rotating_around_root: - self.root = new_root - def __delete_fixup( self, y: Optional[Node[T]], parent: Optional[Node[T]] = None ) -> None: logger.debug(f"[__delete_fixup] y = {y}, parent = {parent}") - z = y if y else parent - logger.debug( - f"[z.demote()] Node.differences({repr(z)}) == (2, 2) ~>*" - f"{Node.differences(z)} == (2, 2)" - ) - assert z - # FIXME: In combination with propagation below, we get AVL tree - if Node.differences(z) == (2, 2): - z.demote() + if Node.differences(y) == (2, 2): + y.demote() + parent = y.parent + elif Node.differences(parent) == (2, 2): + parent.demote() + parent = parent.parent - if parent: - for y in (parent.left, parent.right): - logger.debug( - f"[bottom-up delete] Node.difference({y}, {parent}) == 3" - f"~>* {Node.difference(y, parent)} == 3" - ) - if Node.difference(y, parent) == 3: - self.__bottomup_delete(y, parent) + if not parent: + return + + for y in (parent.left, parent.right): + logger.debug( + f"[bottom-up delete] Node.difference({y}, {parent}) == 3" + f"~>* {Node.difference(y, parent)} == 3" + ) + if Node.difference(y, parent) == 3: + self.__bottomup_delete(y, parent) def _delete_rebalance( self, node: Optional[Node[T]], parent: Optional[Node[T]]