python/wavl.py

137 lines
3.5 KiB
Python
Raw Normal View History

from ranked_tree import RotateFunction
from avl import AVLTree
from node import Node, NodeType, Comparable
import logging
from typing import TypeVar, Optional
logger = logging.getLogger(__name__)
T = TypeVar("T", bound=Comparable)
class WAVLTree(AVLTree[T]):
def is_correct_node(
self, node: Optional[Node[T]], recursive: bool = True
) -> bool:
if not node:
return True
for child_rank in Node.differences(node):
if child_rank not in (1, 2):
return False
if node.type == NodeType.LEAF:
return node.rank == 0
return not recursive or (
self.is_correct_node(node.left)
and self.is_correct_node(node.right)
)
# region DeleteRebalance
def __fix_delete(
self,
x: Optional[Node[T]],
y: Node[T],
z: Node[T],
reversed: bool,
rotate_left: RotateFunction[T],
rotate_right: RotateFunction[T],
) -> None:
v = y.left
w = y.right
if reversed:
v, w = w, v
w_diff = Node.difference(w, y)
if w_diff == 1 and y.parent:
rotate_left(y.parent)
y.promote()
z.demote()
if z.type == NodeType.LEAF:
z.demote()
elif w_diff == 2 and v.parent:
rotate_right(v.parent)
rotate_left(v.parent)
v.promote().promote()
y.demote()
z.demote().demote()
def __bottomup_delete(
self, x: Optional[Node[T]], parent: Optional[Node[T]]
) -> None:
x_diff = Node.difference(x, parent)
if x_diff != 3 or not parent:
return
y = parent.right if parent.left is x else parent.left
y_diff = Node.difference(y, parent)
while (
parent
and x_diff == 3
and y
and (y_diff == 2 or Node.differences(y) == (2, 2))
):
if y_diff != 2:
y.demote()
parent.demote()
x = parent
parent = x.parent
if not parent:
return
y = parent.right if parent.left is x else parent.left
x_diff = Node.difference(x, parent)
y_diff = Node.difference(y, parent)
parent_node_diffs = Node.differences(parent)
if parent_node_diffs in ((1, 3), (3, 1)):
if parent.left is x:
assert parent.right
self.__fix_delete(
x,
parent.right,
parent,
False,
self.rotate_left,
self.rotate_right,
)
else:
assert parent.left
self.__fix_delete(
x,
parent.left,
parent,
True,
self.rotate_right,
self.rotate_left,
)
def _delete_rebalance(
self, y: Optional[Node[T]], parent: Optional[Node[T]]
) -> None:
if Node.differences(y) == (2, 2):
y.demote()
parent = y.parent
elif Node.differences(parent) == (2, 2):
parent.demote()
parent = parent.parent
if not parent:
return
for y in (parent.left, parent.right):
if Node.difference(y, parent) == 3:
self.__bottomup_delete(y, parent)
return
# endregion DeleteRebalance