300 lines
8.2 KiB
Python
300 lines
8.2 KiB
Python
from node import Node, NodeType
|
|
|
|
from collections import deque
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class WAVLTree:
|
|
def __init__(self):
|
|
self.root = None
|
|
|
|
@property
|
|
def rank(self):
|
|
return Node.rank(self.root)
|
|
|
|
@property
|
|
def is_correct(self):
|
|
return WAVLTree.is_correct_node(self.root)
|
|
|
|
@staticmethod
|
|
def is_correct_node(node, recursive=True):
|
|
if not node:
|
|
return True
|
|
|
|
for child_rank in Node.differences(node):
|
|
if child_rank not in (1, 2):
|
|
return False
|
|
|
|
if node.type == NodeType.LEAF:
|
|
return node.rank == 0
|
|
|
|
return not recursive or (
|
|
WAVLTree.is_correct_node(node.left) and WAVLTree.is_correct_node(node.right)
|
|
)
|
|
|
|
def search(self, value, node=None):
|
|
if not node:
|
|
node = self.root
|
|
|
|
return Node.search(value, node)
|
|
|
|
def __fix_0_child(self, x, y, z, rotate_left, rotate_right):
|
|
new_root = x.parent
|
|
|
|
if not y or Node.difference(y) == 2:
|
|
new_root = rotate_right(z)
|
|
z.demote()
|
|
elif Node.difference(y) == 1:
|
|
rotate_left(x)
|
|
new_root = rotate_right(z)
|
|
|
|
y.promote()
|
|
x.demote()
|
|
z.demote()
|
|
|
|
return new_root
|
|
|
|
def __bottomup_rebalance(self, x):
|
|
diffs = Node.differences(x.parent)
|
|
# if diffs != (0, 1) and diffs != (1, 0):
|
|
# return
|
|
|
|
while x.parent and (diffs == (0, 1) or diffs == (1, 0)):
|
|
x.parent.promote()
|
|
x = x.parent
|
|
|
|
diffs = Node.differences(x.parent)
|
|
|
|
if not x.parent:
|
|
return
|
|
|
|
rotating_around_root = x.parent.parent is None
|
|
new_root = x.parent
|
|
|
|
rank_difference = Node.difference(x)
|
|
|
|
if rank_difference != 0:
|
|
return
|
|
|
|
if rank_difference == 0 and x.parent.left is x:
|
|
new_root = self.__fix_0_child(
|
|
x, x.right, x.parent, Node.rotate_left, Node.rotate_right
|
|
)
|
|
elif rank_difference == 0 and x.parent.right is x:
|
|
new_root = self.__fix_0_child(
|
|
x, x.left, x.parent, Node.rotate_right, Node.rotate_left
|
|
)
|
|
|
|
if rotating_around_root:
|
|
self.root = new_root
|
|
|
|
def insert(self, value):
|
|
inserted_node = Node(value)
|
|
|
|
if not self.root:
|
|
self.root = inserted_node
|
|
return
|
|
|
|
parent = Node.find_parent_node(value, self.root)
|
|
inserted_node.parent = parent
|
|
|
|
if value < parent.value:
|
|
parent.left = inserted_node
|
|
else:
|
|
parent.right = inserted_node
|
|
|
|
self.__bottomup_rebalance(inserted_node)
|
|
|
|
def __transplant(self, u, v):
|
|
if not u.parent:
|
|
self.root = v
|
|
elif u.parent.left is u:
|
|
u.parent.left = v
|
|
else:
|
|
u.parent.right = v
|
|
|
|
if v:
|
|
v.rank = u.rank
|
|
v.parent = u.parent
|
|
|
|
@staticmethod
|
|
def __fix_delete(x, y, z, reversed, rotate_left, rotate_right):
|
|
new_root = x
|
|
v = y.left
|
|
w = y.right
|
|
|
|
if reversed:
|
|
v, w = w, v
|
|
|
|
logger.debug(f"__fix_delete({x}, {y}, {z}, {reversed})")
|
|
|
|
w_diff = Node.difference(w, y)
|
|
logger.debug(f"w_diff = {w_diff}")
|
|
if w_diff == 1 and y.parent:
|
|
logger.debug(f"y.parent = {y.parent}")
|
|
new_root = rotate_left(y.parent)
|
|
|
|
y.promote()
|
|
z.demote()
|
|
|
|
if z.type == NodeType.LEAF:
|
|
z.demote()
|
|
elif w_diff == 2 and v.parent:
|
|
logger.debug(f"v.parent = {v.parent}")
|
|
rotate_right(v.parent)
|
|
new_root = rotate_left(v.parent)
|
|
|
|
v.promote().promote()
|
|
y.demote()
|
|
z.demote().demote()
|
|
|
|
return new_root
|
|
|
|
def __bottomup_delete(self, x, parent):
|
|
x_diff = Node.difference(x, parent)
|
|
if x_diff != 3 or not parent:
|
|
return
|
|
|
|
y = parent.right if parent.left is x else parent.left
|
|
y_diff = Node.difference(y, parent)
|
|
|
|
while (
|
|
parent
|
|
and x_diff == 3
|
|
and y
|
|
and (y_diff == 2 or Node.differences(y) == (2, 2))
|
|
):
|
|
parent.demote()
|
|
if y_diff != 2:
|
|
y.demote()
|
|
|
|
x = parent
|
|
parent = x.parent
|
|
if not parent:
|
|
return
|
|
y = parent.right if parent.left is x else parent.left
|
|
|
|
x_diff = Node.difference(x, parent)
|
|
y_diff = Node.difference(y, parent)
|
|
|
|
if not parent:
|
|
return
|
|
|
|
rotating_around_root = parent.parent is None
|
|
new_root = parent
|
|
|
|
parent_node_diffs = Node.differences(parent)
|
|
if parent_node_diffs in ((1, 3), (3, 1)):
|
|
if parent.left is x:
|
|
new_root = WAVLTree.__fix_delete(
|
|
x,
|
|
parent.right,
|
|
parent,
|
|
False,
|
|
Node.rotate_left,
|
|
Node.rotate_right,
|
|
)
|
|
else:
|
|
new_root = WAVLTree.__fix_delete(
|
|
x,
|
|
parent.left,
|
|
parent,
|
|
True,
|
|
Node.rotate_right,
|
|
Node.rotate_left,
|
|
)
|
|
|
|
if rotating_around_root:
|
|
self.root = new_root
|
|
|
|
def __delete_fixup(self, y, parent=None):
|
|
logger.debug(f"[__delete_fixup] y = {y}, parent = {parent}")
|
|
|
|
z = y if y else parent
|
|
logger.debug(
|
|
f"[z.demote()] Node.differences({repr(z)}) == (2, 2) ~>* {Node.differences(z)} == (2, 2)"
|
|
)
|
|
if Node.differences(z) == (2, 2):
|
|
z.demote()
|
|
|
|
if parent:
|
|
for y in (parent.left, parent.right):
|
|
logger.debug(
|
|
f"[bottom-up delete] Node.difference({y}, {parent}) == 3 ~>* {Node.difference(y, parent)} == 3"
|
|
)
|
|
if Node.difference(y, parent) == 3:
|
|
self.__bottomup_delete(y, parent)
|
|
|
|
def __fix_after_delete(self, node, parent):
|
|
while node or parent:
|
|
self.__delete_fixup(node, parent)
|
|
node, parent = parent, (parent.parent if parent else None)
|
|
|
|
def delete_node(self, node):
|
|
y, parent = None, node.parent
|
|
|
|
if not node.left:
|
|
logger.debug("node.left is None")
|
|
y = node.right
|
|
self.__transplant(node, node.right)
|
|
elif not node.right:
|
|
logger.debug("node.right is None")
|
|
y = node.left
|
|
self.__transplant(node, node.left)
|
|
else:
|
|
logger.debug("taking successor")
|
|
n = Node.minimum(node.right)
|
|
y, parent = None, (n.parent if n.parent is not node else n)
|
|
|
|
if n.parent is not node:
|
|
parent = n.right if n.right else n.parent
|
|
self.__transplant(n, n.right)
|
|
n.right = node.right
|
|
n.right.parent = n
|
|
|
|
self.__transplant(node, n)
|
|
n.left = node.left
|
|
n.left.parent = n
|
|
|
|
return (y, parent)
|
|
|
|
def __delete(self, value, node=None):
|
|
if node is None:
|
|
return
|
|
|
|
if node.value != value:
|
|
return self.__delete(value, node.left if value < node.value else node.right)
|
|
|
|
(y, parent) = self.delete_node(node)
|
|
self.__fix_after_delete(y, parent)
|
|
|
|
def delete(self, value):
|
|
logger.debug(f"[DELETE] {value}")
|
|
self.__delete(value, self.root)
|
|
|
|
def __str__(self):
|
|
result = "digraph {\n"
|
|
|
|
queue = deque()
|
|
queue.append(self.root)
|
|
|
|
edges = []
|
|
while queue:
|
|
node = queue.popleft()
|
|
if not node:
|
|
continue
|
|
|
|
result += f'"{str(node)}" [label="{node.value}, {node.rank}"];\n'
|
|
for child in (node.left, node.right):
|
|
if not child:
|
|
continue
|
|
|
|
edges.append((node, child))
|
|
queue.append(child)
|
|
|
|
for fromNode, toNode in edges:
|
|
result += f'"{str(fromNode)}" -> "{str(toNode)}" [label="{Node.difference(toNode)}"]\n'
|
|
|
|
return result + "}\n"
|