From bc981a8bb4c6d6d1556365c58a9636937d2c94df Mon Sep 17 00:00:00 2001 From: Matej Focko Date: Sun, 30 Jan 2022 15:02:28 +0100 Subject: [PATCH] node: add type signatures Signed-off-by: Matej Focko --- node.py | 88 ++++++++++++++++++++++++++++++++++++++++----------------- 1 file changed, 62 insertions(+), 26 deletions(-) diff --git a/node.py b/node.py index ec0fb3c..e2eafad 100644 --- a/node.py +++ b/node.py @@ -1,4 +1,14 @@ +from abc import abstractmethod import enum +from typing import Callable, Optional, Generic, Tuple, TypeVar, Protocol + + +class Comparable(Protocol): + """Protocol for annotating comparable types.""" + + @abstractmethod + def __lt__(self: "Comparable", other: "Comparable") -> bool: + pass class NodeType(enum.IntEnum): @@ -7,17 +17,26 @@ class NodeType(enum.IntEnum): BINARY = 2 -class Node: - def __init__(self, value, left=None, right=None, parent=None): +T = TypeVar("T", bound=Comparable) + + +class Node(Generic[T]): + def __init__( + self, + value: T, + left: "Optional[Node[T]]" = None, + right: "Optional[Node[T]]" = None, + parent: "Optional[Node[T]]" = None, + ): self.parent = parent self.left = left self.right = right self.value = value - self.rank = 0 + self.rank: int = 0 @property - def type(self): + def type(self) -> NodeType: if self.left and self.right: return NodeType.BINARY @@ -26,46 +45,58 @@ class Node: return NodeType.LEAF - def __repr__(self): - return f"Node(value={self.value}, rank={self.rank}, left={self.left}, right={self.right}, parent={self.parent})" + def __repr__(self) -> str: + return ( + f"Node(value={self.value}, rank={self.rank}, " + f"left={self.left}, right={self.right}, parent={self.parent})" + ) - def __str__(self): + def __str__(self) -> str: return f"Node(value={self.value}, rank={self.rank})" @staticmethod - def height(node): - return 1 + max(Node.height(node.left), Node.height(node.right)) if node else -1 + def height(node: "Optional[Node[T]]") -> int: + return ( + 1 + max(Node.height(node.left), Node.height(node.right)) + if node + else -1 + ) @staticmethod - def rank(node): + def get_rank(node: "Optional[Node[T]]") -> int: return -1 if not node else node.rank @staticmethod - def difference(node, parent=None): + def difference( + node: "Optional[Node[T]]", parent: "Optional[Node[T]]" = None + ) -> int: if not parent: parent = node.parent if node else None - return Node.rank(parent) - Node.rank(node) + return Node.get_rank(parent) - Node.get_rank(node) @staticmethod - def differences(node): - node_rank = Node.rank(node) + def differences(node: "Optional[Node[T]]") -> Tuple[int, int]: + node_rank = Node.get_rank(node) (left, right) = (node.left, node.right) if node else (None, None) - return (node_rank - Node.rank(left), node_rank - Node.rank(right)) + return ( + node_rank - Node.get_rank(left), + node_rank - Node.get_rank(right), + ) - def promote(self): + def promote(self) -> "Node[T]": self.rank += 1 return self - def demote(self): + def demote(self) -> "Node[T]": self.rank -= 1 return self @staticmethod - def rotate_right(x): + def rotate_right(x: "Node[T]") -> "Node[T]": parent = x.parent y = x.left - z = x.right + # z = x.right assert y is not None if parent: @@ -85,9 +116,9 @@ class Node: return y @staticmethod - def rotate_left(x): + def rotate_left(x: "Node[T]") -> "Node[T]": parent = x.parent - y = x.left + # y = x.left z = x.right assert z is not None @@ -108,8 +139,10 @@ class Node: return z @staticmethod - def find_parent_node(value, node, missing=True): - new_node = node + def find_parent_node( + value: T, node: "Node[T]", missing: bool = True + ) -> "Node[T]": + new_node: "Optional[Node[T]]" = node while new_node and (missing or new_node.value != value): node = new_node @@ -118,22 +151,25 @@ class Node: return node @staticmethod - def search(value, node): + def search(value: T, node: "Optional[Node[T]]") -> "Optional[Node[T]]": while node and node.value != value: node = node.left if value < node.value else node.right return node @staticmethod - def minimum(node): + def minimum(node: "Node[T]") -> "Node[T]": while node.left: node = node.left return node @staticmethod - def maximum(node): + def maximum(node: "Node[T]") -> "Node[T]": while node.right: node = node.right return node + + +RotateFunction = Callable[[Node[T]], Node[T]]