rb-tree #5

Merged
mfocko merged 3 commits from rb-tree into main 2022-05-20 15:48:38 +02:00
2 changed files with 195 additions and 4 deletions
Showing only changes of commit 02ce49c86e - Show all commits

192
rbt.py Normal file
View file

@ -0,0 +1,192 @@
from node import Node, Comparable, RotateFunction
from ranked_tree import RankedTree
import enum
import logging
from typing import Callable, Tuple, TypeVar, Optional
logger = logging.getLogger(__name__)
T = TypeVar("T", bound=Comparable)
class Colour(enum.IntEnum):
Red = 0
Black = 1
def colour(x: Optional[Node[T]]) -> Colour:
if not x or not x.parent:
return Colour.Black
diff = Node.difference(x)
assert diff in (0, 1)
return Colour(diff)
class RBTree(RankedTree[T]):
def is_correct_node(
self, node: Optional[Node[T]], recursive: bool = True
) -> bool:
if not node:
return True
left, right = Node.differences(node)
if left not in (Colour.Red, Colour.Black):
# left subtree has invalid difference
return False
elif right not in (Colour.Red, Colour.Black):
# right subtree has invalid difference
return False
if Node.difference(node) == Colour.Red and (
left == Colour.Red or right == Colour.Red
):
# two consecutive red nodes
return False
return not recursive or (
self.is_correct_node(node.left)
and self.is_correct_node(node.right)
)
# region InsertRebalance
def _insert_rebalance_step(
self,
x: Node[T],
d: Optional[Node[T]],
y: Node[T],
rotate_left: RotateFunction,
rotate_right: RotateFunction,
) -> Node[T]:
p = x.parent
pp = p.parent
if d is not None and colour(d) == Colour.Red:
pp.rank += 1
x = pp
elif x == y:
x = p
rotate_left(tree=self, x=p)
else:
rotate_right(tree=self, x=pp)
return x
def _insert_rebalance(self, x: Node[T]) -> None:
while x.parent is not None and colour(x.parent) == Colour.Red:
p = x.parent
pp = x.parent.parent
assert pp
if p == pp.left:
x = self._insert_rebalance_step(
x, pp.right, p.right, Node.rotate_left, Node.rotate_right
)
else:
x = self._insert_rebalance_step(
x, pp.left, p.left, Node.rotate_right, Node.rotate_left
)
# endregion InsertRebalance
# region DeleteRebalance
def _delete_rebalance_step(
self,
x: Node[T],
w: Node[T],
parent: Node[T],
left: Callable[[Node[T]], Node[T]],
right: Callable[[Node[T]], Node[T]],
rotate_left: RotateFunction,
rotate_right: RotateFunction,
) -> Tuple[Optional[Node[T]], Optional[Node[T]]]:
if colour(w) == Colour.Red:
logger.debug("RB-Delete: Case 1 -- x's sibling w is red")
# w.colour = Colour.Black
# x.parent.colour = Colour.Red
rotate_left(tree=self, x=parent)
w = right(parent)
if colour(w.left) == Colour.Black and colour(w.right) == Colour.Black:
logger.debug(
"RB-Delete: Case 2 -- x's sibling w is black, and both of w's children are black"
)
# w.colour = Colour.Red
parent.rank -= 1
x = parent
else:
if colour(right(w)) == Colour.Black:
logger.debug(
"RB-Delete: Case 3 -- x's sibling w is black, w's left child is red, and w's right child is black"
)
# left(w).colour = Colour.Black
# w.colour = Colour.Red
rotate_right(tree=self, x=w)
w = right(parent)
logger.debug(
"RB-Delete: Case 4 -- x's sibling w is black, and w's right child is red"
)
# w.colour = colour(x.parent)
# x.parent.colour = Colour.Black
# right(w).colour = Colour.Black
rotate_left(tree=self, x=parent)
x = self.root
return x, (x.parent if x else None)
def _delete_rebalance(
self, node: Optional[Node[T]], parent: Optional[Node[T]]
) -> None:
if not node and not parent:
return
logger.debug(
"RB-Delete: Called with node=%s parent=%s", node, parent
)
logger.debug("RB-Delete: Colour of node: %s", colour(node))
if (node and 2 not in Node.differences(node)) or (
2 not in Node.differences(parent)
):
# we haven't deleted a black node
logger.debug(
"RB-Delete: No black node has been deleted; node: %s; diffs: %s",
node,
Node.differences(node),
)
return
logger.debug("RB-Delete: Tree before rebalancing:\n%s", str(self))
while node != self.root and colour(node) == Colour.Black:
logger.debug(
"RB-Delete: Rebalancing node=%s (%s-node) parent=%s (%s-node)",
node,
Node.differences(node),
parent,
Node.differences(parent),
)
if node == parent.left:
node, parent = self._delete_rebalance_step(
node,
parent.right,
parent,
lambda x: x.left,
lambda x: x.right,
Node.rotate_left,
Node.rotate_right,
)
else:
node, parent = self._delete_rebalance_step(
node,
parent.left,
parent,
lambda x: x.right,
lambda x: x.left,
Node.rotate_right,
Node.rotate_left,
)
# endregion DeleteRebalance

View file

@ -1,6 +1,7 @@
from avl import AVLTree from avl import AVLTree
from wavl import WAVLTree from wavl import WAVLTree
from ravl import RAVLTree from ravl import RAVLTree
from rbt import RBTree
import logging import logging
import random import random
@ -13,8 +14,7 @@ logger = logging.getLogger(__name__)
@pytest.mark.parametrize( @pytest.mark.parametrize(
("RankBalancedTree"), ("RankBalancedTree"), [RBTree, AVLTree, WAVLTree, RAVLTree]
[AVLTree, WAVLTree, RAVLTree]
) )
@hypothesis.settings(max_examples=10000) @hypothesis.settings(max_examples=10000)
@hypothesis.given(values=st.sets(st.integers())) @hypothesis.given(values=st.sets(st.integers()))
@ -44,8 +44,7 @@ def _report(t_before: str, t_after: str) -> None:
@pytest.mark.parametrize( @pytest.mark.parametrize(
("RankBalancedTree"), ("RankBalancedTree"), [RBTree, AVLTree, WAVLTree, RAVLTree]
[AVLTree, WAVLTree, RAVLTree]
) )
@hypothesis.settings(max_examples=10000) @hypothesis.settings(max_examples=10000)
@hypothesis.given(config=delete_strategy()) @hypothesis.given(config=delete_strategy())