feat: add avl and split wavl
Signed-off-by: Matej Focko <mfocko@redhat.com>
This commit is contained in:
parent
f960a24eb6
commit
4a832e1de8
3 changed files with 350 additions and 174 deletions
165
avl.py
Normal file
165
avl.py
Normal file
|
@ -0,0 +1,165 @@
|
|||
from node import Node, Comparable, RotateFunction
|
||||
from ranked_tree import RankedTree
|
||||
|
||||
import logging
|
||||
from typing import TypeVar, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
T = TypeVar("T", bound=Comparable)
|
||||
|
||||
|
||||
def _balance_factor(node: Optional[Node[T]]) -> int:
|
||||
if not node:
|
||||
return 0
|
||||
|
||||
left, right = Node.get_rank(node.left), Node.get_rank(node.right)
|
||||
return right - left
|
||||
|
||||
|
||||
def _update_rank(node: Node[T]) -> None:
|
||||
left, right = Node.get_rank(node.left), Node.get_rank(node.right)
|
||||
# logger.debug(f"[_update_rank] on {node} = ({left}, {right})")
|
||||
node.rank = 1 + max(left, right)
|
||||
|
||||
|
||||
class AVLTree(RankedTree[T]):
|
||||
def is_correct_node(
|
||||
self, node: Optional[Node[T]], recursive: bool = True
|
||||
) -> bool:
|
||||
if not node:
|
||||
return True
|
||||
|
||||
if not (-1 <= _balance_factor(node) <= 1):
|
||||
return False
|
||||
|
||||
return not recursive or (
|
||||
self.is_correct_node(node.left)
|
||||
and self.is_correct_node(node.right)
|
||||
)
|
||||
|
||||
# region InsertRebalance
|
||||
|
||||
def __fix_0_child(
|
||||
self,
|
||||
x: Node[T],
|
||||
y: Optional[Node[T]],
|
||||
z: Node[T],
|
||||
rotate_left: RotateFunction[T],
|
||||
rotate_right: RotateFunction[T],
|
||||
) -> Optional[Node[T]]:
|
||||
new_root = x.parent
|
||||
|
||||
if not y or Node.difference(y) == 2:
|
||||
new_root = rotate_right(z)
|
||||
z.demote()
|
||||
elif Node.difference(y) == 1:
|
||||
rotate_left(x)
|
||||
new_root = rotate_right(z)
|
||||
|
||||
y.promote()
|
||||
x.demote()
|
||||
z.demote()
|
||||
|
||||
return new_root
|
||||
|
||||
def _insert_rebalance(self, x: Node[T]) -> None:
|
||||
diffs = Node.differences(x.parent)
|
||||
while x.parent and (diffs == (0, 1) or diffs == (1, 0)):
|
||||
x.parent.promote()
|
||||
x = x.parent
|
||||
|
||||
diffs = Node.differences(x.parent)
|
||||
|
||||
if not x.parent:
|
||||
return
|
||||
|
||||
rotating_around_root = x.parent.parent is None
|
||||
new_root: Optional[Node[T]] = x.parent
|
||||
|
||||
rank_difference = Node.difference(x)
|
||||
if rank_difference != 0:
|
||||
return
|
||||
|
||||
x_parent = x.parent
|
||||
assert x_parent is not None
|
||||
if rank_difference == 0 and x.parent.left is x:
|
||||
new_root = self.__fix_0_child(
|
||||
x, x.right, x_parent, Node.rotate_left, Node.rotate_right
|
||||
)
|
||||
elif rank_difference == 0 and x.parent.right is x:
|
||||
new_root = self.__fix_0_child(
|
||||
x, x.left, x_parent, Node.rotate_right, Node.rotate_left
|
||||
)
|
||||
|
||||
if rotating_around_root:
|
||||
self.root = new_root
|
||||
|
||||
# endregion InsertRebalance
|
||||
|
||||
# region DeleteRebalance
|
||||
|
||||
def __delete_rotate(
|
||||
self,
|
||||
x: Node[T],
|
||||
y: Node[T],
|
||||
leaning: int,
|
||||
rotating_around_root: bool,
|
||||
rotate_left: RotateFunction[T],
|
||||
rotate_right: RotateFunction[T],
|
||||
) -> bool:
|
||||
new_root = x
|
||||
|
||||
factor = _balance_factor(y)
|
||||
if factor in (0, leaning):
|
||||
new_root = rotate_left(x)
|
||||
else:
|
||||
rotate_right(y)
|
||||
new_root = rotate_left(x)
|
||||
|
||||
for n in filter(None, (new_root.left, new_root.right, new_root)):
|
||||
_update_rank(n)
|
||||
|
||||
if rotating_around_root:
|
||||
self.root = new_root
|
||||
return factor == 0
|
||||
|
||||
def __delete_fixup(
|
||||
self, y: Optional[Node[T]], parent: Optional[Node[T]] = None
|
||||
) -> bool:
|
||||
x = y if y else parent
|
||||
assert x
|
||||
|
||||
factor = _balance_factor(x)
|
||||
if factor == 0:
|
||||
_update_rank(x)
|
||||
return False
|
||||
elif factor in (-1, 1):
|
||||
return True
|
||||
|
||||
rotating_around_root = x.parent is None
|
||||
y, leaning, to_left, to_right = (
|
||||
(x.right, 1, Node.rotate_left, Node.rotate_right)
|
||||
if factor == 2
|
||||
else (x.left, -1, Node.rotate_right, Node.rotate_left)
|
||||
)
|
||||
assert y
|
||||
return self.__delete_rotate(
|
||||
x,
|
||||
y,
|
||||
leaning,
|
||||
rotating_around_root,
|
||||
to_left,
|
||||
to_right,
|
||||
)
|
||||
|
||||
def _delete_rebalance(
|
||||
self, node: Optional[Node[T]], parent: Optional[Node[T]]
|
||||
) -> None:
|
||||
while node or parent:
|
||||
# TODO: Check if it is possible to not propagate all the way up.
|
||||
# if self.__delete_fixup(node, parent):
|
||||
# return
|
||||
self.__delete_fixup(node, parent)
|
||||
node, parent = parent, (parent.parent if parent else None)
|
||||
|
||||
# endregion DeleteRebalance
|
146
ranked_tree.py
Normal file
146
ranked_tree.py
Normal file
|
@ -0,0 +1,146 @@
|
|||
from abc import abstractmethod
|
||||
from node import Node, Comparable
|
||||
|
||||
from collections import deque
|
||||
import logging
|
||||
from typing import Deque, Optional, Tuple, TypeVar, Generic
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
T = TypeVar("T", bound=Comparable)
|
||||
|
||||
|
||||
class RankedTree(Generic[T]):
|
||||
def __init__(self) -> None:
|
||||
self.root: Optional[Node[T]] = None
|
||||
|
||||
def __str__(self) -> str:
|
||||
result = "digraph {\n"
|
||||
|
||||
queue: Deque[Optional[Node[T]]] = deque()
|
||||
queue.append(self.root)
|
||||
|
||||
edges = []
|
||||
while queue:
|
||||
node = queue.popleft()
|
||||
if not node:
|
||||
continue
|
||||
|
||||
result += f'"{str(node)}" [label="{node.value}, {node.rank}"];\n'
|
||||
for child in (node.left, node.right):
|
||||
if not child:
|
||||
continue
|
||||
|
||||
edges.append((node, child))
|
||||
queue.append(child)
|
||||
|
||||
for from_node, to_node in edges:
|
||||
label = f'[label="{Node.difference(to_node)}"]'
|
||||
result += f'"{str(from_node)}" -> "{str(to_node)}" {label}\n'
|
||||
|
||||
return result + "}\n"
|
||||
|
||||
# region TreeSpecificMethods
|
||||
|
||||
@abstractmethod
|
||||
def is_correct_node(self, node: Optional[Node[T]]) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _insert_rebalance(self, x: Node[T]) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _delete_rebalance(
|
||||
self, node: Optional[Node[T]], parent: Optional[Node[T]]
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
# endregion TreeSpecificMethods
|
||||
|
||||
@property
|
||||
def rank(self) -> int:
|
||||
return Node.get_rank(self.root)
|
||||
|
||||
@property
|
||||
def is_correct(self) -> bool:
|
||||
return self.is_correct_node(self.root)
|
||||
|
||||
def search(
|
||||
self, value: T, node: Optional[Node[T]] = None
|
||||
) -> Optional[Node[T]]:
|
||||
if not node:
|
||||
node = self.root
|
||||
|
||||
return Node.search(value, node)
|
||||
|
||||
def insert(self, value: T) -> None:
|
||||
inserted_node = Node(value)
|
||||
|
||||
if not self.root:
|
||||
self.root = inserted_node
|
||||
return
|
||||
|
||||
parent = Node.find_parent_node(value, self.root)
|
||||
inserted_node.parent = parent
|
||||
|
||||
if value < parent.value:
|
||||
parent.left = inserted_node
|
||||
else:
|
||||
parent.right = inserted_node
|
||||
|
||||
self._insert_rebalance(inserted_node)
|
||||
|
||||
def _transplant(self, u: Node[T], v: Optional[Node[T]]) -> None:
|
||||
if not u.parent:
|
||||
self.root = v
|
||||
elif u.parent.left is u:
|
||||
u.parent.left = v
|
||||
else:
|
||||
u.parent.right = v
|
||||
|
||||
if v:
|
||||
v.rank = u.rank
|
||||
v.parent = u.parent
|
||||
|
||||
def _delete_node(
|
||||
self, node: Optional[Node[T]]
|
||||
) -> Optional[Tuple[Optional[Node[T]], Optional[Node[T]]]]:
|
||||
if node is None:
|
||||
return None
|
||||
|
||||
y, parent = None, node.parent
|
||||
|
||||
if not node.left:
|
||||
y = node.right
|
||||
self._transplant(node, node.right)
|
||||
elif not node.right:
|
||||
y = node.left
|
||||
self._transplant(node, node.left)
|
||||
else:
|
||||
n = Node.minimum(node.right)
|
||||
y, parent = None, (n.parent if n.parent is not node else n)
|
||||
|
||||
if n.parent is not node:
|
||||
parent = n.right if n.right else n.parent
|
||||
self._transplant(n, n.right)
|
||||
n.right = node.right
|
||||
n.right.parent = n
|
||||
|
||||
self._transplant(node, n)
|
||||
n.left = node.left
|
||||
n.left.parent = n
|
||||
|
||||
return (y, parent)
|
||||
|
||||
def _delete(
|
||||
self, value: T
|
||||
) -> Optional[Tuple[Optional[Node[T]], Optional[Node[T]]]]:
|
||||
node = self.root
|
||||
while node is not None and node.value != value:
|
||||
node = node.left if value < node.value else node.right
|
||||
return self._delete_node(node)
|
||||
|
||||
def delete(self, value: T) -> None:
|
||||
if to_be_rebalanced := self._delete(value):
|
||||
y, parent = to_be_rebalanced
|
||||
self._delete_rebalance(y, parent)
|
213
wavl.py
213
wavl.py
|
@ -1,25 +1,17 @@
|
|||
from node import Node, NodeType
|
||||
from avl import AVLTree
|
||||
from node import Node, NodeType, Comparable, RotateFunction
|
||||
|
||||
from collections import deque
|
||||
import logging
|
||||
from typing import TypeVar, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
T = TypeVar("T", bound=Comparable)
|
||||
|
||||
|
||||
class WAVLTree:
|
||||
def __init__(self):
|
||||
self.root = None
|
||||
|
||||
@property
|
||||
def rank(self):
|
||||
return Node.rank(self.root)
|
||||
|
||||
@property
|
||||
def is_correct(self):
|
||||
return WAVLTree.is_correct_node(self.root)
|
||||
|
||||
@staticmethod
|
||||
def is_correct_node(node, recursive=True):
|
||||
class WAVLTree(AVLTree[T]):
|
||||
def is_correct_node(
|
||||
self, node: Optional[Node[T]], recursive: bool = True
|
||||
) -> bool:
|
||||
if not node:
|
||||
return True
|
||||
|
||||
|
@ -31,96 +23,21 @@ class WAVLTree:
|
|||
return node.rank == 0
|
||||
|
||||
return not recursive or (
|
||||
WAVLTree.is_correct_node(node.left) and WAVLTree.is_correct_node(node.right)
|
||||
self.is_correct_node(node.left)
|
||||
and self.is_correct_node(node.right)
|
||||
)
|
||||
|
||||
def search(self, value, node=None):
|
||||
if not node:
|
||||
node = self.root
|
||||
|
||||
return Node.search(value, node)
|
||||
|
||||
def __fix_0_child(self, x, y, z, rotate_left, rotate_right):
|
||||
new_root = x.parent
|
||||
|
||||
if not y or Node.difference(y) == 2:
|
||||
new_root = rotate_right(z)
|
||||
z.demote()
|
||||
elif Node.difference(y) == 1:
|
||||
rotate_left(x)
|
||||
new_root = rotate_right(z)
|
||||
|
||||
y.promote()
|
||||
x.demote()
|
||||
z.demote()
|
||||
|
||||
return new_root
|
||||
|
||||
def __bottomup_rebalance(self, x):
|
||||
diffs = Node.differences(x.parent)
|
||||
# if diffs != (0, 1) and diffs != (1, 0):
|
||||
# return
|
||||
|
||||
while x.parent and (diffs == (0, 1) or diffs == (1, 0)):
|
||||
x.parent.promote()
|
||||
x = x.parent
|
||||
|
||||
diffs = Node.differences(x.parent)
|
||||
|
||||
if not x.parent:
|
||||
return
|
||||
|
||||
rotating_around_root = x.parent.parent is None
|
||||
new_root = x.parent
|
||||
|
||||
rank_difference = Node.difference(x)
|
||||
|
||||
if rank_difference != 0:
|
||||
return
|
||||
|
||||
if rank_difference == 0 and x.parent.left is x:
|
||||
new_root = self.__fix_0_child(
|
||||
x, x.right, x.parent, Node.rotate_left, Node.rotate_right
|
||||
)
|
||||
elif rank_difference == 0 and x.parent.right is x:
|
||||
new_root = self.__fix_0_child(
|
||||
x, x.left, x.parent, Node.rotate_right, Node.rotate_left
|
||||
)
|
||||
|
||||
if rotating_around_root:
|
||||
self.root = new_root
|
||||
|
||||
def insert(self, value):
|
||||
inserted_node = Node(value)
|
||||
|
||||
if not self.root:
|
||||
self.root = inserted_node
|
||||
return
|
||||
|
||||
parent = Node.find_parent_node(value, self.root)
|
||||
inserted_node.parent = parent
|
||||
|
||||
if value < parent.value:
|
||||
parent.left = inserted_node
|
||||
else:
|
||||
parent.right = inserted_node
|
||||
|
||||
self.__bottomup_rebalance(inserted_node)
|
||||
|
||||
def __transplant(self, u, v):
|
||||
if not u.parent:
|
||||
self.root = v
|
||||
elif u.parent.left is u:
|
||||
u.parent.left = v
|
||||
else:
|
||||
u.parent.right = v
|
||||
|
||||
if v:
|
||||
v.rank = u.rank
|
||||
v.parent = u.parent
|
||||
# region DeleteRebalance
|
||||
|
||||
@staticmethod
|
||||
def __fix_delete(x, y, z, reversed, rotate_left, rotate_right):
|
||||
def __fix_delete(
|
||||
x: Optional[Node[T]],
|
||||
y: Node[T],
|
||||
z: Node[T],
|
||||
reversed: bool,
|
||||
rotate_left: RotateFunction[T],
|
||||
rotate_right: RotateFunction[T],
|
||||
) -> Optional[Node[T]]:
|
||||
new_root = x
|
||||
v = y.left
|
||||
w = y.right
|
||||
|
@ -132,6 +49,8 @@ class WAVLTree:
|
|||
|
||||
w_diff = Node.difference(w, y)
|
||||
logger.debug(f"w_diff = {w_diff}")
|
||||
|
||||
assert v
|
||||
if w_diff == 1 and y.parent:
|
||||
logger.debug(f"y.parent = {y.parent}")
|
||||
new_root = rotate_left(y.parent)
|
||||
|
@ -152,7 +71,9 @@ class WAVLTree:
|
|||
|
||||
return new_root
|
||||
|
||||
def __bottomup_delete(self, x, parent):
|
||||
def __bottomup_delete(
|
||||
self, x: Optional[Node[T]], parent: Optional[Node[T]]
|
||||
) -> None:
|
||||
x_diff = Node.difference(x, parent)
|
||||
if x_diff != 3 or not parent:
|
||||
return
|
||||
|
@ -183,11 +104,12 @@ class WAVLTree:
|
|||
return
|
||||
|
||||
rotating_around_root = parent.parent is None
|
||||
new_root = parent
|
||||
new_root: Optional[Node[T]] = parent
|
||||
|
||||
parent_node_diffs = Node.differences(parent)
|
||||
if parent_node_diffs in ((1, 3), (3, 1)):
|
||||
if parent.left is x:
|
||||
assert parent.right
|
||||
new_root = WAVLTree.__fix_delete(
|
||||
x,
|
||||
parent.right,
|
||||
|
@ -197,6 +119,7 @@ class WAVLTree:
|
|||
Node.rotate_right,
|
||||
)
|
||||
else:
|
||||
assert parent.left
|
||||
new_root = WAVLTree.__fix_delete(
|
||||
x,
|
||||
parent.left,
|
||||
|
@ -209,92 +132,34 @@ class WAVLTree:
|
|||
if rotating_around_root:
|
||||
self.root = new_root
|
||||
|
||||
def __delete_fixup(self, y, parent=None):
|
||||
def __delete_fixup(
|
||||
self, y: Optional[Node[T]], parent: Optional[Node[T]] = None
|
||||
) -> None:
|
||||
logger.debug(f"[__delete_fixup] y = {y}, parent = {parent}")
|
||||
|
||||
z = y if y else parent
|
||||
logger.debug(
|
||||
f"[z.demote()] Node.differences({repr(z)}) == (2, 2) ~>* {Node.differences(z)} == (2, 2)"
|
||||
f"[z.demote()] Node.differences({repr(z)}) == (2, 2) ~>*"
|
||||
f"{Node.differences(z)} == (2, 2)"
|
||||
)
|
||||
assert z
|
||||
if Node.differences(z) == (2, 2):
|
||||
z.demote()
|
||||
|
||||
if parent:
|
||||
for y in (parent.left, parent.right):
|
||||
logger.debug(
|
||||
f"[bottom-up delete] Node.difference({y}, {parent}) == 3 ~>* {Node.difference(y, parent)} == 3"
|
||||
f"[bottom-up delete] Node.difference({y}, {parent}) == 3"
|
||||
f"~>* {Node.difference(y, parent)} == 3"
|
||||
)
|
||||
if Node.difference(y, parent) == 3:
|
||||
self.__bottomup_delete(y, parent)
|
||||
|
||||
def __fix_after_delete(self, node, parent):
|
||||
def _delete_rebalance(
|
||||
self, node: Optional[Node[T]], parent: Optional[Node[T]]
|
||||
) -> None:
|
||||
while node or parent:
|
||||
self.__delete_fixup(node, parent)
|
||||
node, parent = parent, (parent.parent if parent else None)
|
||||
|
||||
def delete_node(self, node):
|
||||
y, parent = None, node.parent
|
||||
|
||||
if not node.left:
|
||||
logger.debug("node.left is None")
|
||||
y = node.right
|
||||
self.__transplant(node, node.right)
|
||||
elif not node.right:
|
||||
logger.debug("node.right is None")
|
||||
y = node.left
|
||||
self.__transplant(node, node.left)
|
||||
else:
|
||||
logger.debug("taking successor")
|
||||
n = Node.minimum(node.right)
|
||||
y, parent = None, (n.parent if n.parent is not node else n)
|
||||
|
||||
if n.parent is not node:
|
||||
parent = n.right if n.right else n.parent
|
||||
self.__transplant(n, n.right)
|
||||
n.right = node.right
|
||||
n.right.parent = n
|
||||
|
||||
self.__transplant(node, n)
|
||||
n.left = node.left
|
||||
n.left.parent = n
|
||||
|
||||
return (y, parent)
|
||||
|
||||
def __delete(self, value, node=None):
|
||||
if node is None:
|
||||
return
|
||||
|
||||
if node.value != value:
|
||||
return self.__delete(value, node.left if value < node.value else node.right)
|
||||
|
||||
(y, parent) = self.delete_node(node)
|
||||
self.__fix_after_delete(y, parent)
|
||||
|
||||
def delete(self, value):
|
||||
logger.debug(f"[DELETE] {value}")
|
||||
self.__delete(value, self.root)
|
||||
|
||||
def __str__(self):
|
||||
result = "digraph {\n"
|
||||
|
||||
queue = deque()
|
||||
queue.append(self.root)
|
||||
|
||||
edges = []
|
||||
while queue:
|
||||
node = queue.popleft()
|
||||
if not node:
|
||||
continue
|
||||
|
||||
result += f'"{str(node)}" [label="{node.value}, {node.rank}"];\n'
|
||||
for child in (node.left, node.right):
|
||||
if not child:
|
||||
continue
|
||||
|
||||
edges.append((node, child))
|
||||
queue.append(child)
|
||||
|
||||
for fromNode, toNode in edges:
|
||||
result += f'"{str(fromNode)}" -> "{str(toNode)}" [label="{Node.difference(toNode)}"]\n'
|
||||
|
||||
return result + "}\n"
|
||||
# endregion DeleteRebalance
|
||||
|
|
Loading…
Reference in a new issue