152 lines
3.9 KiB
Python
152 lines
3.9 KiB
Python
from node import Node, Comparable
|
|
from ranked_tree import RankedTree, RotateFunction
|
|
|
|
import logging
|
|
from typing import TypeVar, Optional
|
|
|
|
logger = logging.getLogger(__name__)
|
|
T = TypeVar("T", bound=Comparable)
|
|
|
|
|
|
def _balance_factor(node: Optional[Node[T]]) -> int:
|
|
if not node:
|
|
return 0
|
|
|
|
left, right = Node.get_rank(node.left), Node.get_rank(node.right)
|
|
return right - left
|
|
|
|
|
|
def _update_rank(node: Node[T]) -> None:
|
|
left, right = Node.get_rank(node.left), Node.get_rank(node.right)
|
|
node.rank = 1 + max(left, right)
|
|
|
|
|
|
class AVLTree(RankedTree[T]):
|
|
def is_correct_node(
|
|
self, node: Optional[Node[T]], recursive: bool = True
|
|
) -> bool:
|
|
if not node:
|
|
return True
|
|
|
|
if not (
|
|
-1 <= _balance_factor(node) <= 1
|
|
and node.rank
|
|
== 1 + max(Node.get_rank(node.left), Node.get_rank(node.right))
|
|
):
|
|
return False
|
|
|
|
return not recursive or (
|
|
self.is_correct_node(node.left)
|
|
and self.is_correct_node(node.right)
|
|
)
|
|
|
|
# region InsertRebalance
|
|
|
|
def __fix_0_child(
|
|
self,
|
|
x: Node[T],
|
|
y: Optional[Node[T]],
|
|
z: Node[T],
|
|
rotate_left: RotateFunction[T],
|
|
rotate_right: RotateFunction[T],
|
|
) -> None:
|
|
if not y or Node.difference(y) == 2:
|
|
rotate_right(z)
|
|
z.demote()
|
|
elif Node.difference(y) == 1:
|
|
rotate_left(x)
|
|
rotate_right(z)
|
|
|
|
y.promote()
|
|
x.demote()
|
|
z.demote()
|
|
|
|
def _insert_rebalance(self, x: Node[T]) -> None:
|
|
diffs = Node.differences(x.parent)
|
|
while x.parent and (diffs == (0, 1) or diffs == (1, 0)):
|
|
x.parent.promote()
|
|
x = x.parent
|
|
|
|
diffs = Node.differences(x.parent)
|
|
|
|
if not x.parent:
|
|
return
|
|
|
|
rank_difference = Node.difference(x)
|
|
if rank_difference != 0:
|
|
return
|
|
|
|
x_parent = x.parent
|
|
assert x_parent is not None
|
|
if rank_difference == 0 and x.parent.left is x:
|
|
self.__fix_0_child(
|
|
x, x.right, x_parent, self.rotate_left, self.rotate_right
|
|
)
|
|
elif rank_difference == 0 and x.parent.right is x:
|
|
self.__fix_0_child(
|
|
x, x.left, x_parent, self.rotate_right, self.rotate_left
|
|
)
|
|
|
|
# endregion InsertRebalance
|
|
|
|
# region DeleteRebalance
|
|
|
|
def __delete_rotate(
|
|
self,
|
|
x: Node[T],
|
|
y: Node[T],
|
|
leaning: int,
|
|
rotate_left: RotateFunction[T],
|
|
rotate_right: RotateFunction[T],
|
|
) -> bool:
|
|
new_root = x
|
|
|
|
factor = _balance_factor(y)
|
|
if factor in (0, leaning):
|
|
new_root = rotate_left(x)
|
|
else:
|
|
rotate_right(y)
|
|
new_root = rotate_left(x)
|
|
|
|
for n in filter(None, (new_root.left, new_root.right, new_root)):
|
|
_update_rank(n)
|
|
|
|
return factor != 0
|
|
|
|
def __delete_fixup(
|
|
self, x: Node[T], parent: Optional[Node[T]] = None
|
|
) -> bool:
|
|
factor = _balance_factor(x)
|
|
if factor == 0:
|
|
_update_rank(x)
|
|
return True
|
|
elif factor in (-1, 1):
|
|
return False
|
|
|
|
y, leaning, to_left, to_right = (
|
|
(x.right, 1, self.rotate_left, self.rotate_right)
|
|
if factor == 2
|
|
else (x.left, -1, self.rotate_right, self.rotate_left)
|
|
)
|
|
assert y
|
|
return self.__delete_rotate(
|
|
x,
|
|
y,
|
|
leaning,
|
|
to_left,
|
|
to_right,
|
|
)
|
|
|
|
def _delete_rebalance(
|
|
self, node: Optional[Node[T]], parent: Optional[Node[T]]
|
|
) -> None:
|
|
if not node and not parent:
|
|
return
|
|
|
|
if not node and parent:
|
|
node, parent = parent, parent.parent
|
|
|
|
while node and self.__delete_fixup(node, parent):
|
|
node, parent = parent, parent.parent if parent else None
|
|
|
|
# endregion DeleteRebalance
|