diff --git a/rbt.py b/rbt.py new file mode 100644 index 0000000..3adbf0d --- /dev/null +++ b/rbt.py @@ -0,0 +1,166 @@ +from node import Node, Comparable, RotateFunction +from ranked_tree import RankedTree + +import enum +import logging +from typing import Callable, Tuple, TypeVar, Optional + +logger = logging.getLogger(__name__) +T = TypeVar("T", bound=Comparable) + + +class Colour(enum.IntEnum): + Red = 0 + Black = 1 + + +def is_double_black(x: Optional[Node]) -> bool: + return x and 2 in Node.differences(x) + + +class RBTree(RankedTree[T]): + def is_correct_node( + self, node: Optional[Node[T]], recursive: bool = True + ) -> bool: + if not node: + return True + + left, right = Node.differences(node) + if left not in (Colour.Red, Colour.Black): + # left subtree has invalid difference + return False + elif right not in (Colour.Red, Colour.Black): + # right subtree has invalid difference + return False + + if Node.difference(node) == Colour.Red and ( + left == Colour.Red or right == Colour.Red + ): + # two consecutive red nodes + return False + + return not recursive or ( + self.is_correct_node(node.left) + and self.is_correct_node(node.right) + ) + + # region InsertRebalance + + def _insert_rebalance_step( + self, + z: Node[T], + y: Optional[Node[T]], + right_child: Node[T], + rotate_left: RotateFunction, + rotate_right: RotateFunction, + ) -> Node[T]: + p = z.parent + pp = p.parent + + if y is not None and Node.difference(y) == Colour.Red: + # Case 1 + # ====== + # z’s uncle y is red + pp.rank += Colour.Black + z = pp + elif z == right_child: + # Case 2 + # ====== + # z’s uncle y is black and z is a right child + z = p + rotate_left(tree=self, x=p) + else: + # Case 3 + # ====== + # z’s uncle y is black and z is a left child + rotate_right(tree=self, x=pp) + + return z + + def _insert_rebalance(self, z: Node[T]) -> None: + while z.parent is not None and Node.difference(z.parent) == Colour.Red: + p = z.parent + pp = p.parent + + assert pp + if p == pp.left: + z = self._insert_rebalance_step( + z, pp.right, p.right, Node.rotate_left, Node.rotate_right + ) + else: + z = self._insert_rebalance_step( + z, pp.left, p.left, Node.rotate_right, Node.rotate_left + ) + + # endregion InsertRebalance + + # region DeleteRebalance + + def _delete_rebalance_step( + self, + x: Node[T], + w: Node[T], + parent: Node[T], + right: Callable[[Node[T]], Node[T]], + rotate_left: RotateFunction, + rotate_right: RotateFunction, + ) -> Tuple[Optional[Node[T]], Optional[Node[T]]]: + if Node.difference(w) == Colour.Red: + # Case 1 + # ====== + # x’s sibling w is red + rotate_left(tree=self, x=parent) + w = right(parent) + + if Node.differences(w) == (Colour.Black, Colour.Black): + # Case 2 + # ====== + # x’s sibling w is black, and both of w’s children are black + parent.rank -= Colour.Black + x = parent + else: + # Case 3 + # ====== + # x’s sibling w is black, + # w’s left child is red, and w’s right child is black + if Node.difference(right(w), w) == Colour.Black: + rotate_right(tree=self, x=w) + w = right(parent) + + # Case 4 + # ====== + # x’s sibling w is black, and w’s right child is red + parent.rank -= Colour.Black + w.rank += Colour.Black + rotate_left(tree=self, x=parent) + x = self.root + + return x, (x.parent if x else None) + + def _delete_rebalance( + self, node: Optional[Node[T]], parent: Optional[Node[T]] + ) -> None: + if not node and not parent: + return + + while node != self.root and is_double_black(parent): + if node == parent.left: + node, parent = self._delete_rebalance_step( + node, + parent.right, + parent, + lambda x: x.right, + Node.rotate_left, + Node.rotate_right, + ) + else: + node, parent = self._delete_rebalance_step( + node, + parent.left, + parent, + lambda x: x.left, + Node.rotate_right, + Node.rotate_left, + ) + + # endregion DeleteRebalance diff --git a/test_properties.py b/test_properties.py index 0026387..fe926eb 100644 --- a/test_properties.py +++ b/test_properties.py @@ -1,6 +1,7 @@ from avl import AVLTree from wavl import WAVLTree from ravl import RAVLTree +from rbt import RBTree import logging import random @@ -13,8 +14,7 @@ logger = logging.getLogger(__name__) @pytest.mark.parametrize( - ("RankBalancedTree"), - [AVLTree, WAVLTree, RAVLTree] + ("RankBalancedTree"), [RBTree, AVLTree, WAVLTree, RAVLTree] ) @hypothesis.settings(max_examples=10000) @hypothesis.given(values=st.sets(st.integers())) @@ -44,8 +44,7 @@ def _report(t_before: str, t_after: str) -> None: @pytest.mark.parametrize( - ("RankBalancedTree"), - [AVLTree, WAVLTree, RAVLTree] + ("RankBalancedTree"), [RBTree, AVLTree, WAVLTree, RAVLTree] ) @hypothesis.settings(max_examples=10000) @hypothesis.given(config=delete_strategy()) diff --git a/test_rbt.py b/test_rbt.py new file mode 100644 index 0000000..fbdf14f --- /dev/null +++ b/test_rbt.py @@ -0,0 +1,43 @@ +from rbt import RBTree + +import logging +import pytest + +logger = logging.getLogger(__name__) + + +@pytest.mark.parametrize( + "values, delete_order", + [ + ([0, 1, -1], [0, -1, 1]), + ([0, 1, 2, 3, 4, 5], [4, 2, 1, 0, 5, 3]), + ( + [32, 1, 2, 3, 4, 5, 0, -2, -4, -3, -1], + [-4, -3, 1, 2, 5, 3, -2, 4, 32, -1, 0], + ), + ([0, 1, 2, -2, -3, -1], [-3, 2, 1, 0, -1, -2]), + ([0, 1, 2, 3, 4, 5, -1], [4, 2, 1, 0, 5, 3, -1]), + ], +) +def test_delete_minimal(values, delete_order): + tree = RBTree() + + for value in values: + tree.insert(value) + + for value in delete_order: + logger.info("Deleting %s", value) + + before = str(tree) + tree.delete(value) + after = str(tree) + + try: + assert tree.is_correct + except AssertionError: + logger.info( + f"[FAIL] Delete {value} from {values} in order {delete_order}" + ) + logger.info(f"Before:\n{before}") + logger.info(f"After:\n{after}") + raise