feat: add avl and split wavl

Signed-off-by: Matej Focko <mfocko@redhat.com>
This commit is contained in:
Matej Focko 2022-01-30 15:03:34 +01:00
parent f960a24eb6
commit 4a832e1de8
No known key found for this signature in database
GPG key ID: 332171FADF1DB90B
3 changed files with 350 additions and 174 deletions

165
avl.py Normal file
View file

@ -0,0 +1,165 @@
from node import Node, Comparable, RotateFunction
from ranked_tree import RankedTree
import logging
from typing import TypeVar, Optional
logger = logging.getLogger(__name__)
T = TypeVar("T", bound=Comparable)
def _balance_factor(node: Optional[Node[T]]) -> int:
if not node:
return 0
left, right = Node.get_rank(node.left), Node.get_rank(node.right)
return right - left
def _update_rank(node: Node[T]) -> None:
left, right = Node.get_rank(node.left), Node.get_rank(node.right)
# logger.debug(f"[_update_rank] on {node} = ({left}, {right})")
node.rank = 1 + max(left, right)
class AVLTree(RankedTree[T]):
def is_correct_node(
self, node: Optional[Node[T]], recursive: bool = True
) -> bool:
if not node:
return True
if not (-1 <= _balance_factor(node) <= 1):
return False
return not recursive or (
self.is_correct_node(node.left)
and self.is_correct_node(node.right)
)
# region InsertRebalance
def __fix_0_child(
self,
x: Node[T],
y: Optional[Node[T]],
z: Node[T],
rotate_left: RotateFunction[T],
rotate_right: RotateFunction[T],
) -> Optional[Node[T]]:
new_root = x.parent
if not y or Node.difference(y) == 2:
new_root = rotate_right(z)
z.demote()
elif Node.difference(y) == 1:
rotate_left(x)
new_root = rotate_right(z)
y.promote()
x.demote()
z.demote()
return new_root
def _insert_rebalance(self, x: Node[T]) -> None:
diffs = Node.differences(x.parent)
while x.parent and (diffs == (0, 1) or diffs == (1, 0)):
x.parent.promote()
x = x.parent
diffs = Node.differences(x.parent)
if not x.parent:
return
rotating_around_root = x.parent.parent is None
new_root: Optional[Node[T]] = x.parent
rank_difference = Node.difference(x)
if rank_difference != 0:
return
x_parent = x.parent
assert x_parent is not None
if rank_difference == 0 and x.parent.left is x:
new_root = self.__fix_0_child(
x, x.right, x_parent, Node.rotate_left, Node.rotate_right
)
elif rank_difference == 0 and x.parent.right is x:
new_root = self.__fix_0_child(
x, x.left, x_parent, Node.rotate_right, Node.rotate_left
)
if rotating_around_root:
self.root = new_root
# endregion InsertRebalance
# region DeleteRebalance
def __delete_rotate(
self,
x: Node[T],
y: Node[T],
leaning: int,
rotating_around_root: bool,
rotate_left: RotateFunction[T],
rotate_right: RotateFunction[T],
) -> bool:
new_root = x
factor = _balance_factor(y)
if factor in (0, leaning):
new_root = rotate_left(x)
else:
rotate_right(y)
new_root = rotate_left(x)
for n in filter(None, (new_root.left, new_root.right, new_root)):
_update_rank(n)
if rotating_around_root:
self.root = new_root
return factor == 0
def __delete_fixup(
self, y: Optional[Node[T]], parent: Optional[Node[T]] = None
) -> bool:
x = y if y else parent
assert x
factor = _balance_factor(x)
if factor == 0:
_update_rank(x)
return False
elif factor in (-1, 1):
return True
rotating_around_root = x.parent is None
y, leaning, to_left, to_right = (
(x.right, 1, Node.rotate_left, Node.rotate_right)
if factor == 2
else (x.left, -1, Node.rotate_right, Node.rotate_left)
)
assert y
return self.__delete_rotate(
x,
y,
leaning,
rotating_around_root,
to_left,
to_right,
)
def _delete_rebalance(
self, node: Optional[Node[T]], parent: Optional[Node[T]]
) -> None:
while node or parent:
# TODO: Check if it is possible to not propagate all the way up.
# if self.__delete_fixup(node, parent):
# return
self.__delete_fixup(node, parent)
node, parent = parent, (parent.parent if parent else None)
# endregion DeleteRebalance

146
ranked_tree.py Normal file
View file

@ -0,0 +1,146 @@
from abc import abstractmethod
from node import Node, Comparable
from collections import deque
import logging
from typing import Deque, Optional, Tuple, TypeVar, Generic
logger = logging.getLogger(__name__)
T = TypeVar("T", bound=Comparable)
class RankedTree(Generic[T]):
def __init__(self) -> None:
self.root: Optional[Node[T]] = None
def __str__(self) -> str:
result = "digraph {\n"
queue: Deque[Optional[Node[T]]] = deque()
queue.append(self.root)
edges = []
while queue:
node = queue.popleft()
if not node:
continue
result += f'"{str(node)}" [label="{node.value}, {node.rank}"];\n'
for child in (node.left, node.right):
if not child:
continue
edges.append((node, child))
queue.append(child)
for from_node, to_node in edges:
label = f'[label="{Node.difference(to_node)}"]'
result += f'"{str(from_node)}" -> "{str(to_node)}" {label}\n'
return result + "}\n"
# region TreeSpecificMethods
@abstractmethod
def is_correct_node(self, node: Optional[Node[T]]) -> bool:
pass
@abstractmethod
def _insert_rebalance(self, x: Node[T]) -> None:
pass
@abstractmethod
def _delete_rebalance(
self, node: Optional[Node[T]], parent: Optional[Node[T]]
) -> None:
pass
# endregion TreeSpecificMethods
@property
def rank(self) -> int:
return Node.get_rank(self.root)
@property
def is_correct(self) -> bool:
return self.is_correct_node(self.root)
def search(
self, value: T, node: Optional[Node[T]] = None
) -> Optional[Node[T]]:
if not node:
node = self.root
return Node.search(value, node)
def insert(self, value: T) -> None:
inserted_node = Node(value)
if not self.root:
self.root = inserted_node
return
parent = Node.find_parent_node(value, self.root)
inserted_node.parent = parent
if value < parent.value:
parent.left = inserted_node
else:
parent.right = inserted_node
self._insert_rebalance(inserted_node)
def _transplant(self, u: Node[T], v: Optional[Node[T]]) -> None:
if not u.parent:
self.root = v
elif u.parent.left is u:
u.parent.left = v
else:
u.parent.right = v
if v:
v.rank = u.rank
v.parent = u.parent
def _delete_node(
self, node: Optional[Node[T]]
) -> Optional[Tuple[Optional[Node[T]], Optional[Node[T]]]]:
if node is None:
return None
y, parent = None, node.parent
if not node.left:
y = node.right
self._transplant(node, node.right)
elif not node.right:
y = node.left
self._transplant(node, node.left)
else:
n = Node.minimum(node.right)
y, parent = None, (n.parent if n.parent is not node else n)
if n.parent is not node:
parent = n.right if n.right else n.parent
self._transplant(n, n.right)
n.right = node.right
n.right.parent = n
self._transplant(node, n)
n.left = node.left
n.left.parent = n
return (y, parent)
def _delete(
self, value: T
) -> Optional[Tuple[Optional[Node[T]], Optional[Node[T]]]]:
node = self.root
while node is not None and node.value != value:
node = node.left if value < node.value else node.right
return self._delete_node(node)
def delete(self, value: T) -> None:
if to_be_rebalanced := self._delete(value):
y, parent = to_be_rebalanced
self._delete_rebalance(y, parent)

213
wavl.py
View file

@ -1,25 +1,17 @@
from node import Node, NodeType from avl import AVLTree
from node import Node, NodeType, Comparable, RotateFunction
from collections import deque
import logging import logging
from typing import TypeVar, Optional
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
T = TypeVar("T", bound=Comparable)
class WAVLTree: class WAVLTree(AVLTree[T]):
def __init__(self): def is_correct_node(
self.root = None self, node: Optional[Node[T]], recursive: bool = True
) -> bool:
@property
def rank(self):
return Node.rank(self.root)
@property
def is_correct(self):
return WAVLTree.is_correct_node(self.root)
@staticmethod
def is_correct_node(node, recursive=True):
if not node: if not node:
return True return True
@ -31,96 +23,21 @@ class WAVLTree:
return node.rank == 0 return node.rank == 0
return not recursive or ( return not recursive or (
WAVLTree.is_correct_node(node.left) and WAVLTree.is_correct_node(node.right) self.is_correct_node(node.left)
and self.is_correct_node(node.right)
) )
def search(self, value, node=None): # region DeleteRebalance
if not node:
node = self.root
return Node.search(value, node)
def __fix_0_child(self, x, y, z, rotate_left, rotate_right):
new_root = x.parent
if not y or Node.difference(y) == 2:
new_root = rotate_right(z)
z.demote()
elif Node.difference(y) == 1:
rotate_left(x)
new_root = rotate_right(z)
y.promote()
x.demote()
z.demote()
return new_root
def __bottomup_rebalance(self, x):
diffs = Node.differences(x.parent)
# if diffs != (0, 1) and diffs != (1, 0):
# return
while x.parent and (diffs == (0, 1) or diffs == (1, 0)):
x.parent.promote()
x = x.parent
diffs = Node.differences(x.parent)
if not x.parent:
return
rotating_around_root = x.parent.parent is None
new_root = x.parent
rank_difference = Node.difference(x)
if rank_difference != 0:
return
if rank_difference == 0 and x.parent.left is x:
new_root = self.__fix_0_child(
x, x.right, x.parent, Node.rotate_left, Node.rotate_right
)
elif rank_difference == 0 and x.parent.right is x:
new_root = self.__fix_0_child(
x, x.left, x.parent, Node.rotate_right, Node.rotate_left
)
if rotating_around_root:
self.root = new_root
def insert(self, value):
inserted_node = Node(value)
if not self.root:
self.root = inserted_node
return
parent = Node.find_parent_node(value, self.root)
inserted_node.parent = parent
if value < parent.value:
parent.left = inserted_node
else:
parent.right = inserted_node
self.__bottomup_rebalance(inserted_node)
def __transplant(self, u, v):
if not u.parent:
self.root = v
elif u.parent.left is u:
u.parent.left = v
else:
u.parent.right = v
if v:
v.rank = u.rank
v.parent = u.parent
@staticmethod @staticmethod
def __fix_delete(x, y, z, reversed, rotate_left, rotate_right): def __fix_delete(
x: Optional[Node[T]],
y: Node[T],
z: Node[T],
reversed: bool,
rotate_left: RotateFunction[T],
rotate_right: RotateFunction[T],
) -> Optional[Node[T]]:
new_root = x new_root = x
v = y.left v = y.left
w = y.right w = y.right
@ -132,6 +49,8 @@ class WAVLTree:
w_diff = Node.difference(w, y) w_diff = Node.difference(w, y)
logger.debug(f"w_diff = {w_diff}") logger.debug(f"w_diff = {w_diff}")
assert v
if w_diff == 1 and y.parent: if w_diff == 1 and y.parent:
logger.debug(f"y.parent = {y.parent}") logger.debug(f"y.parent = {y.parent}")
new_root = rotate_left(y.parent) new_root = rotate_left(y.parent)
@ -152,7 +71,9 @@ class WAVLTree:
return new_root return new_root
def __bottomup_delete(self, x, parent): def __bottomup_delete(
self, x: Optional[Node[T]], parent: Optional[Node[T]]
) -> None:
x_diff = Node.difference(x, parent) x_diff = Node.difference(x, parent)
if x_diff != 3 or not parent: if x_diff != 3 or not parent:
return return
@ -183,11 +104,12 @@ class WAVLTree:
return return
rotating_around_root = parent.parent is None rotating_around_root = parent.parent is None
new_root = parent new_root: Optional[Node[T]] = parent
parent_node_diffs = Node.differences(parent) parent_node_diffs = Node.differences(parent)
if parent_node_diffs in ((1, 3), (3, 1)): if parent_node_diffs in ((1, 3), (3, 1)):
if parent.left is x: if parent.left is x:
assert parent.right
new_root = WAVLTree.__fix_delete( new_root = WAVLTree.__fix_delete(
x, x,
parent.right, parent.right,
@ -197,6 +119,7 @@ class WAVLTree:
Node.rotate_right, Node.rotate_right,
) )
else: else:
assert parent.left
new_root = WAVLTree.__fix_delete( new_root = WAVLTree.__fix_delete(
x, x,
parent.left, parent.left,
@ -209,92 +132,34 @@ class WAVLTree:
if rotating_around_root: if rotating_around_root:
self.root = new_root self.root = new_root
def __delete_fixup(self, y, parent=None): def __delete_fixup(
self, y: Optional[Node[T]], parent: Optional[Node[T]] = None
) -> None:
logger.debug(f"[__delete_fixup] y = {y}, parent = {parent}") logger.debug(f"[__delete_fixup] y = {y}, parent = {parent}")
z = y if y else parent z = y if y else parent
logger.debug( logger.debug(
f"[z.demote()] Node.differences({repr(z)}) == (2, 2) ~>* {Node.differences(z)} == (2, 2)" f"[z.demote()] Node.differences({repr(z)}) == (2, 2) ~>*"
f"{Node.differences(z)} == (2, 2)"
) )
assert z
if Node.differences(z) == (2, 2): if Node.differences(z) == (2, 2):
z.demote() z.demote()
if parent: if parent:
for y in (parent.left, parent.right): for y in (parent.left, parent.right):
logger.debug( logger.debug(
f"[bottom-up delete] Node.difference({y}, {parent}) == 3 ~>* {Node.difference(y, parent)} == 3" f"[bottom-up delete] Node.difference({y}, {parent}) == 3"
f"~>* {Node.difference(y, parent)} == 3"
) )
if Node.difference(y, parent) == 3: if Node.difference(y, parent) == 3:
self.__bottomup_delete(y, parent) self.__bottomup_delete(y, parent)
def __fix_after_delete(self, node, parent): def _delete_rebalance(
self, node: Optional[Node[T]], parent: Optional[Node[T]]
) -> None:
while node or parent: while node or parent:
self.__delete_fixup(node, parent) self.__delete_fixup(node, parent)
node, parent = parent, (parent.parent if parent else None) node, parent = parent, (parent.parent if parent else None)
def delete_node(self, node): # endregion DeleteRebalance
y, parent = None, node.parent
if not node.left:
logger.debug("node.left is None")
y = node.right
self.__transplant(node, node.right)
elif not node.right:
logger.debug("node.right is None")
y = node.left
self.__transplant(node, node.left)
else:
logger.debug("taking successor")
n = Node.minimum(node.right)
y, parent = None, (n.parent if n.parent is not node else n)
if n.parent is not node:
parent = n.right if n.right else n.parent
self.__transplant(n, n.right)
n.right = node.right
n.right.parent = n
self.__transplant(node, n)
n.left = node.left
n.left.parent = n
return (y, parent)
def __delete(self, value, node=None):
if node is None:
return
if node.value != value:
return self.__delete(value, node.left if value < node.value else node.right)
(y, parent) = self.delete_node(node)
self.__fix_after_delete(y, parent)
def delete(self, value):
logger.debug(f"[DELETE] {value}")
self.__delete(value, self.root)
def __str__(self):
result = "digraph {\n"
queue = deque()
queue.append(self.root)
edges = []
while queue:
node = queue.popleft()
if not node:
continue
result += f'"{str(node)}" [label="{node.value}, {node.rank}"];\n'
for child in (node.left, node.right):
if not child:
continue
edges.append((node, child))
queue.append(child)
for fromNode, toNode in edges:
result += f'"{str(fromNode)}" -> "{str(toNode)}" [label="{Node.difference(toNode)}"]\n'
return result + "}\n"