diff --git a/wavl.py b/wavl.py index dcc28a0..11e3621 100644 --- a/wavl.py +++ b/wavl.py @@ -48,7 +48,6 @@ class WAVLTree(AVLTree[T]): w_diff = Node.difference(w, y) - # assert v if w_diff == 1 and y.parent: new_root = rotate_left(y.parent) if rotating_around_root: @@ -59,7 +58,7 @@ class WAVLTree(AVLTree[T]): if z.type == NodeType.LEAF: z.demote() - elif w_diff == 2 and v and v.parent: + elif w_diff == 2 and v.parent: rotate_right(v.parent) new_root = rotate_left(v.parent) if rotating_around_root: @@ -85,9 +84,9 @@ class WAVLTree(AVLTree[T]): and y and (y_diff == 2 or Node.differences(y) == (2, 2)) ): - parent.demote() if y_diff != 2: y.demote() + parent.demote() x = parent parent = x.parent @@ -98,9 +97,6 @@ class WAVLTree(AVLTree[T]): x_diff = Node.difference(x, parent) y_diff = Node.difference(y, parent) - if not parent: - return - rotating_around_root = parent.parent is None parent_node_diffs = Node.differences(parent) @@ -128,8 +124,8 @@ class WAVLTree(AVLTree[T]): Node.rotate_left, ) - def __delete_fixup( - self, y: Optional[Node[T]], parent: Optional[Node[T]] = None + def _delete_rebalance( + self, y: Optional[Node[T]], parent: Optional[Node[T]] ) -> None: if Node.differences(y) == (2, 2): y.demote() @@ -144,10 +140,6 @@ class WAVLTree(AVLTree[T]): for y in (parent.left, parent.right): if Node.difference(y, parent) == 3: self.__bottomup_delete(y, parent) - - def _delete_rebalance( - self, node: Optional[Node[T]], parent: Optional[Node[T]] - ) -> None: - self.__delete_fixup(node, parent) + return # endregion DeleteRebalance