from node import Node, Comparable from abc import abstractmethod from collections import deque import logging from typing import Callable, Deque, Iterable, Optional, Tuple, TypeVar, Generic logger = logging.getLogger(__name__) T = TypeVar("T", bound=Comparable) class RankedTree(Generic[T]): def __init__(self) -> None: self.root: Optional[Node[T]] = None def _get_unwrapped_graph(self) -> str: result = "" queue: Deque[Optional[Node[T]]] = deque([self.root]) edges = [] while queue: node = queue.popleft() if not node: continue result += f'"{str(node)}" [label="{node.value}, {node.rank}"];\n' for child in (node.left, node.right): if not child: continue edges.append((node, child)) queue.append(child) for from_node, to_node in edges: label = f'[label="{Node.difference(to_node)}"]' result += f'"{str(from_node)}" -> "{str(to_node)}" {label}\n' return result def __str__(self) -> str: return "digraph {\n" + self._get_unwrapped_graph() + "}\n" # region TreeSpecificMethods @abstractmethod def is_correct_node(self, node: Optional[Node[T]]) -> bool: pass @abstractmethod def _insert_rebalance(self, x: Node[T]) -> None: pass @abstractmethod def _delete_rebalance( self, node: Optional[Node[T]], parent: Optional[Node[T]] ) -> None: pass # endregion TreeSpecificMethods # region Rotations def rotate_right(self, x: "Node[T]") -> "Node[T]": parent = x.parent y = x.left # z = x.right assert y is not None if parent: if parent.left is x: parent.left = y else: parent.right = y else: self.root = y x.left = y.right if x.left: x.left.parent = x y.right = x x.parent = y y.parent = parent return y def rotate_left(self, x: "Node[T]") -> "Node[T]": parent = x.parent # y = x.left z = x.right assert z is not None if parent: if parent.left is x: parent.left = z else: parent.right = z else: self.root = z x.right = z.left if x.right: x.right.parent = x z.left = x x.parent = z z.parent = parent return z # endregion Rotations @property def rank(self) -> int: return Node.get_rank(self.root) @property def is_correct(self) -> bool: return self.is_correct_node(self.root) def search( self, value: T, node: Optional[Node[T]] = None ) -> Optional[Node[T]]: if not node: node = self.root return Node.search(value, node) def insert(self, value: T) -> None: inserted_node = Node(value) if not self.root: self.root = inserted_node return parent = Node.find_parent_node(value, self.root) if not parent: return inserted_node.parent = parent if value < parent.value: parent.left = inserted_node else: parent.right = inserted_node self._insert_rebalance(inserted_node) def _transplant(self, u: Node[T], v: Optional[Node[T]]) -> None: if not u.parent: self.root = v elif u.parent.left is u: u.parent.left = v else: u.parent.right = v if v: v.rank = u.rank v.parent = u.parent def _delete_node( self, node: Optional[Node[T]] ) -> Optional[Tuple[Optional[Node[T]], Optional[Node[T]]]]: if node is None: return None y, parent = None, node.parent if not node.left: y = node.right self._transplant(node, node.right) elif not node.right: y = node.left self._transplant(node, node.left) else: n = Node.minimum(node.right) node.value, n.value = n.value, node.value return self._delete_node(n) return (y, parent) def _delete( self, value: T ) -> Optional[Tuple[Optional[Node[T]], Optional[Node[T]]]]: node = self.root while node is not None and node.value != value: node = node.left if value < node.value else node.right return self._delete_node(node) def delete(self, value: T) -> None: if to_be_rebalanced := self._delete(value): y, parent = to_be_rebalanced self._delete_rebalance(y, parent) def __iter__(self) -> Iterable[T]: """ Yields: Keys from the tree in an inorder fashion. """ if self.root: yield from self.root RotateFunction = Callable[[Node[T]], Node[T]]