Compare commits

...

9 commits
submit ... main

Author SHA1 Message Date
4e2e227f58
ci: add GitLab CI
Signed-off-by: Matej Focko <mfocko@redhat.com>
2022-05-24 16:24:58 +02:00
dc6905b5e2
fix(mypy): fix major issues with mypy
Signed-off-by: Matej Focko <mfocko@redhat.com>
2022-05-22 21:07:40 +02:00
e28f578fc2
fix(rotations): move rotations to the tree
Move rotations to the RankedTree, which allows us to factor out the
functionality regarding rotations when root of the tree changes.

Signed-off-by: Matej Focko <mfocko@redhat.com>
2022-05-22 21:00:18 +02:00
7b5bf0ec3a
fix(rbt): improve naming and add docstrings
Signed-off-by: Matej Focko <mfocko@redhat.com>
2022-05-22 00:38:48 +02:00
2b02eac10e
fix(rbt): use explicit colours instead of constants
Signed-off-by: Matej Focko <mfocko@redhat.com>
2022-05-20 15:47:03 +02:00
d39e11713c
feat(rbt): finish delete
Signed-off-by: Matej Focko <mfocko@redhat.com>
2022-05-20 12:02:52 +02:00
02ce49c86e
feat(rb): add mockup
Signed-off-by: Matej Focko <mfocko@redhat.com>
2022-05-20 11:50:56 +02:00
3bba030d53
feat: handle root change in rotate function
Signed-off-by: Matej Focko <mfocko@redhat.com>
2022-05-20 11:50:48 +02:00
a8de124eeb
chore(flake): fix flake remarks
Signed-off-by: Matej Focko <mfocko@redhat.com>
2022-05-20 11:50:08 +02:00
11 changed files with 351 additions and 114 deletions

42
.gitlab-ci.yml Normal file
View file

@ -0,0 +1,42 @@
image: python:latest
variables:
PIP_CACHE_DIR: "$CI_PROJECT_DIR/.cache/pip"
cache:
paths:
- .cache/pip
- venv/
before_script:
- python --version
- pip install virtualenv
- virtualenv venv
- source venv/bin/activate
black:
stage: test
script:
- pip install black
- make black
flake8:
stage: test
script:
- pip install flake8
- make flake8
pytest:
stage: test
script:
- pip install hypothesis pytest pytest-cov
- make check
coverage: '/TOTAL.*\s([.\d]+)%/'
artifacts:
when: always
reports:
junit: report.xml
coverage_report:
coverage_format: cobertura
path: coverage.xml

View file

@ -1,5 +1,5 @@
check: check:
pytest -vvv --log-level debug --cov --cov-report=term-missing pytest -vvv --log-level debug --cov --cov-report=term-missing --cov-report=xml:coverage.xml --junitxml=report.xml
lint: black flake8 mypy lint: black flake8 mypy
@ -10,7 +10,7 @@ flake8:
flake8 *.py flake8 *.py
mypy: mypy:
mypy --strict --disallow-any-explicit node.py avl.py wavl.py ravl.py mypy --strict --disallow-any-explicit node.py ranked_tree.py avl.py wavl.py ravl.py rbt.py
clean: clean:
rm -rf __pycache__ .hypothesis .mypy_cache .pytest_cache rm -rf __pycache__ .hypothesis .mypy_cache .pytest_cache

41
avl.py
View file

@ -1,5 +1,5 @@
from node import Node, Comparable, RotateFunction from node import Node, Comparable
from ranked_tree import RankedTree from ranked_tree import RankedTree, RotateFunction
import logging import logging
from typing import TypeVar, Optional from typing import TypeVar, Optional
@ -49,22 +49,18 @@ class AVLTree(RankedTree[T]):
z: Node[T], z: Node[T],
rotate_left: RotateFunction[T], rotate_left: RotateFunction[T],
rotate_right: RotateFunction[T], rotate_right: RotateFunction[T],
) -> Optional[Node[T]]: ) -> None:
new_root = x.parent
if not y or Node.difference(y) == 2: if not y or Node.difference(y) == 2:
new_root = rotate_right(z) rotate_right(z)
z.demote() z.demote()
elif Node.difference(y) == 1: elif Node.difference(y) == 1:
rotate_left(x) rotate_left(x)
new_root = rotate_right(z) rotate_right(z)
y.promote() y.promote()
x.demote() x.demote()
z.demote() z.demote()
return new_root
def _insert_rebalance(self, x: Node[T]) -> None: def _insert_rebalance(self, x: Node[T]) -> None:
diffs = Node.differences(x.parent) diffs = Node.differences(x.parent)
while x.parent and (diffs == (0, 1) or diffs == (1, 0)): while x.parent and (diffs == (0, 1) or diffs == (1, 0)):
@ -76,9 +72,6 @@ class AVLTree(RankedTree[T]):
if not x.parent: if not x.parent:
return return
rotating_around_root = x.parent.parent is None
new_root: Optional[Node[T]] = x.parent
rank_difference = Node.difference(x) rank_difference = Node.difference(x)
if rank_difference != 0: if rank_difference != 0:
return return
@ -86,17 +79,14 @@ class AVLTree(RankedTree[T]):
x_parent = x.parent x_parent = x.parent
assert x_parent is not None assert x_parent is not None
if rank_difference == 0 and x.parent.left is x: if rank_difference == 0 and x.parent.left is x:
new_root = self.__fix_0_child( self.__fix_0_child(
x, x.right, x_parent, Node.rotate_left, Node.rotate_right x, x.right, x_parent, self.rotate_left, self.rotate_right
) )
elif rank_difference == 0 and x.parent.right is x: elif rank_difference == 0 and x.parent.right is x:
new_root = self.__fix_0_child( self.__fix_0_child(
x, x.left, x_parent, Node.rotate_right, Node.rotate_left x, x.left, x_parent, self.rotate_right, self.rotate_left
) )
if rotating_around_root:
self.root = new_root
# endregion InsertRebalance # endregion InsertRebalance
# region DeleteRebalance # region DeleteRebalance
@ -106,7 +96,6 @@ class AVLTree(RankedTree[T]):
x: Node[T], x: Node[T],
y: Node[T], y: Node[T],
leaning: int, leaning: int,
rotating_around_root: bool,
rotate_left: RotateFunction[T], rotate_left: RotateFunction[T],
rotate_right: RotateFunction[T], rotate_right: RotateFunction[T],
) -> bool: ) -> bool:
@ -122,12 +111,10 @@ class AVLTree(RankedTree[T]):
for n in filter(None, (new_root.left, new_root.right, new_root)): for n in filter(None, (new_root.left, new_root.right, new_root)):
_update_rank(n) _update_rank(n)
if rotating_around_root:
self.root = new_root
return factor != 0 return factor != 0
def __delete_fixup( def __delete_fixup(
self, x: Optional[Node[T]], parent: Optional[Node[T]] = None self, x: Node[T], parent: Optional[Node[T]] = None
) -> bool: ) -> bool:
factor = _balance_factor(x) factor = _balance_factor(x)
if factor == 0: if factor == 0:
@ -136,18 +123,16 @@ class AVLTree(RankedTree[T]):
elif factor in (-1, 1): elif factor in (-1, 1):
return False return False
rotating_around_root = x.parent is None
y, leaning, to_left, to_right = ( y, leaning, to_left, to_right = (
(x.right, 1, Node.rotate_left, Node.rotate_right) (x.right, 1, self.rotate_left, self.rotate_right)
if factor == 2 if factor == 2
else (x.left, -1, Node.rotate_right, Node.rotate_left) else (x.left, -1, self.rotate_right, self.rotate_left)
) )
assert y assert y
return self.__delete_rotate( return self.__delete_rotate(
x, x,
y, y,
leaning, leaning,
rotating_around_root,
to_left, to_left,
to_right, to_right,
) )
@ -158,7 +143,7 @@ class AVLTree(RankedTree[T]):
if not node and not parent: if not node and not parent:
return return
if not node: if not node and parent:
node, parent = parent, parent.parent node, parent = parent, parent.parent
while node and self.__delete_fixup(node, parent): while node and self.__delete_fixup(node, parent):

57
node.py
View file

@ -1,8 +1,8 @@
from abc import abstractmethod from abc import abstractmethod
import enum import enum
from typing import ( from typing import (
Callable,
Iterable, Iterable,
Iterator,
Optional, Optional,
Generic, Generic,
Tuple, Tuple,
@ -28,7 +28,7 @@ class NodeType(enum.IntEnum):
T = TypeVar("T", bound=Comparable) T = TypeVar("T", bound=Comparable)
class Node(Generic[T]): class Node(Generic[T], Iterable[T]):
def __init__( def __init__(
self, self,
value: T, value: T,
@ -62,7 +62,7 @@ class Node(Generic[T]):
def __str__(self) -> str: def __str__(self) -> str:
return f"Node(value={self.value}, rank={self.rank})" return f"Node(value={self.value}, rank={self.rank})"
def __iter__(self) -> Iterable[T]: def __iter__(self) -> Iterator[T]:
""" """
Yields: Yields:
Keys from the subtree rooted at the node in an inorder fashion. Keys from the subtree rooted at the node in an inorder fashion.
@ -111,56 +111,10 @@ class Node(Generic[T]):
self.rank -= 1 self.rank -= 1
return self return self
@staticmethod
def rotate_right(x: "Node[T]") -> "Node[T]":
parent = x.parent
y = x.left
# z = x.right
assert y is not None
if parent:
if parent.left is x:
parent.left = y
else:
parent.right = y
x.left = y.right
if x.left:
x.left.parent = x
y.right = x
x.parent = y
y.parent = parent
return y
@staticmethod
def rotate_left(x: "Node[T]") -> "Node[T]":
parent = x.parent
# y = x.left
z = x.right
assert z is not None
if parent:
if parent.left is x:
parent.left = z
else:
parent.right = z
x.right = z.left
if x.right:
x.right.parent = x
z.left = x
x.parent = z
z.parent = parent
return z
@staticmethod @staticmethod
def find_parent_node( def find_parent_node(
value: T, node: "Node[T]", missing: bool = True value: T, node: "Node[T]", missing: bool = True
) -> "Node[T]": ) -> "Optional[Node[T]]":
new_node: "Optional[Node[T]]" = node new_node: "Optional[Node[T]]" = node
while new_node and (missing or new_node.value != value): while new_node and (missing or new_node.value != value):
@ -194,6 +148,3 @@ class Node(Generic[T]):
node = node.right node = node.right
return node return node
RotateFunction = Callable[[Node[T]], Node[T]]

View file

@ -3,7 +3,7 @@ from node import Node, Comparable
from abc import abstractmethod from abc import abstractmethod
from collections import deque from collections import deque
import logging import logging
from typing import Deque, Iterable, Optional, Tuple, TypeVar, Generic from typing import Callable, Deque, Iterable, Optional, Tuple, TypeVar, Generic
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
T = TypeVar("T", bound=Comparable) T = TypeVar("T", bound=Comparable)
@ -58,6 +58,58 @@ class RankedTree(Generic[T]):
# endregion TreeSpecificMethods # endregion TreeSpecificMethods
# region Rotations
def rotate_right(self, x: "Node[T]") -> "Node[T]":
parent = x.parent
y = x.left
# z = x.right
assert y is not None
if parent:
if parent.left is x:
parent.left = y
else:
parent.right = y
else:
self.root = y
x.left = y.right
if x.left:
x.left.parent = x
y.right = x
x.parent = y
y.parent = parent
return y
def rotate_left(self, x: "Node[T]") -> "Node[T]":
parent = x.parent
# y = x.left
z = x.right
assert z is not None
if parent:
if parent.left is x:
parent.left = z
else:
parent.right = z
else:
self.root = z
x.right = z.left
if x.right:
x.right.parent = x
z.left = x
x.parent = z
z.parent = parent
return z
# endregion Rotations
@property @property
def rank(self) -> int: def rank(self) -> int:
return Node.get_rank(self.root) return Node.get_rank(self.root)
@ -122,8 +174,7 @@ class RankedTree(Generic[T]):
self._transplant(node, node.left) self._transplant(node, node.left)
else: else:
n = Node.minimum(node.right) n = Node.minimum(node.right)
node.value = n.value node.value, n.value = n.value, node.value
n.value = None
return self._delete_node(n) return self._delete_node(n)
return (y, parent) return (y, parent)
@ -148,3 +199,6 @@ class RankedTree(Generic[T]):
""" """
if self.root: if self.root:
yield from self.root yield from self.root
RotateFunction = Callable[[Node[T]], Node[T]]

179
rbt.py Normal file
View file

@ -0,0 +1,179 @@
from node import Node, Comparable
from ranked_tree import RankedTree, RotateFunction
import enum
import logging
from typing import Callable, Tuple, TypeVar, Optional
logger = logging.getLogger(__name__)
T = TypeVar("T", bound=Comparable)
class Colour(enum.IntEnum):
"""
Represents colour of the edge or node.
"""
Red = 0
Black = 1
def has_double_black(x: Optional[Node[T]]) -> bool:
"""
Checks for double black child of x.
Args:
x: Node to be checked.
Returns:
`true`, if `x` has a double black node, `false` otherwise.
"""
return x is not None and 2 in Node.differences(x)
class RBTree(RankedTree[T]):
def is_correct_node(
self, node: Optional[Node[T]], recursive: bool = True
) -> bool:
if not node:
return True
left, right = Node.differences(node)
if left not in (Colour.Red, Colour.Black):
# left subtree has invalid difference
return False
elif right not in (Colour.Red, Colour.Black):
# right subtree has invalid difference
return False
if Node.difference(node) == Colour.Red and (
left == Colour.Red or right == Colour.Red
):
# two consecutive red nodes
return False
return not recursive or (
self.is_correct_node(node.left)
and self.is_correct_node(node.right)
)
# region InsertRebalance
def _insert_rebalance_step(
self,
z: Node[T],
y: Optional[Node[T]],
right_child: Node[T],
rotate_left: RotateFunction[T],
rotate_right: RotateFunction[T],
) -> Node[T]:
p = z.parent
pp = p.parent
if y is not None and Node.difference(y) == Colour.Red:
# Case 1
# ======
# zs uncle y is red
pp.rank += Colour.Black
z = pp
elif z == right_child:
# Case 2
# ======
# zs uncle y is black and z is a right child
z = p
rotate_left(p)
else:
# Case 3
# ======
# zs uncle y is black and z is a left child
rotate_right(pp)
return z
def _insert_rebalance(self, z: Node[T]) -> None:
while z.parent is not None and Node.difference(z.parent) == Colour.Red:
p = z.parent
pp = p.parent
assert pp
if p == pp.left:
z = self._insert_rebalance_step(
z, pp.right, p.right, self.rotate_left, self.rotate_right
)
else:
z = self._insert_rebalance_step(
z, pp.left, p.left, self.rotate_right, self.rotate_left
)
# endregion InsertRebalance
# region DeleteRebalance
def _delete_rebalance_step(
self,
x: Node[T],
w: Node[T],
parent: Node[T],
right: Callable[[Node[T]], Optional[Node[T]]],
rotate_left: RotateFunction[T],
rotate_right: RotateFunction[T],
) -> Tuple[Optional[Node[T]], Optional[Node[T]]]:
if Node.difference(w) == Colour.Red:
# Case 1
# ======
# xs sibling w is red
rotate_left(parent)
w = right(parent)
if Node.differences(w) == (Colour.Black, Colour.Black):
# Case 2
# ======
# xs sibling w is black, and both of ws children are black
parent.rank -= Colour.Black
x = parent
else:
# Case 3
# ======
# xs sibling w is black,
# ws left child is red, and ws right child is black
if Node.difference(right(w), w) == Colour.Black:
rotate_right(w)
w = right(parent)
# Case 4
# ======
# xs sibling w is black, and ws right child is red
parent.rank -= Colour.Black
w.rank += Colour.Black
rotate_left(parent)
x = self.root
return x, (x.parent if x else None)
def _delete_rebalance(
self, node: Optional[Node[T]], parent: Optional[Node[T]]
) -> None:
if not node and not parent:
return
while node != self.root and has_double_black(parent):
if node == parent.left:
node, parent = self._delete_rebalance_step(
node,
parent.right,
parent,
lambda x: x.right,
self.rotate_left,
self.rotate_right,
)
else:
node, parent = self._delete_rebalance_step(
node,
parent.left,
parent,
lambda x: x.left,
self.rotate_right,
self.rotate_left,
)
# endregion DeleteRebalance

View file

@ -1,11 +1,8 @@
from avl import AVLTree, _balance_factor from avl import AVLTree, _balance_factor
import logging import logging
import random
from typing import List from typing import List
import hypothesis
import hypothesis.strategies as st
import pytest import pytest
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View file

@ -1,10 +1,10 @@
from avl import AVLTree from avl import AVLTree
from wavl import WAVLTree from wavl import WAVLTree
from ravl import RAVLTree from ravl import RAVLTree
from rbt import RBTree
import logging import logging
import random import random
from typing import List
import hypothesis import hypothesis
import hypothesis.strategies as st import hypothesis.strategies as st
@ -14,8 +14,7 @@ logger = logging.getLogger(__name__)
@pytest.mark.parametrize( @pytest.mark.parametrize(
("RankBalancedTree"), ("RankBalancedTree"), [RBTree, AVLTree, WAVLTree, RAVLTree]
[AVLTree, WAVLTree, RAVLTree]
) )
@hypothesis.settings(max_examples=10000) @hypothesis.settings(max_examples=10000)
@hypothesis.given(values=st.sets(st.integers())) @hypothesis.given(values=st.sets(st.integers()))
@ -45,8 +44,7 @@ def _report(t_before: str, t_after: str) -> None:
@pytest.mark.parametrize( @pytest.mark.parametrize(
("RankBalancedTree"), ("RankBalancedTree"), [RBTree, AVLTree, WAVLTree, RAVLTree]
[AVLTree, WAVLTree, RAVLTree]
) )
@hypothesis.settings(max_examples=10000) @hypothesis.settings(max_examples=10000)
@hypothesis.given(config=delete_strategy()) @hypothesis.given(config=delete_strategy())

View file

@ -1,10 +1,7 @@
from ravl import RAVLTree from ravl import RAVLTree
import hypothesis
import hypothesis.strategies as st
import logging import logging
import pytest import pytest
import random
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

43
test_rbt.py Normal file
View file

@ -0,0 +1,43 @@
from rbt import RBTree
import logging
import pytest
logger = logging.getLogger(__name__)
@pytest.mark.parametrize(
"values, delete_order",
[
([0, 1, -1], [0, -1, 1]),
([0, 1, 2, 3, 4, 5], [4, 2, 1, 0, 5, 3]),
(
[32, 1, 2, 3, 4, 5, 0, -2, -4, -3, -1],
[-4, -3, 1, 2, 5, 3, -2, 4, 32, -1, 0],
),
([0, 1, 2, -2, -3, -1], [-3, 2, 1, 0, -1, -2]),
([0, 1, 2, 3, 4, 5, -1], [4, 2, 1, 0, 5, 3, -1]),
],
)
def test_delete_minimal(values, delete_order):
tree = RBTree()
for value in values:
tree.insert(value)
for value in delete_order:
logger.info("Deleting %s", value)
before = str(tree)
tree.delete(value)
after = str(tree)
try:
assert tree.is_correct
except AssertionError:
logger.info(
f"[FAIL] Delete {value} from {values} in order {delete_order}"
)
logger.info(f"Before:\n{before}")
logger.info(f"After:\n{after}")
raise

25
wavl.py
View file

@ -1,5 +1,6 @@
from ranked_tree import RotateFunction
from avl import AVLTree from avl import AVLTree
from node import Node, NodeType, Comparable, RotateFunction from node import Node, NodeType, Comparable
import logging import logging
from typing import TypeVar, Optional from typing import TypeVar, Optional
@ -35,11 +36,9 @@ class WAVLTree(AVLTree[T]):
y: Node[T], y: Node[T],
z: Node[T], z: Node[T],
reversed: bool, reversed: bool,
rotating_around_root: bool,
rotate_left: RotateFunction[T], rotate_left: RotateFunction[T],
rotate_right: RotateFunction[T], rotate_right: RotateFunction[T],
) -> None: ) -> None:
new_root = x
v = y.left v = y.left
w = y.right w = y.right
@ -49,9 +48,7 @@ class WAVLTree(AVLTree[T]):
w_diff = Node.difference(w, y) w_diff = Node.difference(w, y)
if w_diff == 1 and y.parent: if w_diff == 1 and y.parent:
new_root = rotate_left(y.parent) rotate_left(y.parent)
if rotating_around_root:
self.root = new_root
y.promote() y.promote()
z.demote() z.demote()
@ -60,9 +57,7 @@ class WAVLTree(AVLTree[T]):
z.demote() z.demote()
elif w_diff == 2 and v.parent: elif w_diff == 2 and v.parent:
rotate_right(v.parent) rotate_right(v.parent)
new_root = rotate_left(v.parent) rotate_left(v.parent)
if rotating_around_root:
self.root = new_root
v.promote().promote() v.promote().promote()
y.demote() y.demote()
@ -97,8 +92,6 @@ class WAVLTree(AVLTree[T]):
x_diff = Node.difference(x, parent) x_diff = Node.difference(x, parent)
y_diff = Node.difference(y, parent) y_diff = Node.difference(y, parent)
rotating_around_root = parent.parent is None
parent_node_diffs = Node.differences(parent) parent_node_diffs = Node.differences(parent)
if parent_node_diffs in ((1, 3), (3, 1)): if parent_node_diffs in ((1, 3), (3, 1)):
if parent.left is x: if parent.left is x:
@ -108,9 +101,8 @@ class WAVLTree(AVLTree[T]):
parent.right, parent.right,
parent, parent,
False, False,
rotating_around_root, self.rotate_left,
Node.rotate_left, self.rotate_right,
Node.rotate_right,
) )
else: else:
assert parent.left assert parent.left
@ -119,9 +111,8 @@ class WAVLTree(AVLTree[T]):
parent.left, parent.left,
parent, parent,
True, True,
rotating_around_root, self.rotate_right,
Node.rotate_right, self.rotate_left,
Node.rotate_left,
) )
def _delete_rebalance( def _delete_rebalance(