python/avl.py

153 lines
3.9 KiB
Python
Raw Normal View History

from node import Node, Comparable
from ranked_tree import RankedTree, RotateFunction
import logging
from typing import TypeVar, Optional
logger = logging.getLogger(__name__)
T = TypeVar("T", bound=Comparable)
def _balance_factor(node: Optional[Node[T]]) -> int:
if not node:
return 0
left, right = Node.get_rank(node.left), Node.get_rank(node.right)
return right - left
def _update_rank(node: Node[T]) -> None:
left, right = Node.get_rank(node.left), Node.get_rank(node.right)
node.rank = 1 + max(left, right)
class AVLTree(RankedTree[T]):
def is_correct_node(
self, node: Optional[Node[T]], recursive: bool = True
) -> bool:
if not node:
return True
if not (
-1 <= _balance_factor(node) <= 1
and node.rank
== 1 + max(Node.get_rank(node.left), Node.get_rank(node.right))
):
return False
return not recursive or (
self.is_correct_node(node.left)
and self.is_correct_node(node.right)
)
# region InsertRebalance
def __fix_0_child(
self,
x: Node[T],
y: Optional[Node[T]],
z: Node[T],
rotate_left: RotateFunction[T],
rotate_right: RotateFunction[T],
) -> None:
if not y or Node.difference(y) == 2:
rotate_right(z)
z.demote()
elif Node.difference(y) == 1:
rotate_left(x)
rotate_right(z)
y.promote()
x.demote()
z.demote()
def _insert_rebalance(self, x: Node[T]) -> None:
diffs = Node.differences(x.parent)
while x.parent and (diffs == (0, 1) or diffs == (1, 0)):
x.parent.promote()
x = x.parent
diffs = Node.differences(x.parent)
if not x.parent:
return
rank_difference = Node.difference(x)
if rank_difference != 0:
return
x_parent = x.parent
assert x_parent is not None
if rank_difference == 0 and x.parent.left is x:
self.__fix_0_child(
x, x.right, x_parent, self.rotate_left, self.rotate_right
)
elif rank_difference == 0 and x.parent.right is x:
self.__fix_0_child(
x, x.left, x_parent, self.rotate_right, self.rotate_left
)
# endregion InsertRebalance
# region DeleteRebalance
def __delete_rotate(
self,
x: Node[T],
y: Node[T],
leaning: int,
rotate_left: RotateFunction[T],
rotate_right: RotateFunction[T],
) -> bool:
new_root = x
factor = _balance_factor(y)
if factor in (0, leaning):
new_root = rotate_left(x)
else:
rotate_right(y)
new_root = rotate_left(x)
for n in filter(None, (new_root.left, new_root.right, new_root)):
_update_rank(n)
return factor != 0
def __delete_fixup(
self, x: Node[T], parent: Optional[Node[T]] = None
) -> bool:
factor = _balance_factor(x)
if factor == 0:
_update_rank(x)
return True
elif factor in (-1, 1):
return False
y, leaning, to_left, to_right = (
(x.right, 1, self.rotate_left, self.rotate_right)
if factor == 2
else (x.left, -1, self.rotate_right, self.rotate_left)
)
assert y
return self.__delete_rotate(
x,
y,
leaning,
to_left,
to_right,
)
def _delete_rebalance(
self, node: Optional[Node[T]], parent: Optional[Node[T]]
) -> None:
if not node and not parent:
return
if not node and parent:
node, parent = parent, parent.parent
while node and self.__delete_fixup(node, parent):
node, parent = parent, parent.parent if parent else None
# endregion DeleteRebalance