diff --git a/avl.py b/avl.py new file mode 100644 index 0000000..5445168 --- /dev/null +++ b/avl.py @@ -0,0 +1,165 @@ +from node import Node, Comparable, RotateFunction +from ranked_tree import RankedTree + +import logging +from typing import TypeVar, Optional + +logger = logging.getLogger(__name__) +T = TypeVar("T", bound=Comparable) + + +def _balance_factor(node: Optional[Node[T]]) -> int: + if not node: + return 0 + + left, right = Node.get_rank(node.left), Node.get_rank(node.right) + return right - left + + +def _update_rank(node: Node[T]) -> None: + left, right = Node.get_rank(node.left), Node.get_rank(node.right) + # logger.debug(f"[_update_rank] on {node} = ({left}, {right})") + node.rank = 1 + max(left, right) + + +class AVLTree(RankedTree[T]): + def is_correct_node( + self, node: Optional[Node[T]], recursive: bool = True + ) -> bool: + if not node: + return True + + if not (-1 <= _balance_factor(node) <= 1): + return False + + return not recursive or ( + self.is_correct_node(node.left) + and self.is_correct_node(node.right) + ) + + # region InsertRebalance + + def __fix_0_child( + self, + x: Node[T], + y: Optional[Node[T]], + z: Node[T], + rotate_left: RotateFunction[T], + rotate_right: RotateFunction[T], + ) -> Optional[Node[T]]: + new_root = x.parent + + if not y or Node.difference(y) == 2: + new_root = rotate_right(z) + z.demote() + elif Node.difference(y) == 1: + rotate_left(x) + new_root = rotate_right(z) + + y.promote() + x.demote() + z.demote() + + return new_root + + def _insert_rebalance(self, x: Node[T]) -> None: + diffs = Node.differences(x.parent) + while x.parent and (diffs == (0, 1) or diffs == (1, 0)): + x.parent.promote() + x = x.parent + + diffs = Node.differences(x.parent) + + if not x.parent: + return + + rotating_around_root = x.parent.parent is None + new_root: Optional[Node[T]] = x.parent + + rank_difference = Node.difference(x) + if rank_difference != 0: + return + + x_parent = x.parent + assert x_parent is not None + if rank_difference == 0 and x.parent.left is x: + new_root = self.__fix_0_child( + x, x.right, x_parent, Node.rotate_left, Node.rotate_right + ) + elif rank_difference == 0 and x.parent.right is x: + new_root = self.__fix_0_child( + x, x.left, x_parent, Node.rotate_right, Node.rotate_left + ) + + if rotating_around_root: + self.root = new_root + + # endregion InsertRebalance + + # region DeleteRebalance + + def __delete_rotate( + self, + x: Node[T], + y: Node[T], + leaning: int, + rotating_around_root: bool, + rotate_left: RotateFunction[T], + rotate_right: RotateFunction[T], + ) -> bool: + new_root = x + + factor = _balance_factor(y) + if factor in (0, leaning): + new_root = rotate_left(x) + else: + rotate_right(y) + new_root = rotate_left(x) + + for n in filter(None, (new_root.left, new_root.right, new_root)): + _update_rank(n) + + if rotating_around_root: + self.root = new_root + return factor == 0 + + def __delete_fixup( + self, y: Optional[Node[T]], parent: Optional[Node[T]] = None + ) -> bool: + x = y if y else parent + assert x + + factor = _balance_factor(x) + if factor == 0: + _update_rank(x) + return False + elif factor in (-1, 1): + return True + + rotating_around_root = x.parent is None + y, leaning, to_left, to_right = ( + (x.right, 1, Node.rotate_left, Node.rotate_right) + if factor == 2 + else (x.left, -1, Node.rotate_right, Node.rotate_left) + ) + assert y + return self.__delete_rotate( + x, + y, + leaning, + rotating_around_root, + to_left, + to_right, + ) + + def _delete_rebalance( + self, node: Optional[Node[T]], parent: Optional[Node[T]] + ) -> None: + while node or parent: + # TODO: Check if it is possible to not propagate all the way up. + # if self.__delete_fixup(node, parent): + # return + self.__delete_fixup(node, parent) + node, parent = parent, (parent.parent if parent else None) + + # endregion DeleteRebalance diff --git a/ranked_tree.py b/ranked_tree.py new file mode 100644 index 0000000..92ded74 --- /dev/null +++ b/ranked_tree.py @@ -0,0 +1,146 @@ +from abc import abstractmethod +from node import Node, Comparable + +from collections import deque +import logging +from typing import Deque, Optional, Tuple, TypeVar, Generic + +logger = logging.getLogger(__name__) +T = TypeVar("T", bound=Comparable) + + +class RankedTree(Generic[T]): + def __init__(self) -> None: + self.root: Optional[Node[T]] = None + + def __str__(self) -> str: + result = "digraph {\n" + + queue: Deque[Optional[Node[T]]] = deque() + queue.append(self.root) + + edges = [] + while queue: + node = queue.popleft() + if not node: + continue + + result += f'"{str(node)}" [label="{node.value}, {node.rank}"];\n' + for child in (node.left, node.right): + if not child: + continue + + edges.append((node, child)) + queue.append(child) + + for from_node, to_node in edges: + label = f'[label="{Node.difference(to_node)}"]' + result += f'"{str(from_node)}" -> "{str(to_node)}" {label}\n' + + return result + "}\n" + + # region TreeSpecificMethods + + @abstractmethod + def is_correct_node(self, node: Optional[Node[T]]) -> bool: + pass + + @abstractmethod + def _insert_rebalance(self, x: Node[T]) -> None: + pass + + @abstractmethod + def _delete_rebalance( + self, node: Optional[Node[T]], parent: Optional[Node[T]] + ) -> None: + pass + + # endregion TreeSpecificMethods + + @property + def rank(self) -> int: + return Node.get_rank(self.root) + + @property + def is_correct(self) -> bool: + return self.is_correct_node(self.root) + + def search( + self, value: T, node: Optional[Node[T]] = None + ) -> Optional[Node[T]]: + if not node: + node = self.root + + return Node.search(value, node) + + def insert(self, value: T) -> None: + inserted_node = Node(value) + + if not self.root: + self.root = inserted_node + return + + parent = Node.find_parent_node(value, self.root) + inserted_node.parent = parent + + if value < parent.value: + parent.left = inserted_node + else: + parent.right = inserted_node + + self._insert_rebalance(inserted_node) + + def _transplant(self, u: Node[T], v: Optional[Node[T]]) -> None: + if not u.parent: + self.root = v + elif u.parent.left is u: + u.parent.left = v + else: + u.parent.right = v + + if v: + v.rank = u.rank + v.parent = u.parent + + def _delete_node( + self, node: Optional[Node[T]] + ) -> Optional[Tuple[Optional[Node[T]], Optional[Node[T]]]]: + if node is None: + return None + + y, parent = None, node.parent + + if not node.left: + y = node.right + self._transplant(node, node.right) + elif not node.right: + y = node.left + self._transplant(node, node.left) + else: + n = Node.minimum(node.right) + y, parent = None, (n.parent if n.parent is not node else n) + + if n.parent is not node: + parent = n.right if n.right else n.parent + self._transplant(n, n.right) + n.right = node.right + n.right.parent = n + + self._transplant(node, n) + n.left = node.left + n.left.parent = n + + return (y, parent) + + def _delete( + self, value: T + ) -> Optional[Tuple[Optional[Node[T]], Optional[Node[T]]]]: + node = self.root + while node is not None and node.value != value: + node = node.left if value < node.value else node.right + return self._delete_node(node) + + def delete(self, value: T) -> None: + if to_be_rebalanced := self._delete(value): + y, parent = to_be_rebalanced + self._delete_rebalance(y, parent) diff --git a/wavl.py b/wavl.py index 177e676..20916f3 100644 --- a/wavl.py +++ b/wavl.py @@ -1,25 +1,17 @@ -from node import Node, NodeType +from avl import AVLTree +from node import Node, NodeType, Comparable, RotateFunction -from collections import deque import logging +from typing import TypeVar, Optional logger = logging.getLogger(__name__) +T = TypeVar("T", bound=Comparable) -class WAVLTree: - def __init__(self): - self.root = None - - @property - def rank(self): - return Node.rank(self.root) - - @property - def is_correct(self): - return WAVLTree.is_correct_node(self.root) - - @staticmethod - def is_correct_node(node, recursive=True): +class WAVLTree(AVLTree[T]): + def is_correct_node( + self, node: Optional[Node[T]], recursive: bool = True + ) -> bool: if not node: return True @@ -31,96 +23,21 @@ class WAVLTree: return node.rank == 0 return not recursive or ( - WAVLTree.is_correct_node(node.left) and WAVLTree.is_correct_node(node.right) + self.is_correct_node(node.left) + and self.is_correct_node(node.right) ) - def search(self, value, node=None): - if not node: - node = self.root - - return Node.search(value, node) - - def __fix_0_child(self, x, y, z, rotate_left, rotate_right): - new_root = x.parent - - if not y or Node.difference(y) == 2: - new_root = rotate_right(z) - z.demote() - elif Node.difference(y) == 1: - rotate_left(x) - new_root = rotate_right(z) - - y.promote() - x.demote() - z.demote() - - return new_root - - def __bottomup_rebalance(self, x): - diffs = Node.differences(x.parent) - # if diffs != (0, 1) and diffs != (1, 0): - # return - - while x.parent and (diffs == (0, 1) or diffs == (1, 0)): - x.parent.promote() - x = x.parent - - diffs = Node.differences(x.parent) - - if not x.parent: - return - - rotating_around_root = x.parent.parent is None - new_root = x.parent - - rank_difference = Node.difference(x) - - if rank_difference != 0: - return - - if rank_difference == 0 and x.parent.left is x: - new_root = self.__fix_0_child( - x, x.right, x.parent, Node.rotate_left, Node.rotate_right - ) - elif rank_difference == 0 and x.parent.right is x: - new_root = self.__fix_0_child( - x, x.left, x.parent, Node.rotate_right, Node.rotate_left - ) - - if rotating_around_root: - self.root = new_root - - def insert(self, value): - inserted_node = Node(value) - - if not self.root: - self.root = inserted_node - return - - parent = Node.find_parent_node(value, self.root) - inserted_node.parent = parent - - if value < parent.value: - parent.left = inserted_node - else: - parent.right = inserted_node - - self.__bottomup_rebalance(inserted_node) - - def __transplant(self, u, v): - if not u.parent: - self.root = v - elif u.parent.left is u: - u.parent.left = v - else: - u.parent.right = v - - if v: - v.rank = u.rank - v.parent = u.parent + # region DeleteRebalance @staticmethod - def __fix_delete(x, y, z, reversed, rotate_left, rotate_right): + def __fix_delete( + x: Optional[Node[T]], + y: Node[T], + z: Node[T], + reversed: bool, + rotate_left: RotateFunction[T], + rotate_right: RotateFunction[T], + ) -> Optional[Node[T]]: new_root = x v = y.left w = y.right @@ -132,6 +49,8 @@ class WAVLTree: w_diff = Node.difference(w, y) logger.debug(f"w_diff = {w_diff}") + + assert v if w_diff == 1 and y.parent: logger.debug(f"y.parent = {y.parent}") new_root = rotate_left(y.parent) @@ -152,7 +71,9 @@ class WAVLTree: return new_root - def __bottomup_delete(self, x, parent): + def __bottomup_delete( + self, x: Optional[Node[T]], parent: Optional[Node[T]] + ) -> None: x_diff = Node.difference(x, parent) if x_diff != 3 or not parent: return @@ -183,11 +104,12 @@ class WAVLTree: return rotating_around_root = parent.parent is None - new_root = parent + new_root: Optional[Node[T]] = parent parent_node_diffs = Node.differences(parent) if parent_node_diffs in ((1, 3), (3, 1)): if parent.left is x: + assert parent.right new_root = WAVLTree.__fix_delete( x, parent.right, @@ -197,6 +119,7 @@ class WAVLTree: Node.rotate_right, ) else: + assert parent.left new_root = WAVLTree.__fix_delete( x, parent.left, @@ -209,92 +132,34 @@ class WAVLTree: if rotating_around_root: self.root = new_root - def __delete_fixup(self, y, parent=None): + def __delete_fixup( + self, y: Optional[Node[T]], parent: Optional[Node[T]] = None + ) -> None: logger.debug(f"[__delete_fixup] y = {y}, parent = {parent}") z = y if y else parent logger.debug( - f"[z.demote()] Node.differences({repr(z)}) == (2, 2) ~>* {Node.differences(z)} == (2, 2)" + f"[z.demote()] Node.differences({repr(z)}) == (2, 2) ~>*" + f"{Node.differences(z)} == (2, 2)" ) + assert z if Node.differences(z) == (2, 2): z.demote() if parent: for y in (parent.left, parent.right): logger.debug( - f"[bottom-up delete] Node.difference({y}, {parent}) == 3 ~>* {Node.difference(y, parent)} == 3" + f"[bottom-up delete] Node.difference({y}, {parent}) == 3" + f"~>* {Node.difference(y, parent)} == 3" ) if Node.difference(y, parent) == 3: self.__bottomup_delete(y, parent) - def __fix_after_delete(self, node, parent): + def _delete_rebalance( + self, node: Optional[Node[T]], parent: Optional[Node[T]] + ) -> None: while node or parent: self.__delete_fixup(node, parent) node, parent = parent, (parent.parent if parent else None) - def delete_node(self, node): - y, parent = None, node.parent - - if not node.left: - logger.debug("node.left is None") - y = node.right - self.__transplant(node, node.right) - elif not node.right: - logger.debug("node.right is None") - y = node.left - self.__transplant(node, node.left) - else: - logger.debug("taking successor") - n = Node.minimum(node.right) - y, parent = None, (n.parent if n.parent is not node else n) - - if n.parent is not node: - parent = n.right if n.right else n.parent - self.__transplant(n, n.right) - n.right = node.right - n.right.parent = n - - self.__transplant(node, n) - n.left = node.left - n.left.parent = n - - return (y, parent) - - def __delete(self, value, node=None): - if node is None: - return - - if node.value != value: - return self.__delete(value, node.left if value < node.value else node.right) - - (y, parent) = self.delete_node(node) - self.__fix_after_delete(y, parent) - - def delete(self, value): - logger.debug(f"[DELETE] {value}") - self.__delete(value, self.root) - - def __str__(self): - result = "digraph {\n" - - queue = deque() - queue.append(self.root) - - edges = [] - while queue: - node = queue.popleft() - if not node: - continue - - result += f'"{str(node)}" [label="{node.value}, {node.rank}"];\n' - for child in (node.left, node.right): - if not child: - continue - - edges.append((node, child)) - queue.append(child) - - for fromNode, toNode in edges: - result += f'"{str(fromNode)}" -> "{str(toNode)}" [label="{Node.difference(toNode)}"]\n' - - return result + "}\n" + # endregion DeleteRebalance