Compare commits

..

No commits in common. "main" and "submit" have entirely different histories.
main ... submit

11 changed files with 114 additions and 351 deletions

View file

@ -1,42 +0,0 @@
image: python:latest
variables:
PIP_CACHE_DIR: "$CI_PROJECT_DIR/.cache/pip"
cache:
paths:
- .cache/pip
- venv/
before_script:
- python --version
- pip install virtualenv
- virtualenv venv
- source venv/bin/activate
black:
stage: test
script:
- pip install black
- make black
flake8:
stage: test
script:
- pip install flake8
- make flake8
pytest:
stage: test
script:
- pip install hypothesis pytest pytest-cov
- make check
coverage: '/TOTAL.*\s([.\d]+)%/'
artifacts:
when: always
reports:
junit: report.xml
coverage_report:
coverage_format: cobertura
path: coverage.xml

View file

@ -1,5 +1,5 @@
check: check:
pytest -vvv --log-level debug --cov --cov-report=term-missing --cov-report=xml:coverage.xml --junitxml=report.xml pytest -vvv --log-level debug --cov --cov-report=term-missing
lint: black flake8 mypy lint: black flake8 mypy
@ -10,7 +10,7 @@ flake8:
flake8 *.py flake8 *.py
mypy: mypy:
mypy --strict --disallow-any-explicit node.py ranked_tree.py avl.py wavl.py ravl.py rbt.py mypy --strict --disallow-any-explicit node.py avl.py wavl.py ravl.py
clean: clean:
rm -rf __pycache__ .hypothesis .mypy_cache .pytest_cache rm -rf __pycache__ .hypothesis .mypy_cache .pytest_cache

41
avl.py
View file

@ -1,5 +1,5 @@
from node import Node, Comparable from node import Node, Comparable, RotateFunction
from ranked_tree import RankedTree, RotateFunction from ranked_tree import RankedTree
import logging import logging
from typing import TypeVar, Optional from typing import TypeVar, Optional
@ -49,18 +49,22 @@ class AVLTree(RankedTree[T]):
z: Node[T], z: Node[T],
rotate_left: RotateFunction[T], rotate_left: RotateFunction[T],
rotate_right: RotateFunction[T], rotate_right: RotateFunction[T],
) -> None: ) -> Optional[Node[T]]:
new_root = x.parent
if not y or Node.difference(y) == 2: if not y or Node.difference(y) == 2:
rotate_right(z) new_root = rotate_right(z)
z.demote() z.demote()
elif Node.difference(y) == 1: elif Node.difference(y) == 1:
rotate_left(x) rotate_left(x)
rotate_right(z) new_root = rotate_right(z)
y.promote() y.promote()
x.demote() x.demote()
z.demote() z.demote()
return new_root
def _insert_rebalance(self, x: Node[T]) -> None: def _insert_rebalance(self, x: Node[T]) -> None:
diffs = Node.differences(x.parent) diffs = Node.differences(x.parent)
while x.parent and (diffs == (0, 1) or diffs == (1, 0)): while x.parent and (diffs == (0, 1) or diffs == (1, 0)):
@ -72,6 +76,9 @@ class AVLTree(RankedTree[T]):
if not x.parent: if not x.parent:
return return
rotating_around_root = x.parent.parent is None
new_root: Optional[Node[T]] = x.parent
rank_difference = Node.difference(x) rank_difference = Node.difference(x)
if rank_difference != 0: if rank_difference != 0:
return return
@ -79,14 +86,17 @@ class AVLTree(RankedTree[T]):
x_parent = x.parent x_parent = x.parent
assert x_parent is not None assert x_parent is not None
if rank_difference == 0 and x.parent.left is x: if rank_difference == 0 and x.parent.left is x:
self.__fix_0_child( new_root = self.__fix_0_child(
x, x.right, x_parent, self.rotate_left, self.rotate_right x, x.right, x_parent, Node.rotate_left, Node.rotate_right
) )
elif rank_difference == 0 and x.parent.right is x: elif rank_difference == 0 and x.parent.right is x:
self.__fix_0_child( new_root = self.__fix_0_child(
x, x.left, x_parent, self.rotate_right, self.rotate_left x, x.left, x_parent, Node.rotate_right, Node.rotate_left
) )
if rotating_around_root:
self.root = new_root
# endregion InsertRebalance # endregion InsertRebalance
# region DeleteRebalance # region DeleteRebalance
@ -96,6 +106,7 @@ class AVLTree(RankedTree[T]):
x: Node[T], x: Node[T],
y: Node[T], y: Node[T],
leaning: int, leaning: int,
rotating_around_root: bool,
rotate_left: RotateFunction[T], rotate_left: RotateFunction[T],
rotate_right: RotateFunction[T], rotate_right: RotateFunction[T],
) -> bool: ) -> bool:
@ -111,10 +122,12 @@ class AVLTree(RankedTree[T]):
for n in filter(None, (new_root.left, new_root.right, new_root)): for n in filter(None, (new_root.left, new_root.right, new_root)):
_update_rank(n) _update_rank(n)
if rotating_around_root:
self.root = new_root
return factor != 0 return factor != 0
def __delete_fixup( def __delete_fixup(
self, x: Node[T], parent: Optional[Node[T]] = None self, x: Optional[Node[T]], parent: Optional[Node[T]] = None
) -> bool: ) -> bool:
factor = _balance_factor(x) factor = _balance_factor(x)
if factor == 0: if factor == 0:
@ -123,16 +136,18 @@ class AVLTree(RankedTree[T]):
elif factor in (-1, 1): elif factor in (-1, 1):
return False return False
rotating_around_root = x.parent is None
y, leaning, to_left, to_right = ( y, leaning, to_left, to_right = (
(x.right, 1, self.rotate_left, self.rotate_right) (x.right, 1, Node.rotate_left, Node.rotate_right)
if factor == 2 if factor == 2
else (x.left, -1, self.rotate_right, self.rotate_left) else (x.left, -1, Node.rotate_right, Node.rotate_left)
) )
assert y assert y
return self.__delete_rotate( return self.__delete_rotate(
x, x,
y, y,
leaning, leaning,
rotating_around_root,
to_left, to_left,
to_right, to_right,
) )
@ -143,7 +158,7 @@ class AVLTree(RankedTree[T]):
if not node and not parent: if not node and not parent:
return return
if not node and parent: if not node:
node, parent = parent, parent.parent node, parent = parent, parent.parent
while node and self.__delete_fixup(node, parent): while node and self.__delete_fixup(node, parent):

57
node.py
View file

@ -1,8 +1,8 @@
from abc import abstractmethod from abc import abstractmethod
import enum import enum
from typing import ( from typing import (
Callable,
Iterable, Iterable,
Iterator,
Optional, Optional,
Generic, Generic,
Tuple, Tuple,
@ -28,7 +28,7 @@ class NodeType(enum.IntEnum):
T = TypeVar("T", bound=Comparable) T = TypeVar("T", bound=Comparable)
class Node(Generic[T], Iterable[T]): class Node(Generic[T]):
def __init__( def __init__(
self, self,
value: T, value: T,
@ -62,7 +62,7 @@ class Node(Generic[T], Iterable[T]):
def __str__(self) -> str: def __str__(self) -> str:
return f"Node(value={self.value}, rank={self.rank})" return f"Node(value={self.value}, rank={self.rank})"
def __iter__(self) -> Iterator[T]: def __iter__(self) -> Iterable[T]:
""" """
Yields: Yields:
Keys from the subtree rooted at the node in an inorder fashion. Keys from the subtree rooted at the node in an inorder fashion.
@ -111,10 +111,56 @@ class Node(Generic[T], Iterable[T]):
self.rank -= 1 self.rank -= 1
return self return self
@staticmethod
def rotate_right(x: "Node[T]") -> "Node[T]":
parent = x.parent
y = x.left
# z = x.right
assert y is not None
if parent:
if parent.left is x:
parent.left = y
else:
parent.right = y
x.left = y.right
if x.left:
x.left.parent = x
y.right = x
x.parent = y
y.parent = parent
return y
@staticmethod
def rotate_left(x: "Node[T]") -> "Node[T]":
parent = x.parent
# y = x.left
z = x.right
assert z is not None
if parent:
if parent.left is x:
parent.left = z
else:
parent.right = z
x.right = z.left
if x.right:
x.right.parent = x
z.left = x
x.parent = z
z.parent = parent
return z
@staticmethod @staticmethod
def find_parent_node( def find_parent_node(
value: T, node: "Node[T]", missing: bool = True value: T, node: "Node[T]", missing: bool = True
) -> "Optional[Node[T]]": ) -> "Node[T]":
new_node: "Optional[Node[T]]" = node new_node: "Optional[Node[T]]" = node
while new_node and (missing or new_node.value != value): while new_node and (missing or new_node.value != value):
@ -148,3 +194,6 @@ class Node(Generic[T], Iterable[T]):
node = node.right node = node.right
return node return node
RotateFunction = Callable[[Node[T]], Node[T]]

View file

@ -3,7 +3,7 @@ from node import Node, Comparable
from abc import abstractmethod from abc import abstractmethod
from collections import deque from collections import deque
import logging import logging
from typing import Callable, Deque, Iterable, Optional, Tuple, TypeVar, Generic from typing import Deque, Iterable, Optional, Tuple, TypeVar, Generic
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
T = TypeVar("T", bound=Comparable) T = TypeVar("T", bound=Comparable)
@ -58,58 +58,6 @@ class RankedTree(Generic[T]):
# endregion TreeSpecificMethods # endregion TreeSpecificMethods
# region Rotations
def rotate_right(self, x: "Node[T]") -> "Node[T]":
parent = x.parent
y = x.left
# z = x.right
assert y is not None
if parent:
if parent.left is x:
parent.left = y
else:
parent.right = y
else:
self.root = y
x.left = y.right
if x.left:
x.left.parent = x
y.right = x
x.parent = y
y.parent = parent
return y
def rotate_left(self, x: "Node[T]") -> "Node[T]":
parent = x.parent
# y = x.left
z = x.right
assert z is not None
if parent:
if parent.left is x:
parent.left = z
else:
parent.right = z
else:
self.root = z
x.right = z.left
if x.right:
x.right.parent = x
z.left = x
x.parent = z
z.parent = parent
return z
# endregion Rotations
@property @property
def rank(self) -> int: def rank(self) -> int:
return Node.get_rank(self.root) return Node.get_rank(self.root)
@ -174,7 +122,8 @@ class RankedTree(Generic[T]):
self._transplant(node, node.left) self._transplant(node, node.left)
else: else:
n = Node.minimum(node.right) n = Node.minimum(node.right)
node.value, n.value = n.value, node.value node.value = n.value
n.value = None
return self._delete_node(n) return self._delete_node(n)
return (y, parent) return (y, parent)
@ -199,6 +148,3 @@ class RankedTree(Generic[T]):
""" """
if self.root: if self.root:
yield from self.root yield from self.root
RotateFunction = Callable[[Node[T]], Node[T]]

179
rbt.py
View file

@ -1,179 +0,0 @@
from node import Node, Comparable
from ranked_tree import RankedTree, RotateFunction
import enum
import logging
from typing import Callable, Tuple, TypeVar, Optional
logger = logging.getLogger(__name__)
T = TypeVar("T", bound=Comparable)
class Colour(enum.IntEnum):
"""
Represents colour of the edge or node.
"""
Red = 0
Black = 1
def has_double_black(x: Optional[Node[T]]) -> bool:
"""
Checks for double black child of x.
Args:
x: Node to be checked.
Returns:
`true`, if `x` has a double black node, `false` otherwise.
"""
return x is not None and 2 in Node.differences(x)
class RBTree(RankedTree[T]):
def is_correct_node(
self, node: Optional[Node[T]], recursive: bool = True
) -> bool:
if not node:
return True
left, right = Node.differences(node)
if left not in (Colour.Red, Colour.Black):
# left subtree has invalid difference
return False
elif right not in (Colour.Red, Colour.Black):
# right subtree has invalid difference
return False
if Node.difference(node) == Colour.Red and (
left == Colour.Red or right == Colour.Red
):
# two consecutive red nodes
return False
return not recursive or (
self.is_correct_node(node.left)
and self.is_correct_node(node.right)
)
# region InsertRebalance
def _insert_rebalance_step(
self,
z: Node[T],
y: Optional[Node[T]],
right_child: Node[T],
rotate_left: RotateFunction[T],
rotate_right: RotateFunction[T],
) -> Node[T]:
p = z.parent
pp = p.parent
if y is not None and Node.difference(y) == Colour.Red:
# Case 1
# ======
# zs uncle y is red
pp.rank += Colour.Black
z = pp
elif z == right_child:
# Case 2
# ======
# zs uncle y is black and z is a right child
z = p
rotate_left(p)
else:
# Case 3
# ======
# zs uncle y is black and z is a left child
rotate_right(pp)
return z
def _insert_rebalance(self, z: Node[T]) -> None:
while z.parent is not None and Node.difference(z.parent) == Colour.Red:
p = z.parent
pp = p.parent
assert pp
if p == pp.left:
z = self._insert_rebalance_step(
z, pp.right, p.right, self.rotate_left, self.rotate_right
)
else:
z = self._insert_rebalance_step(
z, pp.left, p.left, self.rotate_right, self.rotate_left
)
# endregion InsertRebalance
# region DeleteRebalance
def _delete_rebalance_step(
self,
x: Node[T],
w: Node[T],
parent: Node[T],
right: Callable[[Node[T]], Optional[Node[T]]],
rotate_left: RotateFunction[T],
rotate_right: RotateFunction[T],
) -> Tuple[Optional[Node[T]], Optional[Node[T]]]:
if Node.difference(w) == Colour.Red:
# Case 1
# ======
# xs sibling w is red
rotate_left(parent)
w = right(parent)
if Node.differences(w) == (Colour.Black, Colour.Black):
# Case 2
# ======
# xs sibling w is black, and both of ws children are black
parent.rank -= Colour.Black
x = parent
else:
# Case 3
# ======
# xs sibling w is black,
# ws left child is red, and ws right child is black
if Node.difference(right(w), w) == Colour.Black:
rotate_right(w)
w = right(parent)
# Case 4
# ======
# xs sibling w is black, and ws right child is red
parent.rank -= Colour.Black
w.rank += Colour.Black
rotate_left(parent)
x = self.root
return x, (x.parent if x else None)
def _delete_rebalance(
self, node: Optional[Node[T]], parent: Optional[Node[T]]
) -> None:
if not node and not parent:
return
while node != self.root and has_double_black(parent):
if node == parent.left:
node, parent = self._delete_rebalance_step(
node,
parent.right,
parent,
lambda x: x.right,
self.rotate_left,
self.rotate_right,
)
else:
node, parent = self._delete_rebalance_step(
node,
parent.left,
parent,
lambda x: x.left,
self.rotate_right,
self.rotate_left,
)
# endregion DeleteRebalance

View file

@ -1,8 +1,11 @@
from avl import AVLTree, _balance_factor from avl import AVLTree, _balance_factor
import logging import logging
import random
from typing import List from typing import List
import hypothesis
import hypothesis.strategies as st
import pytest import pytest
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View file

@ -1,10 +1,10 @@
from avl import AVLTree from avl import AVLTree
from wavl import WAVLTree from wavl import WAVLTree
from ravl import RAVLTree from ravl import RAVLTree
from rbt import RBTree
import logging import logging
import random import random
from typing import List
import hypothesis import hypothesis
import hypothesis.strategies as st import hypothesis.strategies as st
@ -14,7 +14,8 @@ logger = logging.getLogger(__name__)
@pytest.mark.parametrize( @pytest.mark.parametrize(
("RankBalancedTree"), [RBTree, AVLTree, WAVLTree, RAVLTree] ("RankBalancedTree"),
[AVLTree, WAVLTree, RAVLTree]
) )
@hypothesis.settings(max_examples=10000) @hypothesis.settings(max_examples=10000)
@hypothesis.given(values=st.sets(st.integers())) @hypothesis.given(values=st.sets(st.integers()))
@ -44,7 +45,8 @@ def _report(t_before: str, t_after: str) -> None:
@pytest.mark.parametrize( @pytest.mark.parametrize(
("RankBalancedTree"), [RBTree, AVLTree, WAVLTree, RAVLTree] ("RankBalancedTree"),
[AVLTree, WAVLTree, RAVLTree]
) )
@hypothesis.settings(max_examples=10000) @hypothesis.settings(max_examples=10000)
@hypothesis.given(config=delete_strategy()) @hypothesis.given(config=delete_strategy())

View file

@ -1,7 +1,10 @@
from ravl import RAVLTree from ravl import RAVLTree
import hypothesis
import hypothesis.strategies as st
import logging import logging
import pytest import pytest
import random
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View file

@ -1,43 +0,0 @@
from rbt import RBTree
import logging
import pytest
logger = logging.getLogger(__name__)
@pytest.mark.parametrize(
"values, delete_order",
[
([0, 1, -1], [0, -1, 1]),
([0, 1, 2, 3, 4, 5], [4, 2, 1, 0, 5, 3]),
(
[32, 1, 2, 3, 4, 5, 0, -2, -4, -3, -1],
[-4, -3, 1, 2, 5, 3, -2, 4, 32, -1, 0],
),
([0, 1, 2, -2, -3, -1], [-3, 2, 1, 0, -1, -2]),
([0, 1, 2, 3, 4, 5, -1], [4, 2, 1, 0, 5, 3, -1]),
],
)
def test_delete_minimal(values, delete_order):
tree = RBTree()
for value in values:
tree.insert(value)
for value in delete_order:
logger.info("Deleting %s", value)
before = str(tree)
tree.delete(value)
after = str(tree)
try:
assert tree.is_correct
except AssertionError:
logger.info(
f"[FAIL] Delete {value} from {values} in order {delete_order}"
)
logger.info(f"Before:\n{before}")
logger.info(f"After:\n{after}")
raise

25
wavl.py
View file

@ -1,6 +1,5 @@
from ranked_tree import RotateFunction
from avl import AVLTree from avl import AVLTree
from node import Node, NodeType, Comparable from node import Node, NodeType, Comparable, RotateFunction
import logging import logging
from typing import TypeVar, Optional from typing import TypeVar, Optional
@ -36,9 +35,11 @@ class WAVLTree(AVLTree[T]):
y: Node[T], y: Node[T],
z: Node[T], z: Node[T],
reversed: bool, reversed: bool,
rotating_around_root: bool,
rotate_left: RotateFunction[T], rotate_left: RotateFunction[T],
rotate_right: RotateFunction[T], rotate_right: RotateFunction[T],
) -> None: ) -> None:
new_root = x
v = y.left v = y.left
w = y.right w = y.right
@ -48,7 +49,9 @@ class WAVLTree(AVLTree[T]):
w_diff = Node.difference(w, y) w_diff = Node.difference(w, y)
if w_diff == 1 and y.parent: if w_diff == 1 and y.parent:
rotate_left(y.parent) new_root = rotate_left(y.parent)
if rotating_around_root:
self.root = new_root
y.promote() y.promote()
z.demote() z.demote()
@ -57,7 +60,9 @@ class WAVLTree(AVLTree[T]):
z.demote() z.demote()
elif w_diff == 2 and v.parent: elif w_diff == 2 and v.parent:
rotate_right(v.parent) rotate_right(v.parent)
rotate_left(v.parent) new_root = rotate_left(v.parent)
if rotating_around_root:
self.root = new_root
v.promote().promote() v.promote().promote()
y.demote() y.demote()
@ -92,6 +97,8 @@ class WAVLTree(AVLTree[T]):
x_diff = Node.difference(x, parent) x_diff = Node.difference(x, parent)
y_diff = Node.difference(y, parent) y_diff = Node.difference(y, parent)
rotating_around_root = parent.parent is None
parent_node_diffs = Node.differences(parent) parent_node_diffs = Node.differences(parent)
if parent_node_diffs in ((1, 3), (3, 1)): if parent_node_diffs in ((1, 3), (3, 1)):
if parent.left is x: if parent.left is x:
@ -101,8 +108,9 @@ class WAVLTree(AVLTree[T]):
parent.right, parent.right,
parent, parent,
False, False,
self.rotate_left, rotating_around_root,
self.rotate_right, Node.rotate_left,
Node.rotate_right,
) )
else: else:
assert parent.left assert parent.left
@ -111,8 +119,9 @@ class WAVLTree(AVLTree[T]):
parent.left, parent.left,
parent, parent,
True, True,
self.rotate_right, rotating_around_root,
self.rotate_left, Node.rotate_right,
Node.rotate_left,
) )
def _delete_rebalance( def _delete_rebalance(