Add python implementation and tests
Signed-off-by: Matej Focko <me@mfocko.xyz>
This commit is contained in:
parent
329e59ca15
commit
70261aed50
7 changed files with 756 additions and 0 deletions
3
.vscode/settings.json
vendored
Normal file
3
.vscode/settings.json
vendored
Normal file
|
@ -0,0 +1,3 @@
|
|||
{
|
||||
"python.formatting.provider": "black"
|
||||
}
|
15
Pipfile
Normal file
15
Pipfile
Normal file
|
@ -0,0 +1,15 @@
|
|||
[[source]]
|
||||
url = "https://pypi.org/simple"
|
||||
verify_ssl = true
|
||||
name = "pypi"
|
||||
|
||||
[packages]
|
||||
pytest = "*"
|
||||
hypothesis = "*"
|
||||
black = "*"
|
||||
pylint = "*"
|
||||
|
||||
[dev-packages]
|
||||
|
||||
[requires]
|
||||
python_version = "3.9"
|
99
Pipfile.lock
generated
Normal file
99
Pipfile.lock
generated
Normal file
|
@ -0,0 +1,99 @@
|
|||
{
|
||||
"_meta": {
|
||||
"hash": {
|
||||
"sha256": "95de7fc3f7d4accf18f255d512e95bd44ee4e34800f8a38f99be490af2d75e6b"
|
||||
},
|
||||
"pipfile-spec": 6,
|
||||
"requires": {
|
||||
"python_version": "3.9"
|
||||
},
|
||||
"sources": [
|
||||
{
|
||||
"name": "pypi",
|
||||
"url": "https://pypi.org/simple",
|
||||
"verify_ssl": true
|
||||
}
|
||||
]
|
||||
},
|
||||
"default": {
|
||||
"attrs": {
|
||||
"hashes": [
|
||||
"sha256:31b2eced602aa8423c2aea9c76a724617ed67cf9513173fd3a4f03e3a929c7e6",
|
||||
"sha256:832aa3cde19744e49938b91fea06d69ecb9e649c93ba974535d08ad92164f700"
|
||||
],
|
||||
"markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'",
|
||||
"version": "==20.3.0"
|
||||
},
|
||||
"hypothesis": {
|
||||
"hashes": [
|
||||
"sha256:6094f922c1d2bef808b8e6cf3ab8cdb8ad76faa12519a25ac4250eed524e25c0",
|
||||
"sha256:6dfc4c233f3903f1524242c1e1c92b5e06fb4b689d3fd2abbd22783af25d3c27"
|
||||
],
|
||||
"index": "pypi",
|
||||
"version": "==5.43.4"
|
||||
},
|
||||
"iniconfig": {
|
||||
"hashes": [
|
||||
"sha256:011e24c64b7f47f6ebd835bb12a743f2fbe9a26d4cecaa7f53bc4f35ee9da8b3",
|
||||
"sha256:bc3af051d7d14b2ee5ef9969666def0cd1a000e121eaea580d4a313df4b37f32"
|
||||
],
|
||||
"version": "==1.1.1"
|
||||
},
|
||||
"packaging": {
|
||||
"hashes": [
|
||||
"sha256:24e0da08660a87484d1602c30bb4902d74816b6985b93de36926f5bc95741858",
|
||||
"sha256:78598185a7008a470d64526a8059de9aaa449238f280fc9eb6b13ba6c4109093"
|
||||
],
|
||||
"markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'",
|
||||
"version": "==20.8"
|
||||
},
|
||||
"pluggy": {
|
||||
"hashes": [
|
||||
"sha256:15b2acde666561e1298d71b523007ed7364de07029219b604cf808bfa1c765b0",
|
||||
"sha256:966c145cd83c96502c3c3868f50408687b38434af77734af1e9ca461a4081d2d"
|
||||
],
|
||||
"markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'",
|
||||
"version": "==0.13.1"
|
||||
},
|
||||
"py": {
|
||||
"hashes": [
|
||||
"sha256:21b81bda15b66ef5e1a777a21c4dcd9c20ad3efd0b3f817e7a809035269e1bd3",
|
||||
"sha256:3b80836aa6d1feeaa108e046da6423ab8f6ceda6468545ae8d02d9d58d18818a"
|
||||
],
|
||||
"markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'",
|
||||
"version": "==1.10.0"
|
||||
},
|
||||
"pyparsing": {
|
||||
"hashes": [
|
||||
"sha256:c203ec8783bf771a155b207279b9bccb8dea02d8f0c9e5f8ead507bc3246ecc1",
|
||||
"sha256:ef9d7589ef3c200abe66653d3f1ab1033c3c419ae9b9bdb1240a85b024efc88b"
|
||||
],
|
||||
"markers": "python_version >= '2.6' and python_version not in '3.0, 3.1, 3.2, 3.3'",
|
||||
"version": "==2.4.7"
|
||||
},
|
||||
"pytest": {
|
||||
"hashes": [
|
||||
"sha256:1969f797a1a0dbd8ccf0fecc80262312729afea9c17f1d70ebf85c5e76c6f7c8",
|
||||
"sha256:66e419b1899bc27346cb2c993e12c5e5e8daba9073c1fbce33b9807abc95c306"
|
||||
],
|
||||
"index": "pypi",
|
||||
"version": "==6.2.1"
|
||||
},
|
||||
"sortedcontainers": {
|
||||
"hashes": [
|
||||
"sha256:37257a32add0a3ee490bb170b599e93095eed89a55da91fa9f48753ea12fd73f",
|
||||
"sha256:59cc937650cf60d677c16775597c89a960658a09cf7c1a668f86e1e4464b10a1"
|
||||
],
|
||||
"version": "==2.3.0"
|
||||
},
|
||||
"toml": {
|
||||
"hashes": [
|
||||
"sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b",
|
||||
"sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f"
|
||||
],
|
||||
"markers": "python_version >= '2.6' and python_version not in '3.0, 3.1, 3.2, 3.3'",
|
||||
"version": "==0.10.2"
|
||||
}
|
||||
},
|
||||
"develop": {}
|
||||
}
|
139
node.py
Normal file
139
node.py
Normal file
|
@ -0,0 +1,139 @@
|
|||
import enum
|
||||
|
||||
|
||||
class NodeType(enum.IntEnum):
|
||||
LEAF = 0
|
||||
UNARY = 1
|
||||
BINARY = 2
|
||||
|
||||
|
||||
class Node:
|
||||
def __init__(self, value, left=None, right=None, parent=None):
|
||||
self.parent = parent
|
||||
self.left = left
|
||||
self.right = right
|
||||
|
||||
self.value = value
|
||||
self.rank = 0
|
||||
|
||||
@property
|
||||
def type(self):
|
||||
if self.left and self.right:
|
||||
return NodeType.BINARY
|
||||
|
||||
if self.left or self.right:
|
||||
return NodeType.UNARY
|
||||
|
||||
return NodeType.LEAF
|
||||
|
||||
def __repr__(self):
|
||||
return f"Node(value={self.value}, rank={self.rank}, left={self.left}, right={self.right}, parent={self.parent})"
|
||||
|
||||
def __str__(self):
|
||||
return f"Node(value={self.value}, rank={self.rank})"
|
||||
|
||||
@staticmethod
|
||||
def height(node):
|
||||
return 1 + max(Node.height(node.left), Node.height(node.right)) if node else -1
|
||||
|
||||
@staticmethod
|
||||
def rank(node):
|
||||
return -1 if not node else node.rank
|
||||
|
||||
@staticmethod
|
||||
def difference(node, parent=None):
|
||||
if not parent:
|
||||
parent = node.parent if node else None
|
||||
return Node.rank(parent) - Node.rank(node)
|
||||
|
||||
@staticmethod
|
||||
def differences(node):
|
||||
node_rank = Node.rank(node)
|
||||
(left, right) = (node.left, node.right) if node else (None, None)
|
||||
|
||||
return (node_rank - Node.rank(left), node_rank - Node.rank(right))
|
||||
|
||||
def promote(self):
|
||||
self.rank += 1
|
||||
return self
|
||||
|
||||
def demote(self):
|
||||
self.rank -= 1
|
||||
return self
|
||||
|
||||
@staticmethod
|
||||
def rotate_right(x):
|
||||
parent = x.parent
|
||||
y = x.left
|
||||
z = x.right
|
||||
|
||||
assert y is not None
|
||||
if parent:
|
||||
if parent.left is x:
|
||||
parent.left = y
|
||||
else:
|
||||
parent.right = y
|
||||
|
||||
x.left = y.right
|
||||
if x.left:
|
||||
x.left.parent = x
|
||||
|
||||
y.right = x
|
||||
x.parent = y
|
||||
y.parent = parent
|
||||
|
||||
return y
|
||||
|
||||
@staticmethod
|
||||
def rotate_left(x):
|
||||
parent = x.parent
|
||||
y = x.left
|
||||
z = x.right
|
||||
|
||||
assert z is not None
|
||||
if parent:
|
||||
if parent.left is x:
|
||||
parent.left = z
|
||||
else:
|
||||
parent.right = z
|
||||
|
||||
x.right = z.left
|
||||
if x.right:
|
||||
x.right.parent = x
|
||||
|
||||
z.left = x
|
||||
x.parent = z
|
||||
z.parent = parent
|
||||
|
||||
return z
|
||||
|
||||
@staticmethod
|
||||
def find_parent_node(value, node, missing=True):
|
||||
new_node = node
|
||||
|
||||
while new_node and (missing or new_node.value != value):
|
||||
node = new_node
|
||||
new_node = node.left if value < node.value else node.right
|
||||
|
||||
return node
|
||||
|
||||
@staticmethod
|
||||
def search(value, node):
|
||||
while node and node.value != value:
|
||||
node = node.left if value < node.value else node.right
|
||||
|
||||
return node
|
||||
|
||||
@staticmethod
|
||||
def minimum(node):
|
||||
while node.left:
|
||||
node = node.left
|
||||
|
||||
return node
|
||||
|
||||
@staticmethod
|
||||
def maximum(node):
|
||||
while node.right:
|
||||
node = node.right
|
||||
|
||||
return node
|
15
test_node.py
Normal file
15
test_node.py
Normal file
|
@ -0,0 +1,15 @@
|
|||
from node import Node, NodeType
|
||||
|
||||
|
||||
def test_default_rank():
|
||||
node = Node(0)
|
||||
|
||||
assert 0 == node.value
|
||||
assert node.parent is None
|
||||
assert node.left is None
|
||||
assert node.right is None
|
||||
|
||||
assert 0 == node.rank
|
||||
assert 0 == Node.rank(node)
|
||||
assert -1 == Node.difference(node)
|
||||
assert (1, 1) == Node.differences(node)
|
185
test_wavl.py
Normal file
185
test_wavl.py
Normal file
|
@ -0,0 +1,185 @@
|
|||
from node import Node, NodeType
|
||||
from wavl import WAVLTree
|
||||
|
||||
import hypothesis
|
||||
import hypothesis.strategies as st
|
||||
import logging
|
||||
import pytest
|
||||
import random
|
||||
import unittest
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def test_empty():
|
||||
tree = WAVLTree()
|
||||
|
||||
assert tree.root is None
|
||||
assert tree.is_correct
|
||||
|
||||
|
||||
def test_one_node():
|
||||
tree = WAVLTree()
|
||||
tree.insert(1)
|
||||
|
||||
assert tree.root is not None
|
||||
assert 1 == tree.root.value
|
||||
assert 0 == tree.root.rank
|
||||
assert tree.root.left is None
|
||||
assert tree.root.right is None
|
||||
|
||||
assert tree.is_correct
|
||||
|
||||
|
||||
@pytest.mark.parametrize("values", [[1, 2], [1, 2, 0]])
|
||||
def test_no_rebalance_needed(values):
|
||||
tree = WAVLTree()
|
||||
|
||||
for value in values:
|
||||
tree.insert(value)
|
||||
assert tree.is_correct
|
||||
|
||||
|
||||
def test_three_nodes_rebalanced():
|
||||
tree = WAVLTree()
|
||||
|
||||
for value in (1, 2, 3):
|
||||
print(tree)
|
||||
tree.insert(value)
|
||||
|
||||
assert tree.is_correct
|
||||
|
||||
|
||||
def test_bigger_tree():
|
||||
tree = WAVLTree()
|
||||
|
||||
for i in range(50):
|
||||
tree.insert(i)
|
||||
assert tree.is_correct
|
||||
|
||||
|
||||
def test_bigger_tree_reversed():
|
||||
tree = WAVLTree()
|
||||
|
||||
for i in range(50):
|
||||
tree.insert(-i)
|
||||
assert tree.is_correct
|
||||
|
||||
|
||||
def test_promote():
|
||||
tree = WAVLTree()
|
||||
|
||||
for value in (0, 1, -1):
|
||||
tree.insert(value)
|
||||
assert tree.is_correct
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"values",
|
||||
[
|
||||
[0, 1, -2, -1, -3],
|
||||
[0, 1, -2, -1, -3, -4, 4, 9, 7],
|
||||
[0, 1, -2, -1, -3, -4, 4, 9, 7, 5, -5, 8],
|
||||
],
|
||||
)
|
||||
def test_rotate(values):
|
||||
tree = WAVLTree()
|
||||
|
||||
for value in values:
|
||||
tree.insert(value)
|
||||
assert tree.is_correct
|
||||
|
||||
|
||||
@hypothesis.settings(max_examples=1000, deadline=None)
|
||||
@hypothesis.given(values=st.sets(st.integers()))
|
||||
def test_insert_random(values):
|
||||
tree = WAVLTree()
|
||||
|
||||
for value in values:
|
||||
tree.insert(value)
|
||||
assert tree.is_correct
|
||||
assert tree.search(value) is not None
|
||||
|
||||
|
||||
def test_search_empty():
|
||||
tree = WAVLTree()
|
||||
assert tree.search(0) is None
|
||||
|
||||
|
||||
@hypothesis.given(values=st.sets(st.integers()))
|
||||
def test_search_random(values):
|
||||
tree = WAVLTree()
|
||||
|
||||
for value in values:
|
||||
tree.insert(value)
|
||||
assert tree.is_correct
|
||||
|
||||
for value in values:
|
||||
assert tree.search(value) is not None
|
||||
|
||||
|
||||
@st.composite
|
||||
def delete_strategy(draw):
|
||||
values = list(draw(st.sets(st.integers())))
|
||||
delete_order = values.copy()
|
||||
random.shuffle(delete_order)
|
||||
return (values, delete_order)
|
||||
|
||||
|
||||
@hypothesis.settings(max_examples=100000, deadline=None)
|
||||
@hypothesis.given(config=delete_strategy())
|
||||
def test_delete_random(config):
|
||||
values, delete_order = config
|
||||
|
||||
tree = WAVLTree()
|
||||
|
||||
for value in values:
|
||||
tree.insert(value)
|
||||
|
||||
for value in delete_order:
|
||||
before = str(tree)
|
||||
tree.delete(value)
|
||||
after = str(tree)
|
||||
|
||||
try:
|
||||
assert tree.is_correct
|
||||
except AssertionError:
|
||||
logger.info(
|
||||
f"[FAIL] Deleting of {value} from {values} in order {delete_order}"
|
||||
)
|
||||
logger.info(f"Before:\n{before}")
|
||||
logger.info(f"After:\n{after}")
|
||||
raise
|
||||
|
||||
|
||||
# @unittest.skip("Only for replicating hypothesis")
|
||||
@pytest.mark.parametrize(
|
||||
"values, delete_order",
|
||||
[
|
||||
([0, 1, 2], [0, 2, 1]),
|
||||
([0, 1, 2, -1], [2, 0, 1, -1]),
|
||||
([0, 1, -1], [0, -1, 1]),
|
||||
([0, 1], [0, 1]),
|
||||
([0, 1, 32, -1], [32, 0, 1, -1]),
|
||||
],
|
||||
)
|
||||
def test_delete_minimal(values, delete_order):
|
||||
tree = WAVLTree()
|
||||
|
||||
for value in values:
|
||||
tree.insert(value)
|
||||
|
||||
for value in delete_order:
|
||||
before = str(tree)
|
||||
tree.delete(value)
|
||||
after = str(tree)
|
||||
|
||||
try:
|
||||
assert tree.is_correct
|
||||
except AssertionError:
|
||||
logger.info(
|
||||
f"[FAIL] Deleting of {value} from {values} in order {delete_order}"
|
||||
)
|
||||
logger.info(f"Before:\n{before}")
|
||||
logger.info(f"After:\n{after}")
|
||||
raise
|
300
wavl.py
Normal file
300
wavl.py
Normal file
|
@ -0,0 +1,300 @@
|
|||
from node import Node, NodeType
|
||||
|
||||
from collections import deque
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WAVLTree:
|
||||
def __init__(self):
|
||||
self.root = None
|
||||
|
||||
@property
|
||||
def rank(self):
|
||||
return Node.rank(self.root)
|
||||
|
||||
@property
|
||||
def is_correct(self):
|
||||
return WAVLTree.is_correct_node(self.root)
|
||||
|
||||
@staticmethod
|
||||
def is_correct_node(node, recursive=True):
|
||||
if not node:
|
||||
return True
|
||||
|
||||
for child_rank in Node.differences(node):
|
||||
if child_rank not in (1, 2):
|
||||
return False
|
||||
|
||||
if node.type == NodeType.LEAF:
|
||||
return node.rank == 0
|
||||
|
||||
return not recursive or (
|
||||
WAVLTree.is_correct_node(node.left) and WAVLTree.is_correct_node(node.right)
|
||||
)
|
||||
|
||||
def search(self, value, node=None):
|
||||
if not node:
|
||||
node = self.root
|
||||
|
||||
return Node.search(value, node)
|
||||
|
||||
def __fix_0_child(self, x, y, z, rotate_left, rotate_right):
|
||||
new_root = x.parent
|
||||
|
||||
if not y or Node.difference(y) == 2:
|
||||
new_root = rotate_right(z)
|
||||
z.demote()
|
||||
elif Node.difference(y) == 1:
|
||||
rotate_left(x)
|
||||
new_root = rotate_right(z)
|
||||
|
||||
y.promote()
|
||||
x.demote()
|
||||
z.demote()
|
||||
|
||||
return new_root
|
||||
|
||||
def __bottomup_rebalance(self, x):
|
||||
diffs = Node.differences(x.parent)
|
||||
# if diffs != (0, 1) and diffs != (1, 0):
|
||||
# return
|
||||
|
||||
while x.parent and (diffs == (0, 1) or diffs == (1, 0)):
|
||||
x.parent.promote()
|
||||
x = x.parent
|
||||
|
||||
diffs = Node.differences(x.parent)
|
||||
|
||||
if not x.parent:
|
||||
return
|
||||
|
||||
rotating_around_root = x.parent.parent is None
|
||||
new_root = x.parent
|
||||
|
||||
rank_difference = Node.difference(x)
|
||||
|
||||
if rank_difference != 0:
|
||||
return
|
||||
|
||||
if rank_difference == 0 and x.parent.left is x:
|
||||
new_root = self.__fix_0_child(
|
||||
x, x.right, x.parent, Node.rotate_left, Node.rotate_right
|
||||
)
|
||||
elif rank_difference == 0 and x.parent.right is x:
|
||||
new_root = self.__fix_0_child(
|
||||
x, x.left, x.parent, Node.rotate_right, Node.rotate_left
|
||||
)
|
||||
|
||||
if rotating_around_root:
|
||||
self.root = new_root
|
||||
|
||||
def insert(self, value):
|
||||
inserted_node = Node(value)
|
||||
|
||||
if not self.root:
|
||||
self.root = inserted_node
|
||||
return
|
||||
|
||||
parent = Node.find_parent_node(value, self.root)
|
||||
inserted_node.parent = parent
|
||||
|
||||
if value < parent.value:
|
||||
parent.left = inserted_node
|
||||
else:
|
||||
parent.right = inserted_node
|
||||
|
||||
self.__bottomup_rebalance(inserted_node)
|
||||
|
||||
def __transplant(self, u, v):
|
||||
if not u.parent:
|
||||
self.root = v
|
||||
elif u.parent.left is u:
|
||||
u.parent.left = v
|
||||
else:
|
||||
u.parent.right = v
|
||||
|
||||
if v:
|
||||
v.rank = u.rank
|
||||
v.parent = u.parent
|
||||
|
||||
@staticmethod
|
||||
def __fix_delete(x, y, z, reversed, rotate_left, rotate_right):
|
||||
new_root = x
|
||||
v = y.left
|
||||
w = y.right
|
||||
|
||||
if reversed:
|
||||
v, w = w, v
|
||||
|
||||
logger.debug(f"__fix_delete({x}, {y}, {z}, {reversed})")
|
||||
|
||||
w_diff = Node.difference(w, y)
|
||||
logger.debug(f"w_diff = {w_diff}")
|
||||
if w_diff == 1 and y.parent:
|
||||
logger.debug(f"y.parent = {y.parent}")
|
||||
new_root = rotate_left(y.parent)
|
||||
|
||||
y.promote()
|
||||
z.demote()
|
||||
|
||||
if z.type == NodeType.LEAF:
|
||||
z.demote()
|
||||
elif w_diff == 2 and v.parent:
|
||||
logger.debug(f"v.parent = {v.parent}")
|
||||
rotate_right(v.parent)
|
||||
new_root = rotate_left(v.parent)
|
||||
|
||||
v.promote().promote()
|
||||
y.demote()
|
||||
z.demote().demote()
|
||||
|
||||
return new_root
|
||||
|
||||
def __bottomup_delete(self, x, parent):
|
||||
x_diff = Node.difference(x, parent)
|
||||
if x_diff != 3 or not parent:
|
||||
return
|
||||
|
||||
y = parent.right if parent.left is x else parent.left
|
||||
y_diff = Node.difference(y, parent)
|
||||
|
||||
while (
|
||||
parent
|
||||
and x_diff == 3
|
||||
and y
|
||||
and (y_diff == 2 or Node.differences(y) == (2, 2))
|
||||
):
|
||||
parent.demote()
|
||||
if y_diff != 2:
|
||||
y.demote()
|
||||
|
||||
x = parent
|
||||
parent = x.parent
|
||||
if not parent:
|
||||
return
|
||||
y = parent.right if parent.left is x else parent.left
|
||||
|
||||
x_diff = Node.difference(x, parent)
|
||||
y_diff = Node.difference(y, parent)
|
||||
|
||||
if not parent:
|
||||
return
|
||||
|
||||
rotating_around_root = parent.parent is None
|
||||
new_root = parent
|
||||
|
||||
parent_node_diffs = Node.differences(parent)
|
||||
if parent_node_diffs in ((1, 3), (3, 1)):
|
||||
if parent.left is x:
|
||||
new_root = WAVLTree.__fix_delete(
|
||||
x,
|
||||
parent.right,
|
||||
parent,
|
||||
False,
|
||||
Node.rotate_left,
|
||||
Node.rotate_right,
|
||||
)
|
||||
else:
|
||||
new_root = WAVLTree.__fix_delete(
|
||||
x,
|
||||
parent.left,
|
||||
parent,
|
||||
True,
|
||||
Node.rotate_right,
|
||||
Node.rotate_left,
|
||||
)
|
||||
|
||||
if rotating_around_root:
|
||||
self.root = new_root
|
||||
|
||||
def __delete_fixup(self, y, parent=None):
|
||||
logger.debug(f"[__delete_fixup] y = {y}, parent = {parent}")
|
||||
|
||||
z = y if y else parent
|
||||
logger.debug(
|
||||
f"[z.demote()] Node.differences({repr(z)}) == (2, 2) ~>* {Node.differences(z)} == (2, 2)"
|
||||
)
|
||||
if Node.differences(z) == (2, 2):
|
||||
z.demote()
|
||||
|
||||
if parent:
|
||||
for y in (parent.left, parent.right):
|
||||
logger.debug(
|
||||
f"[bottom-up delete] Node.difference({y}, {parent}) == 3 ~>* {Node.difference(y, parent)} == 3"
|
||||
)
|
||||
if Node.difference(y, parent) == 3:
|
||||
self.__bottomup_delete(y, parent)
|
||||
|
||||
def __fix_after_delete(self, node, parent):
|
||||
while node or parent:
|
||||
self.__delete_fixup(node, parent)
|
||||
node, parent = parent, (parent.parent if parent else None)
|
||||
|
||||
def delete_node(self, node):
|
||||
y, parent = None, node.parent
|
||||
|
||||
if not node.left:
|
||||
logger.debug("node.left is None")
|
||||
y = node.right
|
||||
self.__transplant(node, node.right)
|
||||
elif not node.right:
|
||||
logger.debug("node.right is None")
|
||||
y = node.left
|
||||
self.__transplant(node, node.left)
|
||||
else:
|
||||
logger.debug("taking successor")
|
||||
n = Node.minimum(node.right)
|
||||
y, parent = None, (n.parent if n.parent is not node else n)
|
||||
|
||||
if n.parent is not node:
|
||||
parent = n.right if n.right else n.parent
|
||||
self.__transplant(n, n.right)
|
||||
n.right = node.right
|
||||
n.right.parent = n
|
||||
|
||||
self.__transplant(node, n)
|
||||
n.left = node.left
|
||||
n.left.parent = n
|
||||
|
||||
return (y, parent)
|
||||
|
||||
def __delete(self, value, node=None):
|
||||
if node is None:
|
||||
return
|
||||
|
||||
if node.value != value:
|
||||
return self.__delete(value, node.left if value < node.value else node.right)
|
||||
|
||||
(y, parent) = self.delete_node(node)
|
||||
self.__fix_after_delete(y, parent)
|
||||
|
||||
def delete(self, value):
|
||||
logger.debug(f"[DELETE] {value}")
|
||||
self.__delete(value, self.root)
|
||||
|
||||
def __str__(self):
|
||||
result = "digraph {\n"
|
||||
|
||||
queue = deque()
|
||||
queue.append(self.root)
|
||||
|
||||
edges = []
|
||||
while queue:
|
||||
node = queue.popleft()
|
||||
if not node:
|
||||
continue
|
||||
|
||||
result += f'"{str(node)}" [label="{node.value}, {node.rank}"];\n'
|
||||
for child in (node.left, node.right):
|
||||
if not child:
|
||||
continue
|
||||
|
||||
edges.append((node, child))
|
||||
queue.append(child)
|
||||
|
||||
for fromNode, toNode in edges:
|
||||
result += f'"{str(fromNode)}" -> "{str(toNode)}" [label="{Node.difference(toNode)}"]\n'
|
||||
|
||||
return result + "}\n"
|
Loading…
Reference in a new issue