node: add type signatures

Signed-off-by: Matej Focko <mfocko@redhat.com>
This commit is contained in:
Matej Focko 2022-01-30 15:02:28 +01:00
parent cea0ac2ba2
commit bc981a8bb4
No known key found for this signature in database
GPG key ID: 332171FADF1DB90B

88
node.py
View file

@ -1,4 +1,14 @@
from abc import abstractmethod
import enum import enum
from typing import Callable, Optional, Generic, Tuple, TypeVar, Protocol
class Comparable(Protocol):
"""Protocol for annotating comparable types."""
@abstractmethod
def __lt__(self: "Comparable", other: "Comparable") -> bool:
pass
class NodeType(enum.IntEnum): class NodeType(enum.IntEnum):
@ -7,17 +17,26 @@ class NodeType(enum.IntEnum):
BINARY = 2 BINARY = 2
class Node: T = TypeVar("T", bound=Comparable)
def __init__(self, value, left=None, right=None, parent=None):
class Node(Generic[T]):
def __init__(
self,
value: T,
left: "Optional[Node[T]]" = None,
right: "Optional[Node[T]]" = None,
parent: "Optional[Node[T]]" = None,
):
self.parent = parent self.parent = parent
self.left = left self.left = left
self.right = right self.right = right
self.value = value self.value = value
self.rank = 0 self.rank: int = 0
@property @property
def type(self): def type(self) -> NodeType:
if self.left and self.right: if self.left and self.right:
return NodeType.BINARY return NodeType.BINARY
@ -26,46 +45,58 @@ class Node:
return NodeType.LEAF return NodeType.LEAF
def __repr__(self): def __repr__(self) -> str:
return f"Node(value={self.value}, rank={self.rank}, left={self.left}, right={self.right}, parent={self.parent})" return (
f"Node(value={self.value}, rank={self.rank}, "
f"left={self.left}, right={self.right}, parent={self.parent})"
)
def __str__(self): def __str__(self) -> str:
return f"Node(value={self.value}, rank={self.rank})" return f"Node(value={self.value}, rank={self.rank})"
@staticmethod @staticmethod
def height(node): def height(node: "Optional[Node[T]]") -> int:
return 1 + max(Node.height(node.left), Node.height(node.right)) if node else -1 return (
1 + max(Node.height(node.left), Node.height(node.right))
if node
else -1
)
@staticmethod @staticmethod
def rank(node): def get_rank(node: "Optional[Node[T]]") -> int:
return -1 if not node else node.rank return -1 if not node else node.rank
@staticmethod @staticmethod
def difference(node, parent=None): def difference(
node: "Optional[Node[T]]", parent: "Optional[Node[T]]" = None
) -> int:
if not parent: if not parent:
parent = node.parent if node else None parent = node.parent if node else None
return Node.rank(parent) - Node.rank(node) return Node.get_rank(parent) - Node.get_rank(node)
@staticmethod @staticmethod
def differences(node): def differences(node: "Optional[Node[T]]") -> Tuple[int, int]:
node_rank = Node.rank(node) node_rank = Node.get_rank(node)
(left, right) = (node.left, node.right) if node else (None, None) (left, right) = (node.left, node.right) if node else (None, None)
return (node_rank - Node.rank(left), node_rank - Node.rank(right)) return (
node_rank - Node.get_rank(left),
node_rank - Node.get_rank(right),
)
def promote(self): def promote(self) -> "Node[T]":
self.rank += 1 self.rank += 1
return self return self
def demote(self): def demote(self) -> "Node[T]":
self.rank -= 1 self.rank -= 1
return self return self
@staticmethod @staticmethod
def rotate_right(x): def rotate_right(x: "Node[T]") -> "Node[T]":
parent = x.parent parent = x.parent
y = x.left y = x.left
z = x.right # z = x.right
assert y is not None assert y is not None
if parent: if parent:
@ -85,9 +116,9 @@ class Node:
return y return y
@staticmethod @staticmethod
def rotate_left(x): def rotate_left(x: "Node[T]") -> "Node[T]":
parent = x.parent parent = x.parent
y = x.left # y = x.left
z = x.right z = x.right
assert z is not None assert z is not None
@ -108,8 +139,10 @@ class Node:
return z return z
@staticmethod @staticmethod
def find_parent_node(value, node, missing=True): def find_parent_node(
new_node = node value: T, node: "Node[T]", missing: bool = True
) -> "Node[T]":
new_node: "Optional[Node[T]]" = node
while new_node and (missing or new_node.value != value): while new_node and (missing or new_node.value != value):
node = new_node node = new_node
@ -118,22 +151,25 @@ class Node:
return node return node
@staticmethod @staticmethod
def search(value, node): def search(value: T, node: "Optional[Node[T]]") -> "Optional[Node[T]]":
while node and node.value != value: while node and node.value != value:
node = node.left if value < node.value else node.right node = node.left if value < node.value else node.right
return node return node
@staticmethod @staticmethod
def minimum(node): def minimum(node: "Node[T]") -> "Node[T]":
while node.left: while node.left:
node = node.left node = node.left
return node return node
@staticmethod @staticmethod
def maximum(node): def maximum(node: "Node[T]") -> "Node[T]":
while node.right: while node.right:
node = node.right node = node.right
return node return node
RotateFunction = Callable[[Node[T]], Node[T]]