python/node.py
Matej Focko dc6905b5e2
fix(mypy): fix major issues with mypy
Signed-off-by: Matej Focko <mfocko@redhat.com>
2022-05-22 21:07:40 +02:00

151 lines
3.6 KiB
Python

from abc import abstractmethod
import enum
from typing import (
Iterable,
Iterator,
Optional,
Generic,
Tuple,
TypeVar,
Protocol,
)
class Comparable(Protocol):
"""Protocol for annotating comparable types."""
@abstractmethod
def __lt__(self: "Comparable", other: "Comparable") -> bool:
pass
class NodeType(enum.IntEnum):
LEAF = 0
UNARY = 1
BINARY = 2
T = TypeVar("T", bound=Comparable)
class Node(Generic[T], Iterable[T]):
def __init__(
self,
value: T,
left: "Optional[Node[T]]" = None,
right: "Optional[Node[T]]" = None,
parent: "Optional[Node[T]]" = None,
):
self.parent = parent
self.left = left
self.right = right
self.value = value
self.rank: int = 0
@property
def type(self) -> NodeType:
if self.left and self.right:
return NodeType.BINARY
if self.left or self.right:
return NodeType.UNARY
return NodeType.LEAF
def __repr__(self) -> str:
return (
f"Node(value={self.value}, rank={self.rank}, "
f"left={self.left}, right={self.right}, parent={self.parent})"
)
def __str__(self) -> str:
return f"Node(value={self.value}, rank={self.rank})"
def __iter__(self) -> Iterator[T]:
"""
Yields:
Keys from the subtree rooted at the node in an inorder fashion.
"""
if self.left:
yield from self.left
yield self.value
if self.right:
yield from self.right
@staticmethod
def height(node: "Optional[Node[T]]") -> int:
return (
1 + max(Node.height(node.left), Node.height(node.right))
if node
else -1
)
@staticmethod
def get_rank(node: "Optional[Node[T]]") -> int:
return -1 if not node else node.rank
@staticmethod
def difference(
node: "Optional[Node[T]]", parent: "Optional[Node[T]]" = None
) -> int:
if not parent:
parent = node.parent if node else None
return Node.get_rank(parent) - Node.get_rank(node)
@staticmethod
def differences(node: "Optional[Node[T]]") -> Tuple[int, int]:
node_rank = Node.get_rank(node)
(left, right) = (node.left, node.right) if node else (None, None)
return (
node_rank - Node.get_rank(left),
node_rank - Node.get_rank(right),
)
def promote(self) -> "Node[T]":
self.rank += 1
return self
def demote(self) -> "Node[T]":
self.rank -= 1
return self
@staticmethod
def find_parent_node(
value: T, node: "Node[T]", missing: bool = True
) -> "Optional[Node[T]]":
new_node: "Optional[Node[T]]" = node
while new_node and (missing or new_node.value != value):
node = new_node
if value < node.value:
new_node = node.left
elif node.value < value:
new_node = node.right
else:
return None
return node
@staticmethod
def search(value: T, node: "Optional[Node[T]]") -> "Optional[Node[T]]":
while node and node.value != value:
node = node.left if value < node.value else node.right
return node
@staticmethod
def minimum(node: "Node[T]") -> "Node[T]":
while node.left:
node = node.left
return node
@staticmethod
def maximum(node: "Node[T]") -> "Node[T]":
while node.right:
node = node.right
return node