Compare commits

...

9 commits
submit ... main

Author SHA1 Message Date
4e2e227f58
ci: add GitLab CI
Signed-off-by: Matej Focko <mfocko@redhat.com>
2022-05-24 16:24:58 +02:00
dc6905b5e2
fix(mypy): fix major issues with mypy
Signed-off-by: Matej Focko <mfocko@redhat.com>
2022-05-22 21:07:40 +02:00
e28f578fc2
fix(rotations): move rotations to the tree
Move rotations to the RankedTree, which allows us to factor out the
functionality regarding rotations when root of the tree changes.

Signed-off-by: Matej Focko <mfocko@redhat.com>
2022-05-22 21:00:18 +02:00
7b5bf0ec3a
fix(rbt): improve naming and add docstrings
Signed-off-by: Matej Focko <mfocko@redhat.com>
2022-05-22 00:38:48 +02:00
2b02eac10e
fix(rbt): use explicit colours instead of constants
Signed-off-by: Matej Focko <mfocko@redhat.com>
2022-05-20 15:47:03 +02:00
d39e11713c
feat(rbt): finish delete
Signed-off-by: Matej Focko <mfocko@redhat.com>
2022-05-20 12:02:52 +02:00
02ce49c86e
feat(rb): add mockup
Signed-off-by: Matej Focko <mfocko@redhat.com>
2022-05-20 11:50:56 +02:00
3bba030d53
feat: handle root change in rotate function
Signed-off-by: Matej Focko <mfocko@redhat.com>
2022-05-20 11:50:48 +02:00
a8de124eeb
chore(flake): fix flake remarks
Signed-off-by: Matej Focko <mfocko@redhat.com>
2022-05-20 11:50:08 +02:00
11 changed files with 351 additions and 114 deletions

42
.gitlab-ci.yml Normal file
View file

@ -0,0 +1,42 @@
image: python:latest
variables:
PIP_CACHE_DIR: "$CI_PROJECT_DIR/.cache/pip"
cache:
paths:
- .cache/pip
- venv/
before_script:
- python --version
- pip install virtualenv
- virtualenv venv
- source venv/bin/activate
black:
stage: test
script:
- pip install black
- make black
flake8:
stage: test
script:
- pip install flake8
- make flake8
pytest:
stage: test
script:
- pip install hypothesis pytest pytest-cov
- make check
coverage: '/TOTAL.*\s([.\d]+)%/'
artifacts:
when: always
reports:
junit: report.xml
coverage_report:
coverage_format: cobertura
path: coverage.xml

View file

@ -1,5 +1,5 @@
check:
pytest -vvv --log-level debug --cov --cov-report=term-missing
pytest -vvv --log-level debug --cov --cov-report=term-missing --cov-report=xml:coverage.xml --junitxml=report.xml
lint: black flake8 mypy
@ -10,7 +10,7 @@ flake8:
flake8 *.py
mypy:
mypy --strict --disallow-any-explicit node.py avl.py wavl.py ravl.py
mypy --strict --disallow-any-explicit node.py ranked_tree.py avl.py wavl.py ravl.py rbt.py
clean:
rm -rf __pycache__ .hypothesis .mypy_cache .pytest_cache

41
avl.py
View file

@ -1,5 +1,5 @@
from node import Node, Comparable, RotateFunction
from ranked_tree import RankedTree
from node import Node, Comparable
from ranked_tree import RankedTree, RotateFunction
import logging
from typing import TypeVar, Optional
@ -49,22 +49,18 @@ class AVLTree(RankedTree[T]):
z: Node[T],
rotate_left: RotateFunction[T],
rotate_right: RotateFunction[T],
) -> Optional[Node[T]]:
new_root = x.parent
) -> None:
if not y or Node.difference(y) == 2:
new_root = rotate_right(z)
rotate_right(z)
z.demote()
elif Node.difference(y) == 1:
rotate_left(x)
new_root = rotate_right(z)
rotate_right(z)
y.promote()
x.demote()
z.demote()
return new_root
def _insert_rebalance(self, x: Node[T]) -> None:
diffs = Node.differences(x.parent)
while x.parent and (diffs == (0, 1) or diffs == (1, 0)):
@ -76,9 +72,6 @@ class AVLTree(RankedTree[T]):
if not x.parent:
return
rotating_around_root = x.parent.parent is None
new_root: Optional[Node[T]] = x.parent
rank_difference = Node.difference(x)
if rank_difference != 0:
return
@ -86,17 +79,14 @@ class AVLTree(RankedTree[T]):
x_parent = x.parent
assert x_parent is not None
if rank_difference == 0 and x.parent.left is x:
new_root = self.__fix_0_child(
x, x.right, x_parent, Node.rotate_left, Node.rotate_right
self.__fix_0_child(
x, x.right, x_parent, self.rotate_left, self.rotate_right
)
elif rank_difference == 0 and x.parent.right is x:
new_root = self.__fix_0_child(
x, x.left, x_parent, Node.rotate_right, Node.rotate_left
self.__fix_0_child(
x, x.left, x_parent, self.rotate_right, self.rotate_left
)
if rotating_around_root:
self.root = new_root
# endregion InsertRebalance
# region DeleteRebalance
@ -106,7 +96,6 @@ class AVLTree(RankedTree[T]):
x: Node[T],
y: Node[T],
leaning: int,
rotating_around_root: bool,
rotate_left: RotateFunction[T],
rotate_right: RotateFunction[T],
) -> bool:
@ -122,12 +111,10 @@ class AVLTree(RankedTree[T]):
for n in filter(None, (new_root.left, new_root.right, new_root)):
_update_rank(n)
if rotating_around_root:
self.root = new_root
return factor != 0
def __delete_fixup(
self, x: Optional[Node[T]], parent: Optional[Node[T]] = None
self, x: Node[T], parent: Optional[Node[T]] = None
) -> bool:
factor = _balance_factor(x)
if factor == 0:
@ -136,18 +123,16 @@ class AVLTree(RankedTree[T]):
elif factor in (-1, 1):
return False
rotating_around_root = x.parent is None
y, leaning, to_left, to_right = (
(x.right, 1, Node.rotate_left, Node.rotate_right)
(x.right, 1, self.rotate_left, self.rotate_right)
if factor == 2
else (x.left, -1, Node.rotate_right, Node.rotate_left)
else (x.left, -1, self.rotate_right, self.rotate_left)
)
assert y
return self.__delete_rotate(
x,
y,
leaning,
rotating_around_root,
to_left,
to_right,
)
@ -158,7 +143,7 @@ class AVLTree(RankedTree[T]):
if not node and not parent:
return
if not node:
if not node and parent:
node, parent = parent, parent.parent
while node and self.__delete_fixup(node, parent):

57
node.py
View file

@ -1,8 +1,8 @@
from abc import abstractmethod
import enum
from typing import (
Callable,
Iterable,
Iterator,
Optional,
Generic,
Tuple,
@ -28,7 +28,7 @@ class NodeType(enum.IntEnum):
T = TypeVar("T", bound=Comparable)
class Node(Generic[T]):
class Node(Generic[T], Iterable[T]):
def __init__(
self,
value: T,
@ -62,7 +62,7 @@ class Node(Generic[T]):
def __str__(self) -> str:
return f"Node(value={self.value}, rank={self.rank})"
def __iter__(self) -> Iterable[T]:
def __iter__(self) -> Iterator[T]:
"""
Yields:
Keys from the subtree rooted at the node in an inorder fashion.
@ -111,56 +111,10 @@ class Node(Generic[T]):
self.rank -= 1
return self
@staticmethod
def rotate_right(x: "Node[T]") -> "Node[T]":
parent = x.parent
y = x.left
# z = x.right
assert y is not None
if parent:
if parent.left is x:
parent.left = y
else:
parent.right = y
x.left = y.right
if x.left:
x.left.parent = x
y.right = x
x.parent = y
y.parent = parent
return y
@staticmethod
def rotate_left(x: "Node[T]") -> "Node[T]":
parent = x.parent
# y = x.left
z = x.right
assert z is not None
if parent:
if parent.left is x:
parent.left = z
else:
parent.right = z
x.right = z.left
if x.right:
x.right.parent = x
z.left = x
x.parent = z
z.parent = parent
return z
@staticmethod
def find_parent_node(
value: T, node: "Node[T]", missing: bool = True
) -> "Node[T]":
) -> "Optional[Node[T]]":
new_node: "Optional[Node[T]]" = node
while new_node and (missing or new_node.value != value):
@ -194,6 +148,3 @@ class Node(Generic[T]):
node = node.right
return node
RotateFunction = Callable[[Node[T]], Node[T]]

View file

@ -3,7 +3,7 @@ from node import Node, Comparable
from abc import abstractmethod
from collections import deque
import logging
from typing import Deque, Iterable, Optional, Tuple, TypeVar, Generic
from typing import Callable, Deque, Iterable, Optional, Tuple, TypeVar, Generic
logger = logging.getLogger(__name__)
T = TypeVar("T", bound=Comparable)
@ -58,6 +58,58 @@ class RankedTree(Generic[T]):
# endregion TreeSpecificMethods
# region Rotations
def rotate_right(self, x: "Node[T]") -> "Node[T]":
parent = x.parent
y = x.left
# z = x.right
assert y is not None
if parent:
if parent.left is x:
parent.left = y
else:
parent.right = y
else:
self.root = y
x.left = y.right
if x.left:
x.left.parent = x
y.right = x
x.parent = y
y.parent = parent
return y
def rotate_left(self, x: "Node[T]") -> "Node[T]":
parent = x.parent
# y = x.left
z = x.right
assert z is not None
if parent:
if parent.left is x:
parent.left = z
else:
parent.right = z
else:
self.root = z
x.right = z.left
if x.right:
x.right.parent = x
z.left = x
x.parent = z
z.parent = parent
return z
# endregion Rotations
@property
def rank(self) -> int:
return Node.get_rank(self.root)
@ -122,8 +174,7 @@ class RankedTree(Generic[T]):
self._transplant(node, node.left)
else:
n = Node.minimum(node.right)
node.value = n.value
n.value = None
node.value, n.value = n.value, node.value
return self._delete_node(n)
return (y, parent)
@ -148,3 +199,6 @@ class RankedTree(Generic[T]):
"""
if self.root:
yield from self.root
RotateFunction = Callable[[Node[T]], Node[T]]

179
rbt.py Normal file
View file

@ -0,0 +1,179 @@
from node import Node, Comparable
from ranked_tree import RankedTree, RotateFunction
import enum
import logging
from typing import Callable, Tuple, TypeVar, Optional
logger = logging.getLogger(__name__)
T = TypeVar("T", bound=Comparable)
class Colour(enum.IntEnum):
"""
Represents colour of the edge or node.
"""
Red = 0
Black = 1
def has_double_black(x: Optional[Node[T]]) -> bool:
"""
Checks for double black child of x.
Args:
x: Node to be checked.
Returns:
`true`, if `x` has a double black node, `false` otherwise.
"""
return x is not None and 2 in Node.differences(x)
class RBTree(RankedTree[T]):
def is_correct_node(
self, node: Optional[Node[T]], recursive: bool = True
) -> bool:
if not node:
return True
left, right = Node.differences(node)
if left not in (Colour.Red, Colour.Black):
# left subtree has invalid difference
return False
elif right not in (Colour.Red, Colour.Black):
# right subtree has invalid difference
return False
if Node.difference(node) == Colour.Red and (
left == Colour.Red or right == Colour.Red
):
# two consecutive red nodes
return False
return not recursive or (
self.is_correct_node(node.left)
and self.is_correct_node(node.right)
)
# region InsertRebalance
def _insert_rebalance_step(
self,
z: Node[T],
y: Optional[Node[T]],
right_child: Node[T],
rotate_left: RotateFunction[T],
rotate_right: RotateFunction[T],
) -> Node[T]:
p = z.parent
pp = p.parent
if y is not None and Node.difference(y) == Colour.Red:
# Case 1
# ======
# zs uncle y is red
pp.rank += Colour.Black
z = pp
elif z == right_child:
# Case 2
# ======
# zs uncle y is black and z is a right child
z = p
rotate_left(p)
else:
# Case 3
# ======
# zs uncle y is black and z is a left child
rotate_right(pp)
return z
def _insert_rebalance(self, z: Node[T]) -> None:
while z.parent is not None and Node.difference(z.parent) == Colour.Red:
p = z.parent
pp = p.parent
assert pp
if p == pp.left:
z = self._insert_rebalance_step(
z, pp.right, p.right, self.rotate_left, self.rotate_right
)
else:
z = self._insert_rebalance_step(
z, pp.left, p.left, self.rotate_right, self.rotate_left
)
# endregion InsertRebalance
# region DeleteRebalance
def _delete_rebalance_step(
self,
x: Node[T],
w: Node[T],
parent: Node[T],
right: Callable[[Node[T]], Optional[Node[T]]],
rotate_left: RotateFunction[T],
rotate_right: RotateFunction[T],
) -> Tuple[Optional[Node[T]], Optional[Node[T]]]:
if Node.difference(w) == Colour.Red:
# Case 1
# ======
# xs sibling w is red
rotate_left(parent)
w = right(parent)
if Node.differences(w) == (Colour.Black, Colour.Black):
# Case 2
# ======
# xs sibling w is black, and both of ws children are black
parent.rank -= Colour.Black
x = parent
else:
# Case 3
# ======
# xs sibling w is black,
# ws left child is red, and ws right child is black
if Node.difference(right(w), w) == Colour.Black:
rotate_right(w)
w = right(parent)
# Case 4
# ======
# xs sibling w is black, and ws right child is red
parent.rank -= Colour.Black
w.rank += Colour.Black
rotate_left(parent)
x = self.root
return x, (x.parent if x else None)
def _delete_rebalance(
self, node: Optional[Node[T]], parent: Optional[Node[T]]
) -> None:
if not node and not parent:
return
while node != self.root and has_double_black(parent):
if node == parent.left:
node, parent = self._delete_rebalance_step(
node,
parent.right,
parent,
lambda x: x.right,
self.rotate_left,
self.rotate_right,
)
else:
node, parent = self._delete_rebalance_step(
node,
parent.left,
parent,
lambda x: x.left,
self.rotate_right,
self.rotate_left,
)
# endregion DeleteRebalance

View file

@ -1,11 +1,8 @@
from avl import AVLTree, _balance_factor
import logging
import random
from typing import List
import hypothesis
import hypothesis.strategies as st
import pytest
logger = logging.getLogger(__name__)

View file

@ -1,10 +1,10 @@
from avl import AVLTree
from wavl import WAVLTree
from ravl import RAVLTree
from rbt import RBTree
import logging
import random
from typing import List
import hypothesis
import hypothesis.strategies as st
@ -14,8 +14,7 @@ logger = logging.getLogger(__name__)
@pytest.mark.parametrize(
("RankBalancedTree"),
[AVLTree, WAVLTree, RAVLTree]
("RankBalancedTree"), [RBTree, AVLTree, WAVLTree, RAVLTree]
)
@hypothesis.settings(max_examples=10000)
@hypothesis.given(values=st.sets(st.integers()))
@ -45,8 +44,7 @@ def _report(t_before: str, t_after: str) -> None:
@pytest.mark.parametrize(
("RankBalancedTree"),
[AVLTree, WAVLTree, RAVLTree]
("RankBalancedTree"), [RBTree, AVLTree, WAVLTree, RAVLTree]
)
@hypothesis.settings(max_examples=10000)
@hypothesis.given(config=delete_strategy())

View file

@ -1,10 +1,7 @@
from ravl import RAVLTree
import hypothesis
import hypothesis.strategies as st
import logging
import pytest
import random
logger = logging.getLogger(__name__)

43
test_rbt.py Normal file
View file

@ -0,0 +1,43 @@
from rbt import RBTree
import logging
import pytest
logger = logging.getLogger(__name__)
@pytest.mark.parametrize(
"values, delete_order",
[
([0, 1, -1], [0, -1, 1]),
([0, 1, 2, 3, 4, 5], [4, 2, 1, 0, 5, 3]),
(
[32, 1, 2, 3, 4, 5, 0, -2, -4, -3, -1],
[-4, -3, 1, 2, 5, 3, -2, 4, 32, -1, 0],
),
([0, 1, 2, -2, -3, -1], [-3, 2, 1, 0, -1, -2]),
([0, 1, 2, 3, 4, 5, -1], [4, 2, 1, 0, 5, 3, -1]),
],
)
def test_delete_minimal(values, delete_order):
tree = RBTree()
for value in values:
tree.insert(value)
for value in delete_order:
logger.info("Deleting %s", value)
before = str(tree)
tree.delete(value)
after = str(tree)
try:
assert tree.is_correct
except AssertionError:
logger.info(
f"[FAIL] Delete {value} from {values} in order {delete_order}"
)
logger.info(f"Before:\n{before}")
logger.info(f"After:\n{after}")
raise

25
wavl.py
View file

@ -1,5 +1,6 @@
from ranked_tree import RotateFunction
from avl import AVLTree
from node import Node, NodeType, Comparable, RotateFunction
from node import Node, NodeType, Comparable
import logging
from typing import TypeVar, Optional
@ -35,11 +36,9 @@ class WAVLTree(AVLTree[T]):
y: Node[T],
z: Node[T],
reversed: bool,
rotating_around_root: bool,
rotate_left: RotateFunction[T],
rotate_right: RotateFunction[T],
) -> None:
new_root = x
v = y.left
w = y.right
@ -49,9 +48,7 @@ class WAVLTree(AVLTree[T]):
w_diff = Node.difference(w, y)
if w_diff == 1 and y.parent:
new_root = rotate_left(y.parent)
if rotating_around_root:
self.root = new_root
rotate_left(y.parent)
y.promote()
z.demote()
@ -60,9 +57,7 @@ class WAVLTree(AVLTree[T]):
z.demote()
elif w_diff == 2 and v.parent:
rotate_right(v.parent)
new_root = rotate_left(v.parent)
if rotating_around_root:
self.root = new_root
rotate_left(v.parent)
v.promote().promote()
y.demote()
@ -97,8 +92,6 @@ class WAVLTree(AVLTree[T]):
x_diff = Node.difference(x, parent)
y_diff = Node.difference(y, parent)
rotating_around_root = parent.parent is None
parent_node_diffs = Node.differences(parent)
if parent_node_diffs in ((1, 3), (3, 1)):
if parent.left is x:
@ -108,9 +101,8 @@ class WAVLTree(AVLTree[T]):
parent.right,
parent,
False,
rotating_around_root,
Node.rotate_left,
Node.rotate_right,
self.rotate_left,
self.rotate_right,
)
else:
assert parent.left
@ -119,9 +111,8 @@ class WAVLTree(AVLTree[T]):
parent.left,
parent,
True,
rotating_around_root,
Node.rotate_right,
Node.rotate_left,
self.rotate_right,
self.rotate_left,
)
def _delete_rebalance(