From d39e11713cf8debc22bce6b2fdcced4837e6a23b Mon Sep 17 00:00:00 2001 From: Matej Focko Date: Fri, 20 May 2022 11:47:38 +0200 Subject: [PATCH] feat(rbt): finish delete Signed-off-by: Matej Focko --- rbt.py | 122 +++++++++++++++++++++------------------------------- test_rbt.py | 43 ++++++++++++++++++ 2 files changed, 91 insertions(+), 74 deletions(-) create mode 100644 test_rbt.py diff --git a/rbt.py b/rbt.py index c6d7c33..9f9203a 100644 --- a/rbt.py +++ b/rbt.py @@ -14,14 +14,8 @@ class Colour(enum.IntEnum): Black = 1 -def colour(x: Optional[Node[T]]) -> Colour: - if not x or not x.parent: - return Colour.Black - - diff = Node.difference(x) - assert diff in (0, 1) - - return Colour(diff) +def is_double_black(x: Optional[Node]) -> bool: + return x and 2 in Node.differences(x) class RBTree(RankedTree[T]): @@ -54,39 +48,48 @@ class RBTree(RankedTree[T]): def _insert_rebalance_step( self, - x: Node[T], - d: Optional[Node[T]], - y: Node[T], + z: Node[T], + y: Optional[Node[T]], + right_child: Node[T], rotate_left: RotateFunction, rotate_right: RotateFunction, ) -> Node[T]: - p = x.parent + p = z.parent pp = p.parent - if d is not None and colour(d) == Colour.Red: + if y is not None and Node.difference(y) == Colour.Red: + # Case 1 + # ====== + # z’s uncle y is red pp.rank += 1 - x = pp - elif x == y: - x = p + z = pp + elif z == right_child: + # Case 2 + # ====== + # z’s uncle y is black and z is a right child + z = p rotate_left(tree=self, x=p) else: + # Case 3 + # ====== + # z’s uncle y is black and z is a left child rotate_right(tree=self, x=pp) - return x + return z - def _insert_rebalance(self, x: Node[T]) -> None: - while x.parent is not None and colour(x.parent) == Colour.Red: - p = x.parent - pp = x.parent.parent + def _insert_rebalance(self, z: Node[T]) -> None: + while z.parent is not None and Node.difference(z.parent) == Colour.Red: + p = z.parent + pp = p.parent assert pp if p == pp.left: - x = self._insert_rebalance_step( - x, pp.right, p.right, Node.rotate_left, Node.rotate_right + z = self._insert_rebalance_step( + z, pp.right, p.right, Node.rotate_left, Node.rotate_right ) else: - x = self._insert_rebalance_step( - x, pp.left, p.left, Node.rotate_right, Node.rotate_left + z = self._insert_rebalance_step( + z, pp.left, p.left, Node.rotate_right, Node.rotate_left ) # endregion InsertRebalance @@ -98,40 +101,37 @@ class RBTree(RankedTree[T]): x: Node[T], w: Node[T], parent: Node[T], - left: Callable[[Node[T]], Node[T]], right: Callable[[Node[T]], Node[T]], rotate_left: RotateFunction, rotate_right: RotateFunction, ) -> Tuple[Optional[Node[T]], Optional[Node[T]]]: - if colour(w) == Colour.Red: - logger.debug("RB-Delete: Case 1 -- x's sibling w is red") - # w.colour = Colour.Black - # x.parent.colour = Colour.Red + if Node.difference(w) == 0: + # Case 1 + # ====== + # x’s sibling w is red rotate_left(tree=self, x=parent) w = right(parent) - if colour(w.left) == Colour.Black and colour(w.right) == Colour.Black: - logger.debug( - "RB-Delete: Case 2 -- x's sibling w is black, and both of w's children are black" - ) - # w.colour = Colour.Red + if Node.differences(w) == (1, 1): + # Case 2 + # ====== + # x’s sibling w is black, and both of w’s children are black parent.rank -= 1 x = parent else: - if colour(right(w)) == Colour.Black: - logger.debug( - "RB-Delete: Case 3 -- x's sibling w is black, w's left child is red, and w's right child is black" - ) - # left(w).colour = Colour.Black - # w.colour = Colour.Red + # Case 3 + # ====== + # x’s sibling w is black, + # w’s left child is red, and w’s right child is black + if Node.difference(right(w), w) == 1: rotate_right(tree=self, x=w) w = right(parent) - logger.debug( - "RB-Delete: Case 4 -- x's sibling w is black, and w's right child is red" - ) - # w.colour = colour(x.parent) - # x.parent.colour = Colour.Black - # right(w).colour = Colour.Black + + # Case 4 + # ====== + # x’s sibling w is black, and w’s right child is red + parent.rank -= 1 + w.rank += 1 rotate_left(tree=self, x=parent) x = self.root @@ -143,37 +143,12 @@ class RBTree(RankedTree[T]): if not node and not parent: return - logger.debug( - "RB-Delete: Called with node=‹%s› parent=‹%s›", node, parent - ) - logger.debug("RB-Delete: Colour of node: %s", colour(node)) - if (node and 2 not in Node.differences(node)) or ( - 2 not in Node.differences(parent) - ): - # we haven't deleted a black node - logger.debug( - "RB-Delete: No black node has been deleted; node: %s; diffs: %s", - node, - Node.differences(node), - ) - return - - logger.debug("RB-Delete: Tree before rebalancing:\n%s", str(self)) - - while node != self.root and colour(node) == Colour.Black: - logger.debug( - "RB-Delete: Rebalancing node=‹%s› (%s-node) parent=‹%s› (%s-node)", - node, - Node.differences(node), - parent, - Node.differences(parent), - ) + while node != self.root and is_double_black(parent): if node == parent.left: node, parent = self._delete_rebalance_step( node, parent.right, parent, - lambda x: x.left, lambda x: x.right, Node.rotate_left, Node.rotate_right, @@ -183,7 +158,6 @@ class RBTree(RankedTree[T]): node, parent.left, parent, - lambda x: x.right, lambda x: x.left, Node.rotate_right, Node.rotate_left, diff --git a/test_rbt.py b/test_rbt.py new file mode 100644 index 0000000..fbdf14f --- /dev/null +++ b/test_rbt.py @@ -0,0 +1,43 @@ +from rbt import RBTree + +import logging +import pytest + +logger = logging.getLogger(__name__) + + +@pytest.mark.parametrize( + "values, delete_order", + [ + ([0, 1, -1], [0, -1, 1]), + ([0, 1, 2, 3, 4, 5], [4, 2, 1, 0, 5, 3]), + ( + [32, 1, 2, 3, 4, 5, 0, -2, -4, -3, -1], + [-4, -3, 1, 2, 5, 3, -2, 4, 32, -1, 0], + ), + ([0, 1, 2, -2, -3, -1], [-3, 2, 1, 0, -1, -2]), + ([0, 1, 2, 3, 4, 5, -1], [4, 2, 1, 0, 5, 3, -1]), + ], +) +def test_delete_minimal(values, delete_order): + tree = RBTree() + + for value in values: + tree.insert(value) + + for value in delete_order: + logger.info("Deleting %s", value) + + before = str(tree) + tree.delete(value) + after = str(tree) + + try: + assert tree.is_correct + except AssertionError: + logger.info( + f"[FAIL] Delete {value} from {values} in order {delete_order}" + ) + logger.info(f"Before:\n{before}") + logger.info(f"After:\n{after}") + raise