From 70261aed506741e96e49a4b1e3e6d908fdf50409 Mon Sep 17 00:00:00 2001 From: Matej Focko Date: Mon, 4 Jan 2021 10:46:12 +0100 Subject: [PATCH] Add python implementation and tests Signed-off-by: Matej Focko --- .vscode/settings.json | 3 + Pipfile | 15 +++ Pipfile.lock | 99 ++++++++++++++ node.py | 139 +++++++++++++++++++ test_node.py | 15 +++ test_wavl.py | 185 ++++++++++++++++++++++++++ wavl.py | 300 ++++++++++++++++++++++++++++++++++++++++++ 7 files changed, 756 insertions(+) create mode 100644 .vscode/settings.json create mode 100644 Pipfile create mode 100644 Pipfile.lock create mode 100644 node.py create mode 100644 test_node.py create mode 100644 test_wavl.py create mode 100644 wavl.py diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..de288e1 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,3 @@ +{ + "python.formatting.provider": "black" +} \ No newline at end of file diff --git a/Pipfile b/Pipfile new file mode 100644 index 0000000..36a1a87 --- /dev/null +++ b/Pipfile @@ -0,0 +1,15 @@ +[[source]] +url = "https://pypi.org/simple" +verify_ssl = true +name = "pypi" + +[packages] +pytest = "*" +hypothesis = "*" +black = "*" +pylint = "*" + +[dev-packages] + +[requires] +python_version = "3.9" diff --git a/Pipfile.lock b/Pipfile.lock new file mode 100644 index 0000000..007d162 --- /dev/null +++ b/Pipfile.lock @@ -0,0 +1,99 @@ +{ + "_meta": { + "hash": { + "sha256": "95de7fc3f7d4accf18f255d512e95bd44ee4e34800f8a38f99be490af2d75e6b" + }, + "pipfile-spec": 6, + "requires": { + "python_version": "3.9" + }, + "sources": [ + { + "name": "pypi", + "url": "https://pypi.org/simple", + "verify_ssl": true + } + ] + }, + "default": { + "attrs": { + "hashes": [ + "sha256:31b2eced602aa8423c2aea9c76a724617ed67cf9513173fd3a4f03e3a929c7e6", + "sha256:832aa3cde19744e49938b91fea06d69ecb9e649c93ba974535d08ad92164f700" + ], + "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'", + "version": "==20.3.0" + }, + "hypothesis": { + "hashes": [ + "sha256:6094f922c1d2bef808b8e6cf3ab8cdb8ad76faa12519a25ac4250eed524e25c0", + "sha256:6dfc4c233f3903f1524242c1e1c92b5e06fb4b689d3fd2abbd22783af25d3c27" + ], + "index": "pypi", + "version": "==5.43.4" + }, + "iniconfig": { + "hashes": [ + "sha256:011e24c64b7f47f6ebd835bb12a743f2fbe9a26d4cecaa7f53bc4f35ee9da8b3", + "sha256:bc3af051d7d14b2ee5ef9969666def0cd1a000e121eaea580d4a313df4b37f32" + ], + "version": "==1.1.1" + }, + "packaging": { + "hashes": [ + "sha256:24e0da08660a87484d1602c30bb4902d74816b6985b93de36926f5bc95741858", + "sha256:78598185a7008a470d64526a8059de9aaa449238f280fc9eb6b13ba6c4109093" + ], + "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'", + "version": "==20.8" + }, + "pluggy": { + "hashes": [ + "sha256:15b2acde666561e1298d71b523007ed7364de07029219b604cf808bfa1c765b0", + "sha256:966c145cd83c96502c3c3868f50408687b38434af77734af1e9ca461a4081d2d" + ], + "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'", + "version": "==0.13.1" + }, + "py": { + "hashes": [ + "sha256:21b81bda15b66ef5e1a777a21c4dcd9c20ad3efd0b3f817e7a809035269e1bd3", + "sha256:3b80836aa6d1feeaa108e046da6423ab8f6ceda6468545ae8d02d9d58d18818a" + ], + "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'", + "version": "==1.10.0" + }, + "pyparsing": { + "hashes": [ + "sha256:c203ec8783bf771a155b207279b9bccb8dea02d8f0c9e5f8ead507bc3246ecc1", + "sha256:ef9d7589ef3c200abe66653d3f1ab1033c3c419ae9b9bdb1240a85b024efc88b" + ], + "markers": "python_version >= '2.6' and python_version not in '3.0, 3.1, 3.2, 3.3'", + "version": "==2.4.7" + }, + "pytest": { + "hashes": [ + "sha256:1969f797a1a0dbd8ccf0fecc80262312729afea9c17f1d70ebf85c5e76c6f7c8", + "sha256:66e419b1899bc27346cb2c993e12c5e5e8daba9073c1fbce33b9807abc95c306" + ], + "index": "pypi", + "version": "==6.2.1" + }, + "sortedcontainers": { + "hashes": [ + "sha256:37257a32add0a3ee490bb170b599e93095eed89a55da91fa9f48753ea12fd73f", + "sha256:59cc937650cf60d677c16775597c89a960658a09cf7c1a668f86e1e4464b10a1" + ], + "version": "==2.3.0" + }, + "toml": { + "hashes": [ + "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b", + "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f" + ], + "markers": "python_version >= '2.6' and python_version not in '3.0, 3.1, 3.2, 3.3'", + "version": "==0.10.2" + } + }, + "develop": {} +} diff --git a/node.py b/node.py new file mode 100644 index 0000000..ec0fb3c --- /dev/null +++ b/node.py @@ -0,0 +1,139 @@ +import enum + + +class NodeType(enum.IntEnum): + LEAF = 0 + UNARY = 1 + BINARY = 2 + + +class Node: + def __init__(self, value, left=None, right=None, parent=None): + self.parent = parent + self.left = left + self.right = right + + self.value = value + self.rank = 0 + + @property + def type(self): + if self.left and self.right: + return NodeType.BINARY + + if self.left or self.right: + return NodeType.UNARY + + return NodeType.LEAF + + def __repr__(self): + return f"Node(value={self.value}, rank={self.rank}, left={self.left}, right={self.right}, parent={self.parent})" + + def __str__(self): + return f"Node(value={self.value}, rank={self.rank})" + + @staticmethod + def height(node): + return 1 + max(Node.height(node.left), Node.height(node.right)) if node else -1 + + @staticmethod + def rank(node): + return -1 if not node else node.rank + + @staticmethod + def difference(node, parent=None): + if not parent: + parent = node.parent if node else None + return Node.rank(parent) - Node.rank(node) + + @staticmethod + def differences(node): + node_rank = Node.rank(node) + (left, right) = (node.left, node.right) if node else (None, None) + + return (node_rank - Node.rank(left), node_rank - Node.rank(right)) + + def promote(self): + self.rank += 1 + return self + + def demote(self): + self.rank -= 1 + return self + + @staticmethod + def rotate_right(x): + parent = x.parent + y = x.left + z = x.right + + assert y is not None + if parent: + if parent.left is x: + parent.left = y + else: + parent.right = y + + x.left = y.right + if x.left: + x.left.parent = x + + y.right = x + x.parent = y + y.parent = parent + + return y + + @staticmethod + def rotate_left(x): + parent = x.parent + y = x.left + z = x.right + + assert z is not None + if parent: + if parent.left is x: + parent.left = z + else: + parent.right = z + + x.right = z.left + if x.right: + x.right.parent = x + + z.left = x + x.parent = z + z.parent = parent + + return z + + @staticmethod + def find_parent_node(value, node, missing=True): + new_node = node + + while new_node and (missing or new_node.value != value): + node = new_node + new_node = node.left if value < node.value else node.right + + return node + + @staticmethod + def search(value, node): + while node and node.value != value: + node = node.left if value < node.value else node.right + + return node + + @staticmethod + def minimum(node): + while node.left: + node = node.left + + return node + + @staticmethod + def maximum(node): + while node.right: + node = node.right + + return node diff --git a/test_node.py b/test_node.py new file mode 100644 index 0000000..373c6c3 --- /dev/null +++ b/test_node.py @@ -0,0 +1,15 @@ +from node import Node, NodeType + + +def test_default_rank(): + node = Node(0) + + assert 0 == node.value + assert node.parent is None + assert node.left is None + assert node.right is None + + assert 0 == node.rank + assert 0 == Node.rank(node) + assert -1 == Node.difference(node) + assert (1, 1) == Node.differences(node) diff --git a/test_wavl.py b/test_wavl.py new file mode 100644 index 0000000..50bc276 --- /dev/null +++ b/test_wavl.py @@ -0,0 +1,185 @@ +from node import Node, NodeType +from wavl import WAVLTree + +import hypothesis +import hypothesis.strategies as st +import logging +import pytest +import random +import unittest + +logger = logging.getLogger(__name__) + + +def test_empty(): + tree = WAVLTree() + + assert tree.root is None + assert tree.is_correct + + +def test_one_node(): + tree = WAVLTree() + tree.insert(1) + + assert tree.root is not None + assert 1 == tree.root.value + assert 0 == tree.root.rank + assert tree.root.left is None + assert tree.root.right is None + + assert tree.is_correct + + +@pytest.mark.parametrize("values", [[1, 2], [1, 2, 0]]) +def test_no_rebalance_needed(values): + tree = WAVLTree() + + for value in values: + tree.insert(value) + assert tree.is_correct + + +def test_three_nodes_rebalanced(): + tree = WAVLTree() + + for value in (1, 2, 3): + print(tree) + tree.insert(value) + + assert tree.is_correct + + +def test_bigger_tree(): + tree = WAVLTree() + + for i in range(50): + tree.insert(i) + assert tree.is_correct + + +def test_bigger_tree_reversed(): + tree = WAVLTree() + + for i in range(50): + tree.insert(-i) + assert tree.is_correct + + +def test_promote(): + tree = WAVLTree() + + for value in (0, 1, -1): + tree.insert(value) + assert tree.is_correct + + +@pytest.mark.parametrize( + "values", + [ + [0, 1, -2, -1, -3], + [0, 1, -2, -1, -3, -4, 4, 9, 7], + [0, 1, -2, -1, -3, -4, 4, 9, 7, 5, -5, 8], + ], +) +def test_rotate(values): + tree = WAVLTree() + + for value in values: + tree.insert(value) + assert tree.is_correct + + +@hypothesis.settings(max_examples=1000, deadline=None) +@hypothesis.given(values=st.sets(st.integers())) +def test_insert_random(values): + tree = WAVLTree() + + for value in values: + tree.insert(value) + assert tree.is_correct + assert tree.search(value) is not None + + +def test_search_empty(): + tree = WAVLTree() + assert tree.search(0) is None + + +@hypothesis.given(values=st.sets(st.integers())) +def test_search_random(values): + tree = WAVLTree() + + for value in values: + tree.insert(value) + assert tree.is_correct + + for value in values: + assert tree.search(value) is not None + + +@st.composite +def delete_strategy(draw): + values = list(draw(st.sets(st.integers()))) + delete_order = values.copy() + random.shuffle(delete_order) + return (values, delete_order) + + +@hypothesis.settings(max_examples=100000, deadline=None) +@hypothesis.given(config=delete_strategy()) +def test_delete_random(config): + values, delete_order = config + + tree = WAVLTree() + + for value in values: + tree.insert(value) + + for value in delete_order: + before = str(tree) + tree.delete(value) + after = str(tree) + + try: + assert tree.is_correct + except AssertionError: + logger.info( + f"[FAIL] Deleting of {value} from {values} in order {delete_order}" + ) + logger.info(f"Before:\n{before}") + logger.info(f"After:\n{after}") + raise + + +# @unittest.skip("Only for replicating hypothesis") +@pytest.mark.parametrize( + "values, delete_order", + [ + ([0, 1, 2], [0, 2, 1]), + ([0, 1, 2, -1], [2, 0, 1, -1]), + ([0, 1, -1], [0, -1, 1]), + ([0, 1], [0, 1]), + ([0, 1, 32, -1], [32, 0, 1, -1]), + ], +) +def test_delete_minimal(values, delete_order): + tree = WAVLTree() + + for value in values: + tree.insert(value) + + for value in delete_order: + before = str(tree) + tree.delete(value) + after = str(tree) + + try: + assert tree.is_correct + except AssertionError: + logger.info( + f"[FAIL] Deleting of {value} from {values} in order {delete_order}" + ) + logger.info(f"Before:\n{before}") + logger.info(f"After:\n{after}") + raise diff --git a/wavl.py b/wavl.py new file mode 100644 index 0000000..177e676 --- /dev/null +++ b/wavl.py @@ -0,0 +1,300 @@ +from node import Node, NodeType + +from collections import deque +import logging + +logger = logging.getLogger(__name__) + + +class WAVLTree: + def __init__(self): + self.root = None + + @property + def rank(self): + return Node.rank(self.root) + + @property + def is_correct(self): + return WAVLTree.is_correct_node(self.root) + + @staticmethod + def is_correct_node(node, recursive=True): + if not node: + return True + + for child_rank in Node.differences(node): + if child_rank not in (1, 2): + return False + + if node.type == NodeType.LEAF: + return node.rank == 0 + + return not recursive or ( + WAVLTree.is_correct_node(node.left) and WAVLTree.is_correct_node(node.right) + ) + + def search(self, value, node=None): + if not node: + node = self.root + + return Node.search(value, node) + + def __fix_0_child(self, x, y, z, rotate_left, rotate_right): + new_root = x.parent + + if not y or Node.difference(y) == 2: + new_root = rotate_right(z) + z.demote() + elif Node.difference(y) == 1: + rotate_left(x) + new_root = rotate_right(z) + + y.promote() + x.demote() + z.demote() + + return new_root + + def __bottomup_rebalance(self, x): + diffs = Node.differences(x.parent) + # if diffs != (0, 1) and diffs != (1, 0): + # return + + while x.parent and (diffs == (0, 1) or diffs == (1, 0)): + x.parent.promote() + x = x.parent + + diffs = Node.differences(x.parent) + + if not x.parent: + return + + rotating_around_root = x.parent.parent is None + new_root = x.parent + + rank_difference = Node.difference(x) + + if rank_difference != 0: + return + + if rank_difference == 0 and x.parent.left is x: + new_root = self.__fix_0_child( + x, x.right, x.parent, Node.rotate_left, Node.rotate_right + ) + elif rank_difference == 0 and x.parent.right is x: + new_root = self.__fix_0_child( + x, x.left, x.parent, Node.rotate_right, Node.rotate_left + ) + + if rotating_around_root: + self.root = new_root + + def insert(self, value): + inserted_node = Node(value) + + if not self.root: + self.root = inserted_node + return + + parent = Node.find_parent_node(value, self.root) + inserted_node.parent = parent + + if value < parent.value: + parent.left = inserted_node + else: + parent.right = inserted_node + + self.__bottomup_rebalance(inserted_node) + + def __transplant(self, u, v): + if not u.parent: + self.root = v + elif u.parent.left is u: + u.parent.left = v + else: + u.parent.right = v + + if v: + v.rank = u.rank + v.parent = u.parent + + @staticmethod + def __fix_delete(x, y, z, reversed, rotate_left, rotate_right): + new_root = x + v = y.left + w = y.right + + if reversed: + v, w = w, v + + logger.debug(f"__fix_delete({x}, {y}, {z}, {reversed})") + + w_diff = Node.difference(w, y) + logger.debug(f"w_diff = {w_diff}") + if w_diff == 1 and y.parent: + logger.debug(f"y.parent = {y.parent}") + new_root = rotate_left(y.parent) + + y.promote() + z.demote() + + if z.type == NodeType.LEAF: + z.demote() + elif w_diff == 2 and v.parent: + logger.debug(f"v.parent = {v.parent}") + rotate_right(v.parent) + new_root = rotate_left(v.parent) + + v.promote().promote() + y.demote() + z.demote().demote() + + return new_root + + def __bottomup_delete(self, x, parent): + x_diff = Node.difference(x, parent) + if x_diff != 3 or not parent: + return + + y = parent.right if parent.left is x else parent.left + y_diff = Node.difference(y, parent) + + while ( + parent + and x_diff == 3 + and y + and (y_diff == 2 or Node.differences(y) == (2, 2)) + ): + parent.demote() + if y_diff != 2: + y.demote() + + x = parent + parent = x.parent + if not parent: + return + y = parent.right if parent.left is x else parent.left + + x_diff = Node.difference(x, parent) + y_diff = Node.difference(y, parent) + + if not parent: + return + + rotating_around_root = parent.parent is None + new_root = parent + + parent_node_diffs = Node.differences(parent) + if parent_node_diffs in ((1, 3), (3, 1)): + if parent.left is x: + new_root = WAVLTree.__fix_delete( + x, + parent.right, + parent, + False, + Node.rotate_left, + Node.rotate_right, + ) + else: + new_root = WAVLTree.__fix_delete( + x, + parent.left, + parent, + True, + Node.rotate_right, + Node.rotate_left, + ) + + if rotating_around_root: + self.root = new_root + + def __delete_fixup(self, y, parent=None): + logger.debug(f"[__delete_fixup] y = {y}, parent = {parent}") + + z = y if y else parent + logger.debug( + f"[z.demote()] Node.differences({repr(z)}) == (2, 2) ~>* {Node.differences(z)} == (2, 2)" + ) + if Node.differences(z) == (2, 2): + z.demote() + + if parent: + for y in (parent.left, parent.right): + logger.debug( + f"[bottom-up delete] Node.difference({y}, {parent}) == 3 ~>* {Node.difference(y, parent)} == 3" + ) + if Node.difference(y, parent) == 3: + self.__bottomup_delete(y, parent) + + def __fix_after_delete(self, node, parent): + while node or parent: + self.__delete_fixup(node, parent) + node, parent = parent, (parent.parent if parent else None) + + def delete_node(self, node): + y, parent = None, node.parent + + if not node.left: + logger.debug("node.left is None") + y = node.right + self.__transplant(node, node.right) + elif not node.right: + logger.debug("node.right is None") + y = node.left + self.__transplant(node, node.left) + else: + logger.debug("taking successor") + n = Node.minimum(node.right) + y, parent = None, (n.parent if n.parent is not node else n) + + if n.parent is not node: + parent = n.right if n.right else n.parent + self.__transplant(n, n.right) + n.right = node.right + n.right.parent = n + + self.__transplant(node, n) + n.left = node.left + n.left.parent = n + + return (y, parent) + + def __delete(self, value, node=None): + if node is None: + return + + if node.value != value: + return self.__delete(value, node.left if value < node.value else node.right) + + (y, parent) = self.delete_node(node) + self.__fix_after_delete(y, parent) + + def delete(self, value): + logger.debug(f"[DELETE] {value}") + self.__delete(value, self.root) + + def __str__(self): + result = "digraph {\n" + + queue = deque() + queue.append(self.root) + + edges = [] + while queue: + node = queue.popleft() + if not node: + continue + + result += f'"{str(node)}" [label="{node.value}, {node.rank}"];\n' + for child in (node.left, node.right): + if not child: + continue + + edges.append((node, child)) + queue.append(child) + + for fromNode, toNode in edges: + result += f'"{str(fromNode)}" -> "{str(toNode)}" [label="{Node.difference(toNode)}"]\n' + + return result + "}\n"