Add python implementation and tests

Signed-off-by: Matej Focko <me@mfocko.xyz>
This commit is contained in:
Matej Focko 2021-01-04 10:46:12 +01:00
parent 329e59ca15
commit 70261aed50
No known key found for this signature in database
GPG key ID: 625A0E78B2D1D899
7 changed files with 756 additions and 0 deletions

3
.vscode/settings.json vendored Normal file
View file

@ -0,0 +1,3 @@
{
"python.formatting.provider": "black"
}

15
Pipfile Normal file
View file

@ -0,0 +1,15 @@
[[source]]
url = "https://pypi.org/simple"
verify_ssl = true
name = "pypi"
[packages]
pytest = "*"
hypothesis = "*"
black = "*"
pylint = "*"
[dev-packages]
[requires]
python_version = "3.9"

99
Pipfile.lock generated Normal file
View file

@ -0,0 +1,99 @@
{
"_meta": {
"hash": {
"sha256": "95de7fc3f7d4accf18f255d512e95bd44ee4e34800f8a38f99be490af2d75e6b"
},
"pipfile-spec": 6,
"requires": {
"python_version": "3.9"
},
"sources": [
{
"name": "pypi",
"url": "https://pypi.org/simple",
"verify_ssl": true
}
]
},
"default": {
"attrs": {
"hashes": [
"sha256:31b2eced602aa8423c2aea9c76a724617ed67cf9513173fd3a4f03e3a929c7e6",
"sha256:832aa3cde19744e49938b91fea06d69ecb9e649c93ba974535d08ad92164f700"
],
"markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'",
"version": "==20.3.0"
},
"hypothesis": {
"hashes": [
"sha256:6094f922c1d2bef808b8e6cf3ab8cdb8ad76faa12519a25ac4250eed524e25c0",
"sha256:6dfc4c233f3903f1524242c1e1c92b5e06fb4b689d3fd2abbd22783af25d3c27"
],
"index": "pypi",
"version": "==5.43.4"
},
"iniconfig": {
"hashes": [
"sha256:011e24c64b7f47f6ebd835bb12a743f2fbe9a26d4cecaa7f53bc4f35ee9da8b3",
"sha256:bc3af051d7d14b2ee5ef9969666def0cd1a000e121eaea580d4a313df4b37f32"
],
"version": "==1.1.1"
},
"packaging": {
"hashes": [
"sha256:24e0da08660a87484d1602c30bb4902d74816b6985b93de36926f5bc95741858",
"sha256:78598185a7008a470d64526a8059de9aaa449238f280fc9eb6b13ba6c4109093"
],
"markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'",
"version": "==20.8"
},
"pluggy": {
"hashes": [
"sha256:15b2acde666561e1298d71b523007ed7364de07029219b604cf808bfa1c765b0",
"sha256:966c145cd83c96502c3c3868f50408687b38434af77734af1e9ca461a4081d2d"
],
"markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'",
"version": "==0.13.1"
},
"py": {
"hashes": [
"sha256:21b81bda15b66ef5e1a777a21c4dcd9c20ad3efd0b3f817e7a809035269e1bd3",
"sha256:3b80836aa6d1feeaa108e046da6423ab8f6ceda6468545ae8d02d9d58d18818a"
],
"markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'",
"version": "==1.10.0"
},
"pyparsing": {
"hashes": [
"sha256:c203ec8783bf771a155b207279b9bccb8dea02d8f0c9e5f8ead507bc3246ecc1",
"sha256:ef9d7589ef3c200abe66653d3f1ab1033c3c419ae9b9bdb1240a85b024efc88b"
],
"markers": "python_version >= '2.6' and python_version not in '3.0, 3.1, 3.2, 3.3'",
"version": "==2.4.7"
},
"pytest": {
"hashes": [
"sha256:1969f797a1a0dbd8ccf0fecc80262312729afea9c17f1d70ebf85c5e76c6f7c8",
"sha256:66e419b1899bc27346cb2c993e12c5e5e8daba9073c1fbce33b9807abc95c306"
],
"index": "pypi",
"version": "==6.2.1"
},
"sortedcontainers": {
"hashes": [
"sha256:37257a32add0a3ee490bb170b599e93095eed89a55da91fa9f48753ea12fd73f",
"sha256:59cc937650cf60d677c16775597c89a960658a09cf7c1a668f86e1e4464b10a1"
],
"version": "==2.3.0"
},
"toml": {
"hashes": [
"sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b",
"sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f"
],
"markers": "python_version >= '2.6' and python_version not in '3.0, 3.1, 3.2, 3.3'",
"version": "==0.10.2"
}
},
"develop": {}
}

139
node.py Normal file
View file

@ -0,0 +1,139 @@
import enum
class NodeType(enum.IntEnum):
LEAF = 0
UNARY = 1
BINARY = 2
class Node:
def __init__(self, value, left=None, right=None, parent=None):
self.parent = parent
self.left = left
self.right = right
self.value = value
self.rank = 0
@property
def type(self):
if self.left and self.right:
return NodeType.BINARY
if self.left or self.right:
return NodeType.UNARY
return NodeType.LEAF
def __repr__(self):
return f"Node(value={self.value}, rank={self.rank}, left={self.left}, right={self.right}, parent={self.parent})"
def __str__(self):
return f"Node(value={self.value}, rank={self.rank})"
@staticmethod
def height(node):
return 1 + max(Node.height(node.left), Node.height(node.right)) if node else -1
@staticmethod
def rank(node):
return -1 if not node else node.rank
@staticmethod
def difference(node, parent=None):
if not parent:
parent = node.parent if node else None
return Node.rank(parent) - Node.rank(node)
@staticmethod
def differences(node):
node_rank = Node.rank(node)
(left, right) = (node.left, node.right) if node else (None, None)
return (node_rank - Node.rank(left), node_rank - Node.rank(right))
def promote(self):
self.rank += 1
return self
def demote(self):
self.rank -= 1
return self
@staticmethod
def rotate_right(x):
parent = x.parent
y = x.left
z = x.right
assert y is not None
if parent:
if parent.left is x:
parent.left = y
else:
parent.right = y
x.left = y.right
if x.left:
x.left.parent = x
y.right = x
x.parent = y
y.parent = parent
return y
@staticmethod
def rotate_left(x):
parent = x.parent
y = x.left
z = x.right
assert z is not None
if parent:
if parent.left is x:
parent.left = z
else:
parent.right = z
x.right = z.left
if x.right:
x.right.parent = x
z.left = x
x.parent = z
z.parent = parent
return z
@staticmethod
def find_parent_node(value, node, missing=True):
new_node = node
while new_node and (missing or new_node.value != value):
node = new_node
new_node = node.left if value < node.value else node.right
return node
@staticmethod
def search(value, node):
while node and node.value != value:
node = node.left if value < node.value else node.right
return node
@staticmethod
def minimum(node):
while node.left:
node = node.left
return node
@staticmethod
def maximum(node):
while node.right:
node = node.right
return node

15
test_node.py Normal file
View file

@ -0,0 +1,15 @@
from node import Node, NodeType
def test_default_rank():
node = Node(0)
assert 0 == node.value
assert node.parent is None
assert node.left is None
assert node.right is None
assert 0 == node.rank
assert 0 == Node.rank(node)
assert -1 == Node.difference(node)
assert (1, 1) == Node.differences(node)

185
test_wavl.py Normal file
View file

@ -0,0 +1,185 @@
from node import Node, NodeType
from wavl import WAVLTree
import hypothesis
import hypothesis.strategies as st
import logging
import pytest
import random
import unittest
logger = logging.getLogger(__name__)
def test_empty():
tree = WAVLTree()
assert tree.root is None
assert tree.is_correct
def test_one_node():
tree = WAVLTree()
tree.insert(1)
assert tree.root is not None
assert 1 == tree.root.value
assert 0 == tree.root.rank
assert tree.root.left is None
assert tree.root.right is None
assert tree.is_correct
@pytest.mark.parametrize("values", [[1, 2], [1, 2, 0]])
def test_no_rebalance_needed(values):
tree = WAVLTree()
for value in values:
tree.insert(value)
assert tree.is_correct
def test_three_nodes_rebalanced():
tree = WAVLTree()
for value in (1, 2, 3):
print(tree)
tree.insert(value)
assert tree.is_correct
def test_bigger_tree():
tree = WAVLTree()
for i in range(50):
tree.insert(i)
assert tree.is_correct
def test_bigger_tree_reversed():
tree = WAVLTree()
for i in range(50):
tree.insert(-i)
assert tree.is_correct
def test_promote():
tree = WAVLTree()
for value in (0, 1, -1):
tree.insert(value)
assert tree.is_correct
@pytest.mark.parametrize(
"values",
[
[0, 1, -2, -1, -3],
[0, 1, -2, -1, -3, -4, 4, 9, 7],
[0, 1, -2, -1, -3, -4, 4, 9, 7, 5, -5, 8],
],
)
def test_rotate(values):
tree = WAVLTree()
for value in values:
tree.insert(value)
assert tree.is_correct
@hypothesis.settings(max_examples=1000, deadline=None)
@hypothesis.given(values=st.sets(st.integers()))
def test_insert_random(values):
tree = WAVLTree()
for value in values:
tree.insert(value)
assert tree.is_correct
assert tree.search(value) is not None
def test_search_empty():
tree = WAVLTree()
assert tree.search(0) is None
@hypothesis.given(values=st.sets(st.integers()))
def test_search_random(values):
tree = WAVLTree()
for value in values:
tree.insert(value)
assert tree.is_correct
for value in values:
assert tree.search(value) is not None
@st.composite
def delete_strategy(draw):
values = list(draw(st.sets(st.integers())))
delete_order = values.copy()
random.shuffle(delete_order)
return (values, delete_order)
@hypothesis.settings(max_examples=100000, deadline=None)
@hypothesis.given(config=delete_strategy())
def test_delete_random(config):
values, delete_order = config
tree = WAVLTree()
for value in values:
tree.insert(value)
for value in delete_order:
before = str(tree)
tree.delete(value)
after = str(tree)
try:
assert tree.is_correct
except AssertionError:
logger.info(
f"[FAIL] Deleting of {value} from {values} in order {delete_order}"
)
logger.info(f"Before:\n{before}")
logger.info(f"After:\n{after}")
raise
# @unittest.skip("Only for replicating hypothesis")
@pytest.mark.parametrize(
"values, delete_order",
[
([0, 1, 2], [0, 2, 1]),
([0, 1, 2, -1], [2, 0, 1, -1]),
([0, 1, -1], [0, -1, 1]),
([0, 1], [0, 1]),
([0, 1, 32, -1], [32, 0, 1, -1]),
],
)
def test_delete_minimal(values, delete_order):
tree = WAVLTree()
for value in values:
tree.insert(value)
for value in delete_order:
before = str(tree)
tree.delete(value)
after = str(tree)
try:
assert tree.is_correct
except AssertionError:
logger.info(
f"[FAIL] Deleting of {value} from {values} in order {delete_order}"
)
logger.info(f"Before:\n{before}")
logger.info(f"After:\n{after}")
raise

300
wavl.py Normal file
View file

@ -0,0 +1,300 @@
from node import Node, NodeType
from collections import deque
import logging
logger = logging.getLogger(__name__)
class WAVLTree:
def __init__(self):
self.root = None
@property
def rank(self):
return Node.rank(self.root)
@property
def is_correct(self):
return WAVLTree.is_correct_node(self.root)
@staticmethod
def is_correct_node(node, recursive=True):
if not node:
return True
for child_rank in Node.differences(node):
if child_rank not in (1, 2):
return False
if node.type == NodeType.LEAF:
return node.rank == 0
return not recursive or (
WAVLTree.is_correct_node(node.left) and WAVLTree.is_correct_node(node.right)
)
def search(self, value, node=None):
if not node:
node = self.root
return Node.search(value, node)
def __fix_0_child(self, x, y, z, rotate_left, rotate_right):
new_root = x.parent
if not y or Node.difference(y) == 2:
new_root = rotate_right(z)
z.demote()
elif Node.difference(y) == 1:
rotate_left(x)
new_root = rotate_right(z)
y.promote()
x.demote()
z.demote()
return new_root
def __bottomup_rebalance(self, x):
diffs = Node.differences(x.parent)
# if diffs != (0, 1) and diffs != (1, 0):
# return
while x.parent and (diffs == (0, 1) or diffs == (1, 0)):
x.parent.promote()
x = x.parent
diffs = Node.differences(x.parent)
if not x.parent:
return
rotating_around_root = x.parent.parent is None
new_root = x.parent
rank_difference = Node.difference(x)
if rank_difference != 0:
return
if rank_difference == 0 and x.parent.left is x:
new_root = self.__fix_0_child(
x, x.right, x.parent, Node.rotate_left, Node.rotate_right
)
elif rank_difference == 0 and x.parent.right is x:
new_root = self.__fix_0_child(
x, x.left, x.parent, Node.rotate_right, Node.rotate_left
)
if rotating_around_root:
self.root = new_root
def insert(self, value):
inserted_node = Node(value)
if not self.root:
self.root = inserted_node
return
parent = Node.find_parent_node(value, self.root)
inserted_node.parent = parent
if value < parent.value:
parent.left = inserted_node
else:
parent.right = inserted_node
self.__bottomup_rebalance(inserted_node)
def __transplant(self, u, v):
if not u.parent:
self.root = v
elif u.parent.left is u:
u.parent.left = v
else:
u.parent.right = v
if v:
v.rank = u.rank
v.parent = u.parent
@staticmethod
def __fix_delete(x, y, z, reversed, rotate_left, rotate_right):
new_root = x
v = y.left
w = y.right
if reversed:
v, w = w, v
logger.debug(f"__fix_delete({x}, {y}, {z}, {reversed})")
w_diff = Node.difference(w, y)
logger.debug(f"w_diff = {w_diff}")
if w_diff == 1 and y.parent:
logger.debug(f"y.parent = {y.parent}")
new_root = rotate_left(y.parent)
y.promote()
z.demote()
if z.type == NodeType.LEAF:
z.demote()
elif w_diff == 2 and v.parent:
logger.debug(f"v.parent = {v.parent}")
rotate_right(v.parent)
new_root = rotate_left(v.parent)
v.promote().promote()
y.demote()
z.demote().demote()
return new_root
def __bottomup_delete(self, x, parent):
x_diff = Node.difference(x, parent)
if x_diff != 3 or not parent:
return
y = parent.right if parent.left is x else parent.left
y_diff = Node.difference(y, parent)
while (
parent
and x_diff == 3
and y
and (y_diff == 2 or Node.differences(y) == (2, 2))
):
parent.demote()
if y_diff != 2:
y.demote()
x = parent
parent = x.parent
if not parent:
return
y = parent.right if parent.left is x else parent.left
x_diff = Node.difference(x, parent)
y_diff = Node.difference(y, parent)
if not parent:
return
rotating_around_root = parent.parent is None
new_root = parent
parent_node_diffs = Node.differences(parent)
if parent_node_diffs in ((1, 3), (3, 1)):
if parent.left is x:
new_root = WAVLTree.__fix_delete(
x,
parent.right,
parent,
False,
Node.rotate_left,
Node.rotate_right,
)
else:
new_root = WAVLTree.__fix_delete(
x,
parent.left,
parent,
True,
Node.rotate_right,
Node.rotate_left,
)
if rotating_around_root:
self.root = new_root
def __delete_fixup(self, y, parent=None):
logger.debug(f"[__delete_fixup] y = {y}, parent = {parent}")
z = y if y else parent
logger.debug(
f"[z.demote()] Node.differences({repr(z)}) == (2, 2) ~>* {Node.differences(z)} == (2, 2)"
)
if Node.differences(z) == (2, 2):
z.demote()
if parent:
for y in (parent.left, parent.right):
logger.debug(
f"[bottom-up delete] Node.difference({y}, {parent}) == 3 ~>* {Node.difference(y, parent)} == 3"
)
if Node.difference(y, parent) == 3:
self.__bottomup_delete(y, parent)
def __fix_after_delete(self, node, parent):
while node or parent:
self.__delete_fixup(node, parent)
node, parent = parent, (parent.parent if parent else None)
def delete_node(self, node):
y, parent = None, node.parent
if not node.left:
logger.debug("node.left is None")
y = node.right
self.__transplant(node, node.right)
elif not node.right:
logger.debug("node.right is None")
y = node.left
self.__transplant(node, node.left)
else:
logger.debug("taking successor")
n = Node.minimum(node.right)
y, parent = None, (n.parent if n.parent is not node else n)
if n.parent is not node:
parent = n.right if n.right else n.parent
self.__transplant(n, n.right)
n.right = node.right
n.right.parent = n
self.__transplant(node, n)
n.left = node.left
n.left.parent = n
return (y, parent)
def __delete(self, value, node=None):
if node is None:
return
if node.value != value:
return self.__delete(value, node.left if value < node.value else node.right)
(y, parent) = self.delete_node(node)
self.__fix_after_delete(y, parent)
def delete(self, value):
logger.debug(f"[DELETE] {value}")
self.__delete(value, self.root)
def __str__(self):
result = "digraph {\n"
queue = deque()
queue.append(self.root)
edges = []
while queue:
node = queue.popleft()
if not node:
continue
result += f'"{str(node)}" [label="{node.value}, {node.rank}"];\n'
for child in (node.left, node.right):
if not child:
continue
edges.append((node, child))
queue.append(child)
for fromNode, toNode in edges:
result += f'"{str(fromNode)}" -> "{str(toNode)}" [label="{Node.difference(toNode)}"]\n'
return result + "}\n"