diff --git a/comparator.py b/comparator.py new file mode 100644 index 0000000..c1d7843 --- /dev/null +++ b/comparator.py @@ -0,0 +1,59 @@ +from node import Node, Comparable +from ranked_tree import RankedTree + +import logging +from typing import TypeVar, Optional + +logger = logging.getLogger(__name__) +T = TypeVar("T", bound=Comparable) + + +def nodes_eq( + left_node: Optional[Node[T]], + right_node: Optional[Node[T]], + same: bool = True, +) -> bool: + if left_node is None or right_node is None: + return left_node == right_node + + return ( + left_node.value == right_node.value + and (not same or left_node.rank == right_node.rank) + and nodes_eq(left_node.left, right_node.left) + and nodes_eq(left_node.right, right_node.right) + ) + + +class Comparator: + def __init__(self, left: RankedTree[T], right: RankedTree[T]) -> None: + self.left = left + self.right = right + + def insert(self, value: T) -> None: + self.left.insert(value) + self.right.insert(value) + + def delete(self, value: T) -> None: + self.left.delete(value) + self.right.delete(value) + + def __str__(self) -> str: + left = ( + self.left._get_unwrapped_graph() + .replace("Node", "lNode") + .replace("]", ', color="red"]') + ) + right = ( + self.right._get_unwrapped_graph() + .replace("Node", "rNode") + .replace("]", ', color="blue"]') + ) + return "digraph {\n" + left + "\n\n" + right + "}\n" + + @property + def are_equal(self) -> bool: + return nodes_eq(self.left.root, self.right.root) + + @property + def are_similar(self) -> bool: + return nodes_eq(self.left.root, self.right.root, False) diff --git a/test_generate.py b/test_generate.py new file mode 100644 index 0000000..fcbcac1 --- /dev/null +++ b/test_generate.py @@ -0,0 +1,37 @@ +from avl import AVLTree +from wavl import WAVLTree +from comparator import Comparator +from test_wavl import delete_strategy + +import hypothesis + + +def report_different(before, deleted, comparator): + h = abs(hash(before)) + with ( + open(f"trees/{h}_before.dot", "w") as b, + open(f"trees/{h}_d{deleted}_after.dot", "w") as a, + ): + print(before, file=b) + print(comparator, file=a) + + +@hypothesis.settings(max_examples=10000, deadline=None) +@hypothesis.given(config=delete_strategy()) +def test_delete(config): + values, order = config + + comparator = Comparator(AVLTree(), WAVLTree()) + + for value in values: + comparator.insert(value) + + for value in order: + before = str(comparator.left) + comparator.delete(value) + + try: + assert comparator.are_equal + except AssertionError: + report_different(before, value, comparator) + raise diff --git a/wavl.py b/wavl.py index 5d1ef71..7bbb27a 100644 --- a/wavl.py +++ b/wavl.py @@ -60,7 +60,7 @@ class WAVLTree(AVLTree[T]): if z.type == NodeType.LEAF: z.demote() - elif w_diff == 2 and v.parent: + elif w_diff == 2 and v and v.parent: logger.debug(f"v.parent = {v.parent}") rotate_right(v.parent) new_root = rotate_left(v.parent) @@ -143,6 +143,7 @@ class WAVLTree(AVLTree[T]): f"{Node.differences(z)} == (2, 2)" ) assert z + # FIXME: In combination with propagation below, we get AVL tree if Node.differences(z) == (2, 2): z.demote() @@ -158,6 +159,8 @@ class WAVLTree(AVLTree[T]): def _delete_rebalance( self, node: Optional[Node[T]], parent: Optional[Node[T]] ) -> None: + # FIXME: Do not go all the way up, just to the replaced nodes and then + # check if rank rule is broken. while node or parent: self.__delete_fixup(node, parent) node, parent = parent, (parent.parent if parent else None)