python/rbt.py
Matej Focko 02ce49c86e
feat(rb): add mockup
Signed-off-by: Matej Focko <mfocko@redhat.com>
2022-05-20 11:50:56 +02:00

192 lines
5.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from node import Node, Comparable, RotateFunction
from ranked_tree import RankedTree
import enum
import logging
from typing import Callable, Tuple, TypeVar, Optional
logger = logging.getLogger(__name__)
T = TypeVar("T", bound=Comparable)
class Colour(enum.IntEnum):
Red = 0
Black = 1
def colour(x: Optional[Node[T]]) -> Colour:
if not x or not x.parent:
return Colour.Black
diff = Node.difference(x)
assert diff in (0, 1)
return Colour(diff)
class RBTree(RankedTree[T]):
def is_correct_node(
self, node: Optional[Node[T]], recursive: bool = True
) -> bool:
if not node:
return True
left, right = Node.differences(node)
if left not in (Colour.Red, Colour.Black):
# left subtree has invalid difference
return False
elif right not in (Colour.Red, Colour.Black):
# right subtree has invalid difference
return False
if Node.difference(node) == Colour.Red and (
left == Colour.Red or right == Colour.Red
):
# two consecutive red nodes
return False
return not recursive or (
self.is_correct_node(node.left)
and self.is_correct_node(node.right)
)
# region InsertRebalance
def _insert_rebalance_step(
self,
x: Node[T],
d: Optional[Node[T]],
y: Node[T],
rotate_left: RotateFunction,
rotate_right: RotateFunction,
) -> Node[T]:
p = x.parent
pp = p.parent
if d is not None and colour(d) == Colour.Red:
pp.rank += 1
x = pp
elif x == y:
x = p
rotate_left(tree=self, x=p)
else:
rotate_right(tree=self, x=pp)
return x
def _insert_rebalance(self, x: Node[T]) -> None:
while x.parent is not None and colour(x.parent) == Colour.Red:
p = x.parent
pp = x.parent.parent
assert pp
if p == pp.left:
x = self._insert_rebalance_step(
x, pp.right, p.right, Node.rotate_left, Node.rotate_right
)
else:
x = self._insert_rebalance_step(
x, pp.left, p.left, Node.rotate_right, Node.rotate_left
)
# endregion InsertRebalance
# region DeleteRebalance
def _delete_rebalance_step(
self,
x: Node[T],
w: Node[T],
parent: Node[T],
left: Callable[[Node[T]], Node[T]],
right: Callable[[Node[T]], Node[T]],
rotate_left: RotateFunction,
rotate_right: RotateFunction,
) -> Tuple[Optional[Node[T]], Optional[Node[T]]]:
if colour(w) == Colour.Red:
logger.debug("RB-Delete: Case 1 -- x's sibling w is red")
# w.colour = Colour.Black
# x.parent.colour = Colour.Red
rotate_left(tree=self, x=parent)
w = right(parent)
if colour(w.left) == Colour.Black and colour(w.right) == Colour.Black:
logger.debug(
"RB-Delete: Case 2 -- x's sibling w is black, and both of w's children are black"
)
# w.colour = Colour.Red
parent.rank -= 1
x = parent
else:
if colour(right(w)) == Colour.Black:
logger.debug(
"RB-Delete: Case 3 -- x's sibling w is black, w's left child is red, and w's right child is black"
)
# left(w).colour = Colour.Black
# w.colour = Colour.Red
rotate_right(tree=self, x=w)
w = right(parent)
logger.debug(
"RB-Delete: Case 4 -- x's sibling w is black, and w's right child is red"
)
# w.colour = colour(x.parent)
# x.parent.colour = Colour.Black
# right(w).colour = Colour.Black
rotate_left(tree=self, x=parent)
x = self.root
return x, (x.parent if x else None)
def _delete_rebalance(
self, node: Optional[Node[T]], parent: Optional[Node[T]]
) -> None:
if not node and not parent:
return
logger.debug(
"RB-Delete: Called with node=%s parent=%s", node, parent
)
logger.debug("RB-Delete: Colour of node: %s", colour(node))
if (node and 2 not in Node.differences(node)) or (
2 not in Node.differences(parent)
):
# we haven't deleted a black node
logger.debug(
"RB-Delete: No black node has been deleted; node: %s; diffs: %s",
node,
Node.differences(node),
)
return
logger.debug("RB-Delete: Tree before rebalancing:\n%s", str(self))
while node != self.root and colour(node) == Colour.Black:
logger.debug(
"RB-Delete: Rebalancing node=%s (%s-node) parent=%s (%s-node)",
node,
Node.differences(node),
parent,
Node.differences(parent),
)
if node == parent.left:
node, parent = self._delete_rebalance_step(
node,
parent.right,
parent,
lambda x: x.left,
lambda x: x.right,
Node.rotate_left,
Node.rotate_right,
)
else:
node, parent = self._delete_rebalance_step(
node,
parent.left,
parent,
lambda x: x.right,
lambda x: x.left,
Node.rotate_right,
Node.rotate_left,
)
# endregion DeleteRebalance