from abc import abstractmethod import enum from typing import ( Iterable, Iterator, Optional, Generic, Tuple, TypeVar, Protocol, ) class Comparable(Protocol): """Protocol for annotating comparable types.""" @abstractmethod def __lt__(self: "Comparable", other: "Comparable") -> bool: pass class NodeType(enum.IntEnum): LEAF = 0 UNARY = 1 BINARY = 2 T = TypeVar("T", bound=Comparable) class Node(Generic[T], Iterable[T]): def __init__( self, value: T, left: "Optional[Node[T]]" = None, right: "Optional[Node[T]]" = None, parent: "Optional[Node[T]]" = None, ): self.parent = parent self.left = left self.right = right self.value = value self.rank: int = 0 @property def type(self) -> NodeType: if self.left and self.right: return NodeType.BINARY if self.left or self.right: return NodeType.UNARY return NodeType.LEAF def __repr__(self) -> str: return ( f"Node(value={self.value}, rank={self.rank}, " f"left={self.left}, right={self.right}, parent={self.parent})" ) def __str__(self) -> str: return f"Node(value={self.value}, rank={self.rank})" def __iter__(self) -> Iterator[T]: """ Yields: Keys from the subtree rooted at the node in an inorder fashion. """ if self.left: yield from self.left yield self.value if self.right: yield from self.right @staticmethod def height(node: "Optional[Node[T]]") -> int: return ( 1 + max(Node.height(node.left), Node.height(node.right)) if node else -1 ) @staticmethod def get_rank(node: "Optional[Node[T]]") -> int: return -1 if not node else node.rank @staticmethod def difference( node: "Optional[Node[T]]", parent: "Optional[Node[T]]" = None ) -> int: if not parent: parent = node.parent if node else None return Node.get_rank(parent) - Node.get_rank(node) @staticmethod def differences(node: "Optional[Node[T]]") -> Tuple[int, int]: node_rank = Node.get_rank(node) (left, right) = (node.left, node.right) if node else (None, None) return ( node_rank - Node.get_rank(left), node_rank - Node.get_rank(right), ) def promote(self) -> "Node[T]": self.rank += 1 return self def demote(self) -> "Node[T]": self.rank -= 1 return self @staticmethod def find_parent_node( value: T, node: "Node[T]", missing: bool = True ) -> "Optional[Node[T]]": new_node: "Optional[Node[T]]" = node while new_node and (missing or new_node.value != value): node = new_node if value < node.value: new_node = node.left elif node.value < value: new_node = node.right else: return None return node @staticmethod def search(value: T, node: "Optional[Node[T]]") -> "Optional[Node[T]]": while node and node.value != value: node = node.left if value < node.value else node.right return node @staticmethod def minimum(node: "Node[T]") -> "Node[T]": while node.left: node = node.left return node @staticmethod def maximum(node: "Node[T]") -> "Node[T]": while node.right: node = node.right return node