From 02ce49c86e6e0a1c22b5388efc81aff334d218a8 Mon Sep 17 00:00:00 2001 From: Matej Focko Date: Fri, 20 May 2022 00:01:04 +0200 Subject: [PATCH] feat(rb): add mockup Signed-off-by: Matej Focko --- rbt.py | 192 +++++++++++++++++++++++++++++++++++++++++++++ test_properties.py | 7 +- 2 files changed, 195 insertions(+), 4 deletions(-) create mode 100644 rbt.py diff --git a/rbt.py b/rbt.py new file mode 100644 index 0000000..c6d7c33 --- /dev/null +++ b/rbt.py @@ -0,0 +1,192 @@ +from node import Node, Comparable, RotateFunction +from ranked_tree import RankedTree + +import enum +import logging +from typing import Callable, Tuple, TypeVar, Optional + +logger = logging.getLogger(__name__) +T = TypeVar("T", bound=Comparable) + + +class Colour(enum.IntEnum): + Red = 0 + Black = 1 + + +def colour(x: Optional[Node[T]]) -> Colour: + if not x or not x.parent: + return Colour.Black + + diff = Node.difference(x) + assert diff in (0, 1) + + return Colour(diff) + + +class RBTree(RankedTree[T]): + def is_correct_node( + self, node: Optional[Node[T]], recursive: bool = True + ) -> bool: + if not node: + return True + + left, right = Node.differences(node) + if left not in (Colour.Red, Colour.Black): + # left subtree has invalid difference + return False + elif right not in (Colour.Red, Colour.Black): + # right subtree has invalid difference + return False + + if Node.difference(node) == Colour.Red and ( + left == Colour.Red or right == Colour.Red + ): + # two consecutive red nodes + return False + + return not recursive or ( + self.is_correct_node(node.left) + and self.is_correct_node(node.right) + ) + + # region InsertRebalance + + def _insert_rebalance_step( + self, + x: Node[T], + d: Optional[Node[T]], + y: Node[T], + rotate_left: RotateFunction, + rotate_right: RotateFunction, + ) -> Node[T]: + p = x.parent + pp = p.parent + + if d is not None and colour(d) == Colour.Red: + pp.rank += 1 + x = pp + elif x == y: + x = p + rotate_left(tree=self, x=p) + else: + rotate_right(tree=self, x=pp) + + return x + + def _insert_rebalance(self, x: Node[T]) -> None: + while x.parent is not None and colour(x.parent) == Colour.Red: + p = x.parent + pp = x.parent.parent + + assert pp + if p == pp.left: + x = self._insert_rebalance_step( + x, pp.right, p.right, Node.rotate_left, Node.rotate_right + ) + else: + x = self._insert_rebalance_step( + x, pp.left, p.left, Node.rotate_right, Node.rotate_left + ) + + # endregion InsertRebalance + + # region DeleteRebalance + + def _delete_rebalance_step( + self, + x: Node[T], + w: Node[T], + parent: Node[T], + left: Callable[[Node[T]], Node[T]], + right: Callable[[Node[T]], Node[T]], + rotate_left: RotateFunction, + rotate_right: RotateFunction, + ) -> Tuple[Optional[Node[T]], Optional[Node[T]]]: + if colour(w) == Colour.Red: + logger.debug("RB-Delete: Case 1 -- x's sibling w is red") + # w.colour = Colour.Black + # x.parent.colour = Colour.Red + rotate_left(tree=self, x=parent) + w = right(parent) + + if colour(w.left) == Colour.Black and colour(w.right) == Colour.Black: + logger.debug( + "RB-Delete: Case 2 -- x's sibling w is black, and both of w's children are black" + ) + # w.colour = Colour.Red + parent.rank -= 1 + x = parent + else: + if colour(right(w)) == Colour.Black: + logger.debug( + "RB-Delete: Case 3 -- x's sibling w is black, w's left child is red, and w's right child is black" + ) + # left(w).colour = Colour.Black + # w.colour = Colour.Red + rotate_right(tree=self, x=w) + w = right(parent) + logger.debug( + "RB-Delete: Case 4 -- x's sibling w is black, and w's right child is red" + ) + # w.colour = colour(x.parent) + # x.parent.colour = Colour.Black + # right(w).colour = Colour.Black + rotate_left(tree=self, x=parent) + x = self.root + + return x, (x.parent if x else None) + + def _delete_rebalance( + self, node: Optional[Node[T]], parent: Optional[Node[T]] + ) -> None: + if not node and not parent: + return + + logger.debug( + "RB-Delete: Called with node=‹%s› parent=‹%s›", node, parent + ) + logger.debug("RB-Delete: Colour of node: %s", colour(node)) + if (node and 2 not in Node.differences(node)) or ( + 2 not in Node.differences(parent) + ): + # we haven't deleted a black node + logger.debug( + "RB-Delete: No black node has been deleted; node: %s; diffs: %s", + node, + Node.differences(node), + ) + return + + logger.debug("RB-Delete: Tree before rebalancing:\n%s", str(self)) + + while node != self.root and colour(node) == Colour.Black: + logger.debug( + "RB-Delete: Rebalancing node=‹%s› (%s-node) parent=‹%s› (%s-node)", + node, + Node.differences(node), + parent, + Node.differences(parent), + ) + if node == parent.left: + node, parent = self._delete_rebalance_step( + node, + parent.right, + parent, + lambda x: x.left, + lambda x: x.right, + Node.rotate_left, + Node.rotate_right, + ) + else: + node, parent = self._delete_rebalance_step( + node, + parent.left, + parent, + lambda x: x.right, + lambda x: x.left, + Node.rotate_right, + Node.rotate_left, + ) + + # endregion DeleteRebalance diff --git a/test_properties.py b/test_properties.py index 0026387..fe926eb 100644 --- a/test_properties.py +++ b/test_properties.py @@ -1,6 +1,7 @@ from avl import AVLTree from wavl import WAVLTree from ravl import RAVLTree +from rbt import RBTree import logging import random @@ -13,8 +14,7 @@ logger = logging.getLogger(__name__) @pytest.mark.parametrize( - ("RankBalancedTree"), - [AVLTree, WAVLTree, RAVLTree] + ("RankBalancedTree"), [RBTree, AVLTree, WAVLTree, RAVLTree] ) @hypothesis.settings(max_examples=10000) @hypothesis.given(values=st.sets(st.integers())) @@ -44,8 +44,7 @@ def _report(t_before: str, t_after: str) -> None: @pytest.mark.parametrize( - ("RankBalancedTree"), - [AVLTree, WAVLTree, RAVLTree] + ("RankBalancedTree"), [RBTree, AVLTree, WAVLTree, RAVLTree] ) @hypothesis.settings(max_examples=10000) @hypothesis.given(config=delete_strategy())