from node import Node, Comparable from ranked_tree import RankedTree, RotateFunction import enum import logging from typing import Callable, Tuple, TypeVar, Optional logger = logging.getLogger(__name__) T = TypeVar("T", bound=Comparable) class Colour(enum.IntEnum): """ Represents colour of the edge or node. """ Red = 0 Black = 1 def has_double_black(x: Optional[Node[T]]) -> bool: """ Checks for double black child of x. Args: x: Node to be checked. Returns: `true`, if `x` has a double black node, `false` otherwise. """ return x is not None and 2 in Node.differences(x) class RBTree(RankedTree[T]): def is_correct_node( self, node: Optional[Node[T]], recursive: bool = True ) -> bool: if not node: return True left, right = Node.differences(node) if left not in (Colour.Red, Colour.Black): # left subtree has invalid difference return False elif right not in (Colour.Red, Colour.Black): # right subtree has invalid difference return False if Node.difference(node) == Colour.Red and ( left == Colour.Red or right == Colour.Red ): # two consecutive red nodes return False return not recursive or ( self.is_correct_node(node.left) and self.is_correct_node(node.right) ) # region InsertRebalance def _insert_rebalance_step( self, z: Node[T], y: Optional[Node[T]], right_child: Node[T], rotate_left: RotateFunction[T], rotate_right: RotateFunction[T], ) -> Node[T]: p = z.parent pp = p.parent if y is not None and Node.difference(y) == Colour.Red: # Case 1 # ====== # z’s uncle y is red pp.rank += Colour.Black z = pp elif z == right_child: # Case 2 # ====== # z’s uncle y is black and z is a right child z = p rotate_left(p) else: # Case 3 # ====== # z’s uncle y is black and z is a left child rotate_right(pp) return z def _insert_rebalance(self, z: Node[T]) -> None: while z.parent is not None and Node.difference(z.parent) == Colour.Red: p = z.parent pp = p.parent assert pp if p == pp.left: z = self._insert_rebalance_step( z, pp.right, p.right, self.rotate_left, self.rotate_right ) else: z = self._insert_rebalance_step( z, pp.left, p.left, self.rotate_right, self.rotate_left ) # endregion InsertRebalance # region DeleteRebalance def _delete_rebalance_step( self, x: Node[T], w: Node[T], parent: Node[T], right: Callable[[Node[T]], Optional[Node[T]]], rotate_left: RotateFunction[T], rotate_right: RotateFunction[T], ) -> Tuple[Optional[Node[T]], Optional[Node[T]]]: if Node.difference(w) == Colour.Red: # Case 1 # ====== # x’s sibling w is red rotate_left(parent) w = right(parent) if Node.differences(w) == (Colour.Black, Colour.Black): # Case 2 # ====== # x’s sibling w is black, and both of w’s children are black parent.rank -= Colour.Black x = parent else: # Case 3 # ====== # x’s sibling w is black, # w’s left child is red, and w’s right child is black if Node.difference(right(w), w) == Colour.Black: rotate_right(w) w = right(parent) # Case 4 # ====== # x’s sibling w is black, and w’s right child is red parent.rank -= Colour.Black w.rank += Colour.Black rotate_left(parent) x = self.root return x, (x.parent if x else None) def _delete_rebalance( self, node: Optional[Node[T]], parent: Optional[Node[T]] ) -> None: if not node and not parent: return while node != self.root and has_double_black(parent): if node == parent.left: node, parent = self._delete_rebalance_step( node, parent.right, parent, lambda x: x.right, self.rotate_left, self.rotate_right, ) else: node, parent = self._delete_rebalance_step( node, parent.left, parent, lambda x: x.left, self.rotate_right, self.rotate_left, ) # endregion DeleteRebalance