feat: implement comparator for trees

Signed-off-by: Matej Focko <mfocko@redhat.com>
This commit is contained in:
Matej Focko 2022-02-03 11:35:12 +01:00
parent a4c896edbf
commit db35a07738
No known key found for this signature in database
GPG key ID: 332171FADF1DB90B
3 changed files with 100 additions and 1 deletions

59
comparator.py Normal file
View file

@ -0,0 +1,59 @@
from node import Node, Comparable
from ranked_tree import RankedTree
import logging
from typing import TypeVar, Optional
logger = logging.getLogger(__name__)
T = TypeVar("T", bound=Comparable)
def nodes_eq(
left_node: Optional[Node[T]],
right_node: Optional[Node[T]],
same: bool = True,
) -> bool:
if left_node is None or right_node is None:
return left_node == right_node
return (
left_node.value == right_node.value
and (not same or left_node.rank == right_node.rank)
and nodes_eq(left_node.left, right_node.left)
and nodes_eq(left_node.right, right_node.right)
)
class Comparator:
def __init__(self, left: RankedTree[T], right: RankedTree[T]) -> None:
self.left = left
self.right = right
def insert(self, value: T) -> None:
self.left.insert(value)
self.right.insert(value)
def delete(self, value: T) -> None:
self.left.delete(value)
self.right.delete(value)
def __str__(self) -> str:
left = (
self.left._get_unwrapped_graph()
.replace("Node", "lNode")
.replace("]", ', color="red"]')
)
right = (
self.right._get_unwrapped_graph()
.replace("Node", "rNode")
.replace("]", ', color="blue"]')
)
return "digraph {\n" + left + "\n\n" + right + "}\n"
@property
def are_equal(self) -> bool:
return nodes_eq(self.left.root, self.right.root)
@property
def are_similar(self) -> bool:
return nodes_eq(self.left.root, self.right.root, False)

37
test_generate.py Normal file
View file

@ -0,0 +1,37 @@
from avl import AVLTree
from wavl import WAVLTree
from comparator import Comparator
from test_wavl import delete_strategy
import hypothesis
def report_different(before, deleted, comparator):
h = abs(hash(before))
with (
open(f"trees/{h}_before.dot", "w") as b,
open(f"trees/{h}_d{deleted}_after.dot", "w") as a,
):
print(before, file=b)
print(comparator, file=a)
@hypothesis.settings(max_examples=10000, deadline=None)
@hypothesis.given(config=delete_strategy())
def test_delete(config):
values, order = config
comparator = Comparator(AVLTree(), WAVLTree())
for value in values:
comparator.insert(value)
for value in order:
before = str(comparator.left)
comparator.delete(value)
try:
assert comparator.are_equal
except AssertionError:
report_different(before, value, comparator)
raise

View file

@ -60,7 +60,7 @@ class WAVLTree(AVLTree[T]):
if z.type == NodeType.LEAF:
z.demote()
elif w_diff == 2 and v.parent:
elif w_diff == 2 and v and v.parent:
logger.debug(f"v.parent = {v.parent}")
rotate_right(v.parent)
new_root = rotate_left(v.parent)
@ -143,6 +143,7 @@ class WAVLTree(AVLTree[T]):
f"{Node.differences(z)} == (2, 2)"
)
assert z
# FIXME: In combination with propagation below, we get AVL tree
if Node.differences(z) == (2, 2):
z.demote()
@ -158,6 +159,8 @@ class WAVLTree(AVLTree[T]):
def _delete_rebalance(
self, node: Optional[Node[T]], parent: Optional[Node[T]]
) -> None:
# FIXME: Do not go all the way up, just to the replaced nodes and then
# check if rank rule is broken.
while node or parent:
self.__delete_fixup(node, parent)
node, parent = parent, (parent.parent if parent else None)