diff --git a/node.py b/node.py index 0669d4b..491403f 100644 --- a/node.py +++ b/node.py @@ -1,6 +1,14 @@ from abc import abstractmethod import enum -from typing import Callable, Optional, Generic, Tuple, TypeVar, Protocol +from typing import ( + Callable, + Iterable, + Optional, + Generic, + Tuple, + TypeVar, + Protocol, +) class Comparable(Protocol): @@ -54,6 +62,17 @@ class Node(Generic[T]): def __str__(self) -> str: return f"Node(value={self.value}, rank={self.rank})" + def __iter__(self) -> Iterable[T]: + """ + Yields: + Keys from the subtree rooted at the node in an inorder fashion. + """ + if self.left: + yield from self.left + yield self.value + if self.right: + yield from self.right + @staticmethod def height(node: "Optional[Node[T]]") -> int: return ( diff --git a/ranked_tree.py b/ranked_tree.py index e3ad76b..80c49b4 100644 --- a/ranked_tree.py +++ b/ranked_tree.py @@ -3,7 +3,7 @@ from node import Node, Comparable from abc import abstractmethod from collections import deque import logging -from typing import Deque, Optional, Tuple, TypeVar, Generic +from typing import Deque, Iterable, Optional, Tuple, TypeVar, Generic logger = logging.getLogger(__name__) T = TypeVar("T", bound=Comparable) @@ -140,3 +140,11 @@ class RankedTree(Generic[T]): if to_be_rebalanced := self._delete(value): y, parent = to_be_rebalanced self._delete_rebalance(y, parent) + + def __iter__(self) -> Iterable[T]: + """ + Yields: + Keys from the tree in an inorder fashion. + """ + if self.root: + yield from self.root