from node import Node, Comparable, RotateFunction from ranked_tree import RankedTree import enum import logging from typing import Callable, Tuple, TypeVar, Optional logger = logging.getLogger(__name__) T = TypeVar("T", bound=Comparable) class Colour(enum.IntEnum): Red = 0 Black = 1 def colour(x: Optional[Node[T]]) -> Colour: if not x or not x.parent: return Colour.Black diff = Node.difference(x) assert diff in (0, 1) return Colour(diff) class RBTree(RankedTree[T]): def is_correct_node( self, node: Optional[Node[T]], recursive: bool = True ) -> bool: if not node: return True left, right = Node.differences(node) if left not in (Colour.Red, Colour.Black): # left subtree has invalid difference return False elif right not in (Colour.Red, Colour.Black): # right subtree has invalid difference return False if Node.difference(node) == Colour.Red and ( left == Colour.Red or right == Colour.Red ): # two consecutive red nodes return False return not recursive or ( self.is_correct_node(node.left) and self.is_correct_node(node.right) ) # region InsertRebalance def _insert_rebalance_step( self, x: Node[T], d: Optional[Node[T]], y: Node[T], rotate_left: RotateFunction, rotate_right: RotateFunction, ) -> Node[T]: p = x.parent pp = p.parent if d is not None and colour(d) == Colour.Red: pp.rank += 1 x = pp elif x == y: x = p rotate_left(tree=self, x=p) else: rotate_right(tree=self, x=pp) return x def _insert_rebalance(self, x: Node[T]) -> None: while x.parent is not None and colour(x.parent) == Colour.Red: p = x.parent pp = x.parent.parent assert pp if p == pp.left: x = self._insert_rebalance_step( x, pp.right, p.right, Node.rotate_left, Node.rotate_right ) else: x = self._insert_rebalance_step( x, pp.left, p.left, Node.rotate_right, Node.rotate_left ) # endregion InsertRebalance # region DeleteRebalance def _delete_rebalance_step( self, x: Node[T], w: Node[T], parent: Node[T], left: Callable[[Node[T]], Node[T]], right: Callable[[Node[T]], Node[T]], rotate_left: RotateFunction, rotate_right: RotateFunction, ) -> Tuple[Optional[Node[T]], Optional[Node[T]]]: if colour(w) == Colour.Red: logger.debug("RB-Delete: Case 1 -- x's sibling w is red") # w.colour = Colour.Black # x.parent.colour = Colour.Red rotate_left(tree=self, x=parent) w = right(parent) if colour(w.left) == Colour.Black and colour(w.right) == Colour.Black: logger.debug( "RB-Delete: Case 2 -- x's sibling w is black, and both of w's children are black" ) # w.colour = Colour.Red parent.rank -= 1 x = parent else: if colour(right(w)) == Colour.Black: logger.debug( "RB-Delete: Case 3 -- x's sibling w is black, w's left child is red, and w's right child is black" ) # left(w).colour = Colour.Black # w.colour = Colour.Red rotate_right(tree=self, x=w) w = right(parent) logger.debug( "RB-Delete: Case 4 -- x's sibling w is black, and w's right child is red" ) # w.colour = colour(x.parent) # x.parent.colour = Colour.Black # right(w).colour = Colour.Black rotate_left(tree=self, x=parent) x = self.root return x, (x.parent if x else None) def _delete_rebalance( self, node: Optional[Node[T]], parent: Optional[Node[T]] ) -> None: if not node and not parent: return logger.debug( "RB-Delete: Called with node=‹%s› parent=‹%s›", node, parent ) logger.debug("RB-Delete: Colour of node: %s", colour(node)) if (node and 2 not in Node.differences(node)) or ( 2 not in Node.differences(parent) ): # we haven't deleted a black node logger.debug( "RB-Delete: No black node has been deleted; node: %s; diffs: %s", node, Node.differences(node), ) return logger.debug("RB-Delete: Tree before rebalancing:\n%s", str(self)) while node != self.root and colour(node) == Colour.Black: logger.debug( "RB-Delete: Rebalancing node=‹%s› (%s-node) parent=‹%s› (%s-node)", node, Node.differences(node), parent, Node.differences(parent), ) if node == parent.left: node, parent = self._delete_rebalance_step( node, parent.right, parent, lambda x: x.left, lambda x: x.right, Node.rotate_left, Node.rotate_right, ) else: node, parent = self._delete_rebalance_step( node, parent.left, parent, lambda x: x.right, lambda x: x.left, Node.rotate_right, Node.rotate_left, ) # endregion DeleteRebalance