from node import Node, Comparable from ranked_tree import RankedTree import logging from typing import TypeVar, Optional logger = logging.getLogger(__name__) T = TypeVar("T", bound=Comparable) def nodes_eq( left_node: Optional[Node[T]], right_node: Optional[Node[T]], same: bool = True, ) -> bool: if left_node is None or right_node is None: return left_node == right_node return ( left_node.value == right_node.value and (not same or left_node.rank == right_node.rank) and nodes_eq(left_node.left, right_node.left) and nodes_eq(left_node.right, right_node.right) ) class Comparator: def __init__(self, left: RankedTree[T], right: RankedTree[T]) -> None: self.left = left self.right = right def insert(self, value: T) -> None: self.left.insert(value) self.right.insert(value) def delete(self, value: T) -> None: self.left.delete(value) self.right.delete(value) def __str__(self) -> str: left = ( self.left._get_unwrapped_graph() .replace("Node", "lNode") .replace("]", ', color="red"]') ) right = ( self.right._get_unwrapped_graph() .replace("Node", "rNode") .replace("]", ', color="blue"]') ) return "digraph {\n" + left + "\n\n" + right + "}\n" @property def are_equal(self) -> bool: return nodes_eq(self.left.root, self.right.root) @property def are_similar(self) -> bool: return nodes_eq(self.left.root, self.right.root, False)