diff --git a/Makefile b/Makefile index 13d6097..78d20ce 100644 --- a/Makefile +++ b/Makefile @@ -10,7 +10,7 @@ flake8: flake8 *.py mypy: - mypy --strict --disallow-any-explicit node.py avl.py wavl.py ravl.py + mypy --strict --disallow-any-explicit node.py ranked_tree.py avl.py wavl.py ravl.py rbt.py clean: rm -rf __pycache__ .hypothesis .mypy_cache .pytest_cache diff --git a/avl.py b/avl.py index 6b798ac..6b20141 100644 --- a/avl.py +++ b/avl.py @@ -114,7 +114,7 @@ class AVLTree(RankedTree[T]): return factor != 0 def __delete_fixup( - self, x: Optional[Node[T]], parent: Optional[Node[T]] = None + self, x: Node[T], parent: Optional[Node[T]] = None ) -> bool: factor = _balance_factor(x) if factor == 0: @@ -143,7 +143,7 @@ class AVLTree(RankedTree[T]): if not node and not parent: return - if not node: + if not node and parent: node, parent = parent, parent.parent while node and self.__delete_fixup(node, parent): diff --git a/node.py b/node.py index 7494858..584951b 100644 --- a/node.py +++ b/node.py @@ -2,6 +2,7 @@ from abc import abstractmethod import enum from typing import ( Iterable, + Iterator, Optional, Generic, Tuple, @@ -27,7 +28,7 @@ class NodeType(enum.IntEnum): T = TypeVar("T", bound=Comparable) -class Node(Generic[T]): +class Node(Generic[T], Iterable[T]): def __init__( self, value: T, @@ -61,7 +62,7 @@ class Node(Generic[T]): def __str__(self) -> str: return f"Node(value={self.value}, rank={self.rank})" - def __iter__(self) -> Iterable[T]: + def __iter__(self) -> Iterator[T]: """ Yields: Keys from the subtree rooted at the node in an inorder fashion. @@ -113,7 +114,7 @@ class Node(Generic[T]): @staticmethod def find_parent_node( value: T, node: "Node[T]", missing: bool = True - ) -> "Node[T]": + ) -> "Optional[Node[T]]": new_node: "Optional[Node[T]]" = node while new_node and (missing or new_node.value != value): diff --git a/ranked_tree.py b/ranked_tree.py index 6b8ce7a..a3b6547 100644 --- a/ranked_tree.py +++ b/ranked_tree.py @@ -174,8 +174,7 @@ class RankedTree(Generic[T]): self._transplant(node, node.left) else: n = Node.minimum(node.right) - node.value = n.value - n.value = None + node.value, n.value = n.value, node.value return self._delete_node(n) return (y, parent) diff --git a/rbt.py b/rbt.py index 0d24ab9..e7b0819 100644 --- a/rbt.py +++ b/rbt.py @@ -18,7 +18,7 @@ class Colour(enum.IntEnum): Black = 1 -def has_double_black(x: Optional[Node]) -> bool: +def has_double_black(x: Optional[Node[T]]) -> bool: """ Checks for double black child of x. @@ -64,8 +64,8 @@ class RBTree(RankedTree[T]): z: Node[T], y: Optional[Node[T]], right_child: Node[T], - rotate_left: RotateFunction, - rotate_right: RotateFunction, + rotate_left: RotateFunction[T], + rotate_right: RotateFunction[T], ) -> Node[T]: p = z.parent pp = p.parent @@ -114,9 +114,9 @@ class RBTree(RankedTree[T]): x: Node[T], w: Node[T], parent: Node[T], - right: Callable[[Node[T]], Node[T]], - rotate_left: RotateFunction, - rotate_right: RotateFunction, + right: Callable[[Node[T]], Optional[Node[T]]], + rotate_left: RotateFunction[T], + rotate_right: RotateFunction[T], ) -> Tuple[Optional[Node[T]], Optional[Node[T]]]: if Node.difference(w) == Colour.Red: # Case 1