node: add type signatures
Signed-off-by: Matej Focko <mfocko@redhat.com>
This commit is contained in:
parent
cea0ac2ba2
commit
bc981a8bb4
1 changed files with 62 additions and 26 deletions
88
node.py
88
node.py
|
@ -1,4 +1,14 @@
|
|||
from abc import abstractmethod
|
||||
import enum
|
||||
from typing import Callable, Optional, Generic, Tuple, TypeVar, Protocol
|
||||
|
||||
|
||||
class Comparable(Protocol):
|
||||
"""Protocol for annotating comparable types."""
|
||||
|
||||
@abstractmethod
|
||||
def __lt__(self: "Comparable", other: "Comparable") -> bool:
|
||||
pass
|
||||
|
||||
|
||||
class NodeType(enum.IntEnum):
|
||||
|
@ -7,17 +17,26 @@ class NodeType(enum.IntEnum):
|
|||
BINARY = 2
|
||||
|
||||
|
||||
class Node:
|
||||
def __init__(self, value, left=None, right=None, parent=None):
|
||||
T = TypeVar("T", bound=Comparable)
|
||||
|
||||
|
||||
class Node(Generic[T]):
|
||||
def __init__(
|
||||
self,
|
||||
value: T,
|
||||
left: "Optional[Node[T]]" = None,
|
||||
right: "Optional[Node[T]]" = None,
|
||||
parent: "Optional[Node[T]]" = None,
|
||||
):
|
||||
self.parent = parent
|
||||
self.left = left
|
||||
self.right = right
|
||||
|
||||
self.value = value
|
||||
self.rank = 0
|
||||
self.rank: int = 0
|
||||
|
||||
@property
|
||||
def type(self):
|
||||
def type(self) -> NodeType:
|
||||
if self.left and self.right:
|
||||
return NodeType.BINARY
|
||||
|
||||
|
@ -26,46 +45,58 @@ class Node:
|
|||
|
||||
return NodeType.LEAF
|
||||
|
||||
def __repr__(self):
|
||||
return f"Node(value={self.value}, rank={self.rank}, left={self.left}, right={self.right}, parent={self.parent})"
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"Node(value={self.value}, rank={self.rank}, "
|
||||
f"left={self.left}, right={self.right}, parent={self.parent})"
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return f"Node(value={self.value}, rank={self.rank})"
|
||||
|
||||
@staticmethod
|
||||
def height(node):
|
||||
return 1 + max(Node.height(node.left), Node.height(node.right)) if node else -1
|
||||
def height(node: "Optional[Node[T]]") -> int:
|
||||
return (
|
||||
1 + max(Node.height(node.left), Node.height(node.right))
|
||||
if node
|
||||
else -1
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def rank(node):
|
||||
def get_rank(node: "Optional[Node[T]]") -> int:
|
||||
return -1 if not node else node.rank
|
||||
|
||||
@staticmethod
|
||||
def difference(node, parent=None):
|
||||
def difference(
|
||||
node: "Optional[Node[T]]", parent: "Optional[Node[T]]" = None
|
||||
) -> int:
|
||||
if not parent:
|
||||
parent = node.parent if node else None
|
||||
return Node.rank(parent) - Node.rank(node)
|
||||
return Node.get_rank(parent) - Node.get_rank(node)
|
||||
|
||||
@staticmethod
|
||||
def differences(node):
|
||||
node_rank = Node.rank(node)
|
||||
def differences(node: "Optional[Node[T]]") -> Tuple[int, int]:
|
||||
node_rank = Node.get_rank(node)
|
||||
(left, right) = (node.left, node.right) if node else (None, None)
|
||||
|
||||
return (node_rank - Node.rank(left), node_rank - Node.rank(right))
|
||||
return (
|
||||
node_rank - Node.get_rank(left),
|
||||
node_rank - Node.get_rank(right),
|
||||
)
|
||||
|
||||
def promote(self):
|
||||
def promote(self) -> "Node[T]":
|
||||
self.rank += 1
|
||||
return self
|
||||
|
||||
def demote(self):
|
||||
def demote(self) -> "Node[T]":
|
||||
self.rank -= 1
|
||||
return self
|
||||
|
||||
@staticmethod
|
||||
def rotate_right(x):
|
||||
def rotate_right(x: "Node[T]") -> "Node[T]":
|
||||
parent = x.parent
|
||||
y = x.left
|
||||
z = x.right
|
||||
# z = x.right
|
||||
|
||||
assert y is not None
|
||||
if parent:
|
||||
|
@ -85,9 +116,9 @@ class Node:
|
|||
return y
|
||||
|
||||
@staticmethod
|
||||
def rotate_left(x):
|
||||
def rotate_left(x: "Node[T]") -> "Node[T]":
|
||||
parent = x.parent
|
||||
y = x.left
|
||||
# y = x.left
|
||||
z = x.right
|
||||
|
||||
assert z is not None
|
||||
|
@ -108,8 +139,10 @@ class Node:
|
|||
return z
|
||||
|
||||
@staticmethod
|
||||
def find_parent_node(value, node, missing=True):
|
||||
new_node = node
|
||||
def find_parent_node(
|
||||
value: T, node: "Node[T]", missing: bool = True
|
||||
) -> "Node[T]":
|
||||
new_node: "Optional[Node[T]]" = node
|
||||
|
||||
while new_node and (missing or new_node.value != value):
|
||||
node = new_node
|
||||
|
@ -118,22 +151,25 @@ class Node:
|
|||
return node
|
||||
|
||||
@staticmethod
|
||||
def search(value, node):
|
||||
def search(value: T, node: "Optional[Node[T]]") -> "Optional[Node[T]]":
|
||||
while node and node.value != value:
|
||||
node = node.left if value < node.value else node.right
|
||||
|
||||
return node
|
||||
|
||||
@staticmethod
|
||||
def minimum(node):
|
||||
def minimum(node: "Node[T]") -> "Node[T]":
|
||||
while node.left:
|
||||
node = node.left
|
||||
|
||||
return node
|
||||
|
||||
@staticmethod
|
||||
def maximum(node):
|
||||
def maximum(node: "Node[T]") -> "Node[T]":
|
||||
while node.right:
|
||||
node = node.right
|
||||
|
||||
return node
|
||||
|
||||
|
||||
RotateFunction = Callable[[Node[T]], Node[T]]
|
||||
|
|
Loading…
Reference in a new issue