rb-tree #5

Merged
mfocko merged 3 commits from rb-tree into main 2022-05-20 15:48:38 +02:00
2 changed files with 91 additions and 74 deletions
Showing only changes of commit d39e11713c - Show all commits

122
rbt.py
View file

@ -14,14 +14,8 @@ class Colour(enum.IntEnum):
Black = 1 Black = 1
def colour(x: Optional[Node[T]]) -> Colour: def is_double_black(x: Optional[Node]) -> bool:
if not x or not x.parent: return x and 2 in Node.differences(x)
return Colour.Black
diff = Node.difference(x)
assert diff in (0, 1)
return Colour(diff)
class RBTree(RankedTree[T]): class RBTree(RankedTree[T]):
@ -54,39 +48,48 @@ class RBTree(RankedTree[T]):
def _insert_rebalance_step( def _insert_rebalance_step(
self, self,
x: Node[T], z: Node[T],
d: Optional[Node[T]], y: Optional[Node[T]],
y: Node[T], right_child: Node[T],
rotate_left: RotateFunction, rotate_left: RotateFunction,
rotate_right: RotateFunction, rotate_right: RotateFunction,
) -> Node[T]: ) -> Node[T]:
p = x.parent p = z.parent
pp = p.parent pp = p.parent
if d is not None and colour(d) == Colour.Red: if y is not None and Node.difference(y) == Colour.Red:
# Case 1
# ======
# zs uncle y is red
pp.rank += 1 pp.rank += 1
x = pp z = pp
elif x == y: elif z == right_child:
x = p # Case 2
# ======
# zs uncle y is black and z is a right child
z = p
rotate_left(tree=self, x=p) rotate_left(tree=self, x=p)
else: else:
# Case 3
# ======
# zs uncle y is black and z is a left child
rotate_right(tree=self, x=pp) rotate_right(tree=self, x=pp)
return x return z
def _insert_rebalance(self, x: Node[T]) -> None: def _insert_rebalance(self, z: Node[T]) -> None:
while x.parent is not None and colour(x.parent) == Colour.Red: while z.parent is not None and Node.difference(z.parent) == Colour.Red:
p = x.parent p = z.parent
pp = x.parent.parent pp = p.parent
assert pp assert pp
if p == pp.left: if p == pp.left:
x = self._insert_rebalance_step( z = self._insert_rebalance_step(
x, pp.right, p.right, Node.rotate_left, Node.rotate_right z, pp.right, p.right, Node.rotate_left, Node.rotate_right
) )
else: else:
x = self._insert_rebalance_step( z = self._insert_rebalance_step(
x, pp.left, p.left, Node.rotate_right, Node.rotate_left z, pp.left, p.left, Node.rotate_right, Node.rotate_left
) )
# endregion InsertRebalance # endregion InsertRebalance
@ -98,40 +101,37 @@ class RBTree(RankedTree[T]):
x: Node[T], x: Node[T],
w: Node[T], w: Node[T],
parent: Node[T], parent: Node[T],
left: Callable[[Node[T]], Node[T]],
right: Callable[[Node[T]], Node[T]], right: Callable[[Node[T]], Node[T]],
rotate_left: RotateFunction, rotate_left: RotateFunction,
rotate_right: RotateFunction, rotate_right: RotateFunction,
) -> Tuple[Optional[Node[T]], Optional[Node[T]]]: ) -> Tuple[Optional[Node[T]], Optional[Node[T]]]:
if colour(w) == Colour.Red: if Node.difference(w) == 0:
logger.debug("RB-Delete: Case 1 -- x's sibling w is red") # Case 1
# w.colour = Colour.Black # ======
# x.parent.colour = Colour.Red # xs sibling w is red
rotate_left(tree=self, x=parent) rotate_left(tree=self, x=parent)
w = right(parent) w = right(parent)
if colour(w.left) == Colour.Black and colour(w.right) == Colour.Black: if Node.differences(w) == (1, 1):
logger.debug( # Case 2
"RB-Delete: Case 2 -- x's sibling w is black, and both of w's children are black" # ======
) # xs sibling w is black, and both of ws children are black
# w.colour = Colour.Red
parent.rank -= 1 parent.rank -= 1
x = parent x = parent
else: else:
if colour(right(w)) == Colour.Black: # Case 3
logger.debug( # ======
"RB-Delete: Case 3 -- x's sibling w is black, w's left child is red, and w's right child is black" # xs sibling w is black,
) # ws left child is red, and ws right child is black
# left(w).colour = Colour.Black if Node.difference(right(w), w) == 1:
# w.colour = Colour.Red
rotate_right(tree=self, x=w) rotate_right(tree=self, x=w)
w = right(parent) w = right(parent)
logger.debug(
"RB-Delete: Case 4 -- x's sibling w is black, and w's right child is red" # Case 4
) # ======
# w.colour = colour(x.parent) # xs sibling w is black, and ws right child is red
# x.parent.colour = Colour.Black parent.rank -= 1
# right(w).colour = Colour.Black w.rank += 1
rotate_left(tree=self, x=parent) rotate_left(tree=self, x=parent)
x = self.root x = self.root
@ -143,37 +143,12 @@ class RBTree(RankedTree[T]):
if not node and not parent: if not node and not parent:
return return
logger.debug( while node != self.root and is_double_black(parent):
"RB-Delete: Called with node=%s parent=%s", node, parent
)
logger.debug("RB-Delete: Colour of node: %s", colour(node))
if (node and 2 not in Node.differences(node)) or (
2 not in Node.differences(parent)
):
# we haven't deleted a black node
logger.debug(
"RB-Delete: No black node has been deleted; node: %s; diffs: %s",
node,
Node.differences(node),
)
return
logger.debug("RB-Delete: Tree before rebalancing:\n%s", str(self))
while node != self.root and colour(node) == Colour.Black:
logger.debug(
"RB-Delete: Rebalancing node=%s (%s-node) parent=%s (%s-node)",
node,
Node.differences(node),
parent,
Node.differences(parent),
)
if node == parent.left: if node == parent.left:
node, parent = self._delete_rebalance_step( node, parent = self._delete_rebalance_step(
node, node,
parent.right, parent.right,
parent, parent,
lambda x: x.left,
lambda x: x.right, lambda x: x.right,
Node.rotate_left, Node.rotate_left,
Node.rotate_right, Node.rotate_right,
@ -183,7 +158,6 @@ class RBTree(RankedTree[T]):
node, node,
parent.left, parent.left,
parent, parent,
lambda x: x.right,
lambda x: x.left, lambda x: x.left,
Node.rotate_right, Node.rotate_right,
Node.rotate_left, Node.rotate_left,

43
test_rbt.py Normal file
View file

@ -0,0 +1,43 @@
from rbt import RBTree
import logging
import pytest
logger = logging.getLogger(__name__)
@pytest.mark.parametrize(
"values, delete_order",
[
([0, 1, -1], [0, -1, 1]),
([0, 1, 2, 3, 4, 5], [4, 2, 1, 0, 5, 3]),
(
[32, 1, 2, 3, 4, 5, 0, -2, -4, -3, -1],
[-4, -3, 1, 2, 5, 3, -2, 4, 32, -1, 0],
),
([0, 1, 2, -2, -3, -1], [-3, 2, 1, 0, -1, -2]),
([0, 1, 2, 3, 4, 5, -1], [4, 2, 1, 0, 5, 3, -1]),
],
)
def test_delete_minimal(values, delete_order):
tree = RBTree()
for value in values:
tree.insert(value)
for value in delete_order:
logger.info("Deleting %s", value)
before = str(tree)
tree.delete(value)
after = str(tree)
try:
assert tree.is_correct
except AssertionError:
logger.info(
f"[FAIL] Delete {value} from {values} in order {delete_order}"
)
logger.info(f"Before:\n{before}")
logger.info(f"After:\n{after}")
raise