from node import Node, NodeType from collections import deque import logging logger = logging.getLogger(__name__) class WAVLTree: def __init__(self): self.root = None @property def rank(self): return Node.rank(self.root) @property def is_correct(self): return WAVLTree.is_correct_node(self.root) @staticmethod def is_correct_node(node, recursive=True): if not node: return True for child_rank in Node.differences(node): if child_rank not in (1, 2): return False if node.type == NodeType.LEAF: return node.rank == 0 return not recursive or ( WAVLTree.is_correct_node(node.left) and WAVLTree.is_correct_node(node.right) ) def search(self, value, node=None): if not node: node = self.root return Node.search(value, node) def __fix_0_child(self, x, y, z, rotate_left, rotate_right): new_root = x.parent if not y or Node.difference(y) == 2: new_root = rotate_right(z) z.demote() elif Node.difference(y) == 1: rotate_left(x) new_root = rotate_right(z) y.promote() x.demote() z.demote() return new_root def __bottomup_rebalance(self, x): diffs = Node.differences(x.parent) # if diffs != (0, 1) and diffs != (1, 0): # return while x.parent and (diffs == (0, 1) or diffs == (1, 0)): x.parent.promote() x = x.parent diffs = Node.differences(x.parent) if not x.parent: return rotating_around_root = x.parent.parent is None new_root = x.parent rank_difference = Node.difference(x) if rank_difference != 0: return if rank_difference == 0 and x.parent.left is x: new_root = self.__fix_0_child( x, x.right, x.parent, Node.rotate_left, Node.rotate_right ) elif rank_difference == 0 and x.parent.right is x: new_root = self.__fix_0_child( x, x.left, x.parent, Node.rotate_right, Node.rotate_left ) if rotating_around_root: self.root = new_root def insert(self, value): inserted_node = Node(value) if not self.root: self.root = inserted_node return parent = Node.find_parent_node(value, self.root) inserted_node.parent = parent if value < parent.value: parent.left = inserted_node else: parent.right = inserted_node self.__bottomup_rebalance(inserted_node) def __transplant(self, u, v): if not u.parent: self.root = v elif u.parent.left is u: u.parent.left = v else: u.parent.right = v if v: v.rank = u.rank v.parent = u.parent @staticmethod def __fix_delete(x, y, z, reversed, rotate_left, rotate_right): new_root = x v = y.left w = y.right if reversed: v, w = w, v logger.debug(f"__fix_delete({x}, {y}, {z}, {reversed})") w_diff = Node.difference(w, y) logger.debug(f"w_diff = {w_diff}") if w_diff == 1 and y.parent: logger.debug(f"y.parent = {y.parent}") new_root = rotate_left(y.parent) y.promote() z.demote() if z.type == NodeType.LEAF: z.demote() elif w_diff == 2 and v.parent: logger.debug(f"v.parent = {v.parent}") rotate_right(v.parent) new_root = rotate_left(v.parent) v.promote().promote() y.demote() z.demote().demote() return new_root def __bottomup_delete(self, x, parent): x_diff = Node.difference(x, parent) if x_diff != 3 or not parent: return y = parent.right if parent.left is x else parent.left y_diff = Node.difference(y, parent) while ( parent and x_diff == 3 and y and (y_diff == 2 or Node.differences(y) == (2, 2)) ): parent.demote() if y_diff != 2: y.demote() x = parent parent = x.parent if not parent: return y = parent.right if parent.left is x else parent.left x_diff = Node.difference(x, parent) y_diff = Node.difference(y, parent) if not parent: return rotating_around_root = parent.parent is None new_root = parent parent_node_diffs = Node.differences(parent) if parent_node_diffs in ((1, 3), (3, 1)): if parent.left is x: new_root = WAVLTree.__fix_delete( x, parent.right, parent, False, Node.rotate_left, Node.rotate_right, ) else: new_root = WAVLTree.__fix_delete( x, parent.left, parent, True, Node.rotate_right, Node.rotate_left, ) if rotating_around_root: self.root = new_root def __delete_fixup(self, y, parent=None): logger.debug(f"[__delete_fixup] y = {y}, parent = {parent}") z = y if y else parent logger.debug( f"[z.demote()] Node.differences({repr(z)}) == (2, 2) ~>* {Node.differences(z)} == (2, 2)" ) if Node.differences(z) == (2, 2): z.demote() if parent: for y in (parent.left, parent.right): logger.debug( f"[bottom-up delete] Node.difference({y}, {parent}) == 3 ~>* {Node.difference(y, parent)} == 3" ) if Node.difference(y, parent) == 3: self.__bottomup_delete(y, parent) def __fix_after_delete(self, node, parent): while node or parent: self.__delete_fixup(node, parent) node, parent = parent, (parent.parent if parent else None) def delete_node(self, node): y, parent = None, node.parent if not node.left: logger.debug("node.left is None") y = node.right self.__transplant(node, node.right) elif not node.right: logger.debug("node.right is None") y = node.left self.__transplant(node, node.left) else: logger.debug("taking successor") n = Node.minimum(node.right) y, parent = None, (n.parent if n.parent is not node else n) if n.parent is not node: parent = n.right if n.right else n.parent self.__transplant(n, n.right) n.right = node.right n.right.parent = n self.__transplant(node, n) n.left = node.left n.left.parent = n return (y, parent) def __delete(self, value, node=None): if node is None: return if node.value != value: return self.__delete(value, node.left if value < node.value else node.right) (y, parent) = self.delete_node(node) self.__fix_after_delete(y, parent) def delete(self, value): logger.debug(f"[DELETE] {value}") self.__delete(value, self.root) def __str__(self): result = "digraph {\n" queue = deque() queue.append(self.root) edges = [] while queue: node = queue.popleft() if not node: continue result += f'"{str(node)}" [label="{node.value}, {node.rank}"];\n' for child in (node.left, node.right): if not child: continue edges.append((node, child)) queue.append(child) for fromNode, toNode in edges: result += f'"{str(fromNode)}" -> "{str(toNode)}" [label="{Node.difference(toNode)}"]\n' return result + "}\n"