from node import Node, Comparable, RotateFunction from ranked_tree import RankedTree import enum import logging from typing import Callable, Tuple, TypeVar, Optional logger = logging.getLogger(__name__) T = TypeVar("T", bound=Comparable) class Colour(enum.IntEnum): Red = 0 Black = 1 def is_double_black(x: Optional[Node]) -> bool: return x and 2 in Node.differences(x) class RBTree(RankedTree[T]): def is_correct_node( self, node: Optional[Node[T]], recursive: bool = True ) -> bool: if not node: return True left, right = Node.differences(node) if left not in (Colour.Red, Colour.Black): # left subtree has invalid difference return False elif right not in (Colour.Red, Colour.Black): # right subtree has invalid difference return False if Node.difference(node) == Colour.Red and ( left == Colour.Red or right == Colour.Red ): # two consecutive red nodes return False return not recursive or ( self.is_correct_node(node.left) and self.is_correct_node(node.right) ) # region InsertRebalance def _insert_rebalance_step( self, z: Node[T], y: Optional[Node[T]], right_child: Node[T], rotate_left: RotateFunction, rotate_right: RotateFunction, ) -> Node[T]: p = z.parent pp = p.parent if y is not None and Node.difference(y) == Colour.Red: # Case 1 # ====== # z’s uncle y is red pp.rank += 1 z = pp elif z == right_child: # Case 2 # ====== # z’s uncle y is black and z is a right child z = p rotate_left(tree=self, x=p) else: # Case 3 # ====== # z’s uncle y is black and z is a left child rotate_right(tree=self, x=pp) return z def _insert_rebalance(self, z: Node[T]) -> None: while z.parent is not None and Node.difference(z.parent) == Colour.Red: p = z.parent pp = p.parent assert pp if p == pp.left: z = self._insert_rebalance_step( z, pp.right, p.right, Node.rotate_left, Node.rotate_right ) else: z = self._insert_rebalance_step( z, pp.left, p.left, Node.rotate_right, Node.rotate_left ) # endregion InsertRebalance # region DeleteRebalance def _delete_rebalance_step( self, x: Node[T], w: Node[T], parent: Node[T], right: Callable[[Node[T]], Node[T]], rotate_left: RotateFunction, rotate_right: RotateFunction, ) -> Tuple[Optional[Node[T]], Optional[Node[T]]]: if Node.difference(w) == 0: # Case 1 # ====== # x’s sibling w is red rotate_left(tree=self, x=parent) w = right(parent) if Node.differences(w) == (1, 1): # Case 2 # ====== # x’s sibling w is black, and both of w’s children are black parent.rank -= 1 x = parent else: # Case 3 # ====== # x’s sibling w is black, # w’s left child is red, and w’s right child is black if Node.difference(right(w), w) == 1: rotate_right(tree=self, x=w) w = right(parent) # Case 4 # ====== # x’s sibling w is black, and w’s right child is red parent.rank -= 1 w.rank += 1 rotate_left(tree=self, x=parent) x = self.root return x, (x.parent if x else None) def _delete_rebalance( self, node: Optional[Node[T]], parent: Optional[Node[T]] ) -> None: if not node and not parent: return while node != self.root and is_double_black(parent): if node == parent.left: node, parent = self._delete_rebalance_step( node, parent.right, parent, lambda x: x.right, Node.rotate_left, Node.rotate_right, ) else: node, parent = self._delete_rebalance_step( node, parent.left, parent, lambda x: x.left, Node.rotate_right, Node.rotate_left, ) # endregion DeleteRebalance