python/rbt.py

167 lines
4.8 KiB
Python
Raw Normal View History

from node import Node, Comparable, RotateFunction
from ranked_tree import RankedTree
import enum
import logging
from typing import Callable, Tuple, TypeVar, Optional
logger = logging.getLogger(__name__)
T = TypeVar("T", bound=Comparable)
class Colour(enum.IntEnum):
Red = 0
Black = 1
def is_double_black(x: Optional[Node]) -> bool:
return x and 2 in Node.differences(x)
class RBTree(RankedTree[T]):
def is_correct_node(
self, node: Optional[Node[T]], recursive: bool = True
) -> bool:
if not node:
return True
left, right = Node.differences(node)
if left not in (Colour.Red, Colour.Black):
# left subtree has invalid difference
return False
elif right not in (Colour.Red, Colour.Black):
# right subtree has invalid difference
return False
if Node.difference(node) == Colour.Red and (
left == Colour.Red or right == Colour.Red
):
# two consecutive red nodes
return False
return not recursive or (
self.is_correct_node(node.left)
and self.is_correct_node(node.right)
)
# region InsertRebalance
def _insert_rebalance_step(
self,
z: Node[T],
y: Optional[Node[T]],
right_child: Node[T],
rotate_left: RotateFunction,
rotate_right: RotateFunction,
) -> Node[T]:
p = z.parent
pp = p.parent
if y is not None and Node.difference(y) == Colour.Red:
# Case 1
# ======
# zs uncle y is red
pp.rank += Colour.Black
z = pp
elif z == right_child:
# Case 2
# ======
# zs uncle y is black and z is a right child
z = p
rotate_left(tree=self, x=p)
else:
# Case 3
# ======
# zs uncle y is black and z is a left child
rotate_right(tree=self, x=pp)
return z
def _insert_rebalance(self, z: Node[T]) -> None:
while z.parent is not None and Node.difference(z.parent) == Colour.Red:
p = z.parent
pp = p.parent
assert pp
if p == pp.left:
z = self._insert_rebalance_step(
z, pp.right, p.right, Node.rotate_left, Node.rotate_right
)
else:
z = self._insert_rebalance_step(
z, pp.left, p.left, Node.rotate_right, Node.rotate_left
)
# endregion InsertRebalance
# region DeleteRebalance
def _delete_rebalance_step(
self,
x: Node[T],
w: Node[T],
parent: Node[T],
right: Callable[[Node[T]], Node[T]],
rotate_left: RotateFunction,
rotate_right: RotateFunction,
) -> Tuple[Optional[Node[T]], Optional[Node[T]]]:
if Node.difference(w) == Colour.Red:
# Case 1
# ======
# xs sibling w is red
rotate_left(tree=self, x=parent)
w = right(parent)
if Node.differences(w) == (Colour.Black, Colour.Black):
# Case 2
# ======
# xs sibling w is black, and both of ws children are black
parent.rank -= Colour.Black
x = parent
else:
# Case 3
# ======
# xs sibling w is black,
# ws left child is red, and ws right child is black
if Node.difference(right(w), w) == Colour.Black:
rotate_right(tree=self, x=w)
w = right(parent)
# Case 4
# ======
# xs sibling w is black, and ws right child is red
parent.rank -= Colour.Black
w.rank += Colour.Black
rotate_left(tree=self, x=parent)
x = self.root
return x, (x.parent if x else None)
def _delete_rebalance(
self, node: Optional[Node[T]], parent: Optional[Node[T]]
) -> None:
if not node and not parent:
return
while node != self.root and is_double_black(parent):
if node == parent.left:
node, parent = self._delete_rebalance_step(
node,
parent.right,
parent,
lambda x: x.right,
Node.rotate_left,
Node.rotate_right,
)
else:
node, parent = self._delete_rebalance_step(
node,
parent.left,
parent,
lambda x: x.left,
Node.rotate_right,
Node.rotate_left,
)
# endregion DeleteRebalance