Compare commits

..

No commits in common. "main" and "submit" have entirely different histories.
main ... submit

11 changed files with 114 additions and 351 deletions

View file

@ -1,42 +0,0 @@
image: python:latest
variables:
PIP_CACHE_DIR: "$CI_PROJECT_DIR/.cache/pip"
cache:
paths:
- .cache/pip
- venv/
before_script:
- python --version
- pip install virtualenv
- virtualenv venv
- source venv/bin/activate
black:
stage: test
script:
- pip install black
- make black
flake8:
stage: test
script:
- pip install flake8
- make flake8
pytest:
stage: test
script:
- pip install hypothesis pytest pytest-cov
- make check
coverage: '/TOTAL.*\s([.\d]+)%/'
artifacts:
when: always
reports:
junit: report.xml
coverage_report:
coverage_format: cobertura
path: coverage.xml

View file

@ -1,5 +1,5 @@
check:
pytest -vvv --log-level debug --cov --cov-report=term-missing --cov-report=xml:coverage.xml --junitxml=report.xml
pytest -vvv --log-level debug --cov --cov-report=term-missing
lint: black flake8 mypy
@ -10,7 +10,7 @@ flake8:
flake8 *.py
mypy:
mypy --strict --disallow-any-explicit node.py ranked_tree.py avl.py wavl.py ravl.py rbt.py
mypy --strict --disallow-any-explicit node.py avl.py wavl.py ravl.py
clean:
rm -rf __pycache__ .hypothesis .mypy_cache .pytest_cache

41
avl.py
View file

@ -1,5 +1,5 @@
from node import Node, Comparable
from ranked_tree import RankedTree, RotateFunction
from node import Node, Comparable, RotateFunction
from ranked_tree import RankedTree
import logging
from typing import TypeVar, Optional
@ -49,18 +49,22 @@ class AVLTree(RankedTree[T]):
z: Node[T],
rotate_left: RotateFunction[T],
rotate_right: RotateFunction[T],
) -> None:
) -> Optional[Node[T]]:
new_root = x.parent
if not y or Node.difference(y) == 2:
rotate_right(z)
new_root = rotate_right(z)
z.demote()
elif Node.difference(y) == 1:
rotate_left(x)
rotate_right(z)
new_root = rotate_right(z)
y.promote()
x.demote()
z.demote()
return new_root
def _insert_rebalance(self, x: Node[T]) -> None:
diffs = Node.differences(x.parent)
while x.parent and (diffs == (0, 1) or diffs == (1, 0)):
@ -72,6 +76,9 @@ class AVLTree(RankedTree[T]):
if not x.parent:
return
rotating_around_root = x.parent.parent is None
new_root: Optional[Node[T]] = x.parent
rank_difference = Node.difference(x)
if rank_difference != 0:
return
@ -79,14 +86,17 @@ class AVLTree(RankedTree[T]):
x_parent = x.parent
assert x_parent is not None
if rank_difference == 0 and x.parent.left is x:
self.__fix_0_child(
x, x.right, x_parent, self.rotate_left, self.rotate_right
new_root = self.__fix_0_child(
x, x.right, x_parent, Node.rotate_left, Node.rotate_right
)
elif rank_difference == 0 and x.parent.right is x:
self.__fix_0_child(
x, x.left, x_parent, self.rotate_right, self.rotate_left
new_root = self.__fix_0_child(
x, x.left, x_parent, Node.rotate_right, Node.rotate_left
)
if rotating_around_root:
self.root = new_root
# endregion InsertRebalance
# region DeleteRebalance
@ -96,6 +106,7 @@ class AVLTree(RankedTree[T]):
x: Node[T],
y: Node[T],
leaning: int,
rotating_around_root: bool,
rotate_left: RotateFunction[T],
rotate_right: RotateFunction[T],
) -> bool:
@ -111,10 +122,12 @@ class AVLTree(RankedTree[T]):
for n in filter(None, (new_root.left, new_root.right, new_root)):
_update_rank(n)
if rotating_around_root:
self.root = new_root
return factor != 0
def __delete_fixup(
self, x: Node[T], parent: Optional[Node[T]] = None
self, x: Optional[Node[T]], parent: Optional[Node[T]] = None
) -> bool:
factor = _balance_factor(x)
if factor == 0:
@ -123,16 +136,18 @@ class AVLTree(RankedTree[T]):
elif factor in (-1, 1):
return False
rotating_around_root = x.parent is None
y, leaning, to_left, to_right = (
(x.right, 1, self.rotate_left, self.rotate_right)
(x.right, 1, Node.rotate_left, Node.rotate_right)
if factor == 2
else (x.left, -1, self.rotate_right, self.rotate_left)
else (x.left, -1, Node.rotate_right, Node.rotate_left)
)
assert y
return self.__delete_rotate(
x,
y,
leaning,
rotating_around_root,
to_left,
to_right,
)
@ -143,7 +158,7 @@ class AVLTree(RankedTree[T]):
if not node and not parent:
return
if not node and parent:
if not node:
node, parent = parent, parent.parent
while node and self.__delete_fixup(node, parent):

57
node.py
View file

@ -1,8 +1,8 @@
from abc import abstractmethod
import enum
from typing import (
Callable,
Iterable,
Iterator,
Optional,
Generic,
Tuple,
@ -28,7 +28,7 @@ class NodeType(enum.IntEnum):
T = TypeVar("T", bound=Comparable)
class Node(Generic[T], Iterable[T]):
class Node(Generic[T]):
def __init__(
self,
value: T,
@ -62,7 +62,7 @@ class Node(Generic[T], Iterable[T]):
def __str__(self) -> str:
return f"Node(value={self.value}, rank={self.rank})"
def __iter__(self) -> Iterator[T]:
def __iter__(self) -> Iterable[T]:
"""
Yields:
Keys from the subtree rooted at the node in an inorder fashion.
@ -111,10 +111,56 @@ class Node(Generic[T], Iterable[T]):
self.rank -= 1
return self
@staticmethod
def rotate_right(x: "Node[T]") -> "Node[T]":
parent = x.parent
y = x.left
# z = x.right
assert y is not None
if parent:
if parent.left is x:
parent.left = y
else:
parent.right = y
x.left = y.right
if x.left:
x.left.parent = x
y.right = x
x.parent = y
y.parent = parent
return y
@staticmethod
def rotate_left(x: "Node[T]") -> "Node[T]":
parent = x.parent
# y = x.left
z = x.right
assert z is not None
if parent:
if parent.left is x:
parent.left = z
else:
parent.right = z
x.right = z.left
if x.right:
x.right.parent = x
z.left = x
x.parent = z
z.parent = parent
return z
@staticmethod
def find_parent_node(
value: T, node: "Node[T]", missing: bool = True
) -> "Optional[Node[T]]":
) -> "Node[T]":
new_node: "Optional[Node[T]]" = node
while new_node and (missing or new_node.value != value):
@ -148,3 +194,6 @@ class Node(Generic[T], Iterable[T]):
node = node.right
return node
RotateFunction = Callable[[Node[T]], Node[T]]

View file

@ -3,7 +3,7 @@ from node import Node, Comparable
from abc import abstractmethod
from collections import deque
import logging
from typing import Callable, Deque, Iterable, Optional, Tuple, TypeVar, Generic
from typing import Deque, Iterable, Optional, Tuple, TypeVar, Generic
logger = logging.getLogger(__name__)
T = TypeVar("T", bound=Comparable)
@ -58,58 +58,6 @@ class RankedTree(Generic[T]):
# endregion TreeSpecificMethods
# region Rotations
def rotate_right(self, x: "Node[T]") -> "Node[T]":
parent = x.parent
y = x.left
# z = x.right
assert y is not None
if parent:
if parent.left is x:
parent.left = y
else:
parent.right = y
else:
self.root = y
x.left = y.right
if x.left:
x.left.parent = x
y.right = x
x.parent = y
y.parent = parent
return y
def rotate_left(self, x: "Node[T]") -> "Node[T]":
parent = x.parent
# y = x.left
z = x.right
assert z is not None
if parent:
if parent.left is x:
parent.left = z
else:
parent.right = z
else:
self.root = z
x.right = z.left
if x.right:
x.right.parent = x
z.left = x
x.parent = z
z.parent = parent
return z
# endregion Rotations
@property
def rank(self) -> int:
return Node.get_rank(self.root)
@ -174,7 +122,8 @@ class RankedTree(Generic[T]):
self._transplant(node, node.left)
else:
n = Node.minimum(node.right)
node.value, n.value = n.value, node.value
node.value = n.value
n.value = None
return self._delete_node(n)
return (y, parent)
@ -199,6 +148,3 @@ class RankedTree(Generic[T]):
"""
if self.root:
yield from self.root
RotateFunction = Callable[[Node[T]], Node[T]]

179
rbt.py
View file

@ -1,179 +0,0 @@
from node import Node, Comparable
from ranked_tree import RankedTree, RotateFunction
import enum
import logging
from typing import Callable, Tuple, TypeVar, Optional
logger = logging.getLogger(__name__)
T = TypeVar("T", bound=Comparable)
class Colour(enum.IntEnum):
"""
Represents colour of the edge or node.
"""
Red = 0
Black = 1
def has_double_black(x: Optional[Node[T]]) -> bool:
"""
Checks for double black child of x.
Args:
x: Node to be checked.
Returns:
`true`, if `x` has a double black node, `false` otherwise.
"""
return x is not None and 2 in Node.differences(x)
class RBTree(RankedTree[T]):
def is_correct_node(
self, node: Optional[Node[T]], recursive: bool = True
) -> bool:
if not node:
return True
left, right = Node.differences(node)
if left not in (Colour.Red, Colour.Black):
# left subtree has invalid difference
return False
elif right not in (Colour.Red, Colour.Black):
# right subtree has invalid difference
return False
if Node.difference(node) == Colour.Red and (
left == Colour.Red or right == Colour.Red
):
# two consecutive red nodes
return False
return not recursive or (
self.is_correct_node(node.left)
and self.is_correct_node(node.right)
)
# region InsertRebalance
def _insert_rebalance_step(
self,
z: Node[T],
y: Optional[Node[T]],
right_child: Node[T],
rotate_left: RotateFunction[T],
rotate_right: RotateFunction[T],
) -> Node[T]:
p = z.parent
pp = p.parent
if y is not None and Node.difference(y) == Colour.Red:
# Case 1
# ======
# zs uncle y is red
pp.rank += Colour.Black
z = pp
elif z == right_child:
# Case 2
# ======
# zs uncle y is black and z is a right child
z = p
rotate_left(p)
else:
# Case 3
# ======
# zs uncle y is black and z is a left child
rotate_right(pp)
return z
def _insert_rebalance(self, z: Node[T]) -> None:
while z.parent is not None and Node.difference(z.parent) == Colour.Red:
p = z.parent
pp = p.parent
assert pp
if p == pp.left:
z = self._insert_rebalance_step(
z, pp.right, p.right, self.rotate_left, self.rotate_right
)
else:
z = self._insert_rebalance_step(
z, pp.left, p.left, self.rotate_right, self.rotate_left
)
# endregion InsertRebalance
# region DeleteRebalance
def _delete_rebalance_step(
self,
x: Node[T],
w: Node[T],
parent: Node[T],
right: Callable[[Node[T]], Optional[Node[T]]],
rotate_left: RotateFunction[T],
rotate_right: RotateFunction[T],
) -> Tuple[Optional[Node[T]], Optional[Node[T]]]:
if Node.difference(w) == Colour.Red:
# Case 1
# ======
# xs sibling w is red
rotate_left(parent)
w = right(parent)
if Node.differences(w) == (Colour.Black, Colour.Black):
# Case 2
# ======
# xs sibling w is black, and both of ws children are black
parent.rank -= Colour.Black
x = parent
else:
# Case 3
# ======
# xs sibling w is black,
# ws left child is red, and ws right child is black
if Node.difference(right(w), w) == Colour.Black:
rotate_right(w)
w = right(parent)
# Case 4
# ======
# xs sibling w is black, and ws right child is red
parent.rank -= Colour.Black
w.rank += Colour.Black
rotate_left(parent)
x = self.root
return x, (x.parent if x else None)
def _delete_rebalance(
self, node: Optional[Node[T]], parent: Optional[Node[T]]
) -> None:
if not node and not parent:
return
while node != self.root and has_double_black(parent):
if node == parent.left:
node, parent = self._delete_rebalance_step(
node,
parent.right,
parent,
lambda x: x.right,
self.rotate_left,
self.rotate_right,
)
else:
node, parent = self._delete_rebalance_step(
node,
parent.left,
parent,
lambda x: x.left,
self.rotate_right,
self.rotate_left,
)
# endregion DeleteRebalance

View file

@ -1,8 +1,11 @@
from avl import AVLTree, _balance_factor
import logging
import random
from typing import List
import hypothesis
import hypothesis.strategies as st
import pytest
logger = logging.getLogger(__name__)

View file

@ -1,10 +1,10 @@
from avl import AVLTree
from wavl import WAVLTree
from ravl import RAVLTree
from rbt import RBTree
import logging
import random
from typing import List
import hypothesis
import hypothesis.strategies as st
@ -14,7 +14,8 @@ logger = logging.getLogger(__name__)
@pytest.mark.parametrize(
("RankBalancedTree"), [RBTree, AVLTree, WAVLTree, RAVLTree]
("RankBalancedTree"),
[AVLTree, WAVLTree, RAVLTree]
)
@hypothesis.settings(max_examples=10000)
@hypothesis.given(values=st.sets(st.integers()))
@ -44,7 +45,8 @@ def _report(t_before: str, t_after: str) -> None:
@pytest.mark.parametrize(
("RankBalancedTree"), [RBTree, AVLTree, WAVLTree, RAVLTree]
("RankBalancedTree"),
[AVLTree, WAVLTree, RAVLTree]
)
@hypothesis.settings(max_examples=10000)
@hypothesis.given(config=delete_strategy())

View file

@ -1,7 +1,10 @@
from ravl import RAVLTree
import hypothesis
import hypothesis.strategies as st
import logging
import pytest
import random
logger = logging.getLogger(__name__)

View file

@ -1,43 +0,0 @@
from rbt import RBTree
import logging
import pytest
logger = logging.getLogger(__name__)
@pytest.mark.parametrize(
"values, delete_order",
[
([0, 1, -1], [0, -1, 1]),
([0, 1, 2, 3, 4, 5], [4, 2, 1, 0, 5, 3]),
(
[32, 1, 2, 3, 4, 5, 0, -2, -4, -3, -1],
[-4, -3, 1, 2, 5, 3, -2, 4, 32, -1, 0],
),
([0, 1, 2, -2, -3, -1], [-3, 2, 1, 0, -1, -2]),
([0, 1, 2, 3, 4, 5, -1], [4, 2, 1, 0, 5, 3, -1]),
],
)
def test_delete_minimal(values, delete_order):
tree = RBTree()
for value in values:
tree.insert(value)
for value in delete_order:
logger.info("Deleting %s", value)
before = str(tree)
tree.delete(value)
after = str(tree)
try:
assert tree.is_correct
except AssertionError:
logger.info(
f"[FAIL] Delete {value} from {values} in order {delete_order}"
)
logger.info(f"Before:\n{before}")
logger.info(f"After:\n{after}")
raise

25
wavl.py
View file

@ -1,6 +1,5 @@
from ranked_tree import RotateFunction
from avl import AVLTree
from node import Node, NodeType, Comparable
from node import Node, NodeType, Comparable, RotateFunction
import logging
from typing import TypeVar, Optional
@ -36,9 +35,11 @@ class WAVLTree(AVLTree[T]):
y: Node[T],
z: Node[T],
reversed: bool,
rotating_around_root: bool,
rotate_left: RotateFunction[T],
rotate_right: RotateFunction[T],
) -> None:
new_root = x
v = y.left
w = y.right
@ -48,7 +49,9 @@ class WAVLTree(AVLTree[T]):
w_diff = Node.difference(w, y)
if w_diff == 1 and y.parent:
rotate_left(y.parent)
new_root = rotate_left(y.parent)
if rotating_around_root:
self.root = new_root
y.promote()
z.demote()
@ -57,7 +60,9 @@ class WAVLTree(AVLTree[T]):
z.demote()
elif w_diff == 2 and v.parent:
rotate_right(v.parent)
rotate_left(v.parent)
new_root = rotate_left(v.parent)
if rotating_around_root:
self.root = new_root
v.promote().promote()
y.demote()
@ -92,6 +97,8 @@ class WAVLTree(AVLTree[T]):
x_diff = Node.difference(x, parent)
y_diff = Node.difference(y, parent)
rotating_around_root = parent.parent is None
parent_node_diffs = Node.differences(parent)
if parent_node_diffs in ((1, 3), (3, 1)):
if parent.left is x:
@ -101,8 +108,9 @@ class WAVLTree(AVLTree[T]):
parent.right,
parent,
False,
self.rotate_left,
self.rotate_right,
rotating_around_root,
Node.rotate_left,
Node.rotate_right,
)
else:
assert parent.left
@ -111,8 +119,9 @@ class WAVLTree(AVLTree[T]):
parent.left,
parent,
True,
self.rotate_right,
self.rotate_left,
rotating_around_root,
Node.rotate_right,
Node.rotate_left,
)
def _delete_rebalance(