python/ranked_tree.py

205 lines
4.9 KiB
Python

from node import Node, Comparable
from abc import abstractmethod
from collections import deque
import logging
from typing import Callable, Deque, Iterable, Optional, Tuple, TypeVar, Generic
logger = logging.getLogger(__name__)
T = TypeVar("T", bound=Comparable)
class RankedTree(Generic[T]):
def __init__(self) -> None:
self.root: Optional[Node[T]] = None
def _get_unwrapped_graph(self) -> str:
result = ""
queue: Deque[Optional[Node[T]]] = deque([self.root])
edges = []
while queue:
node = queue.popleft()
if not node:
continue
result += f'"{str(node)}" [label="{node.value}, {node.rank}"];\n'
for child in (node.left, node.right):
if not child:
continue
edges.append((node, child))
queue.append(child)
for from_node, to_node in edges:
label = f'[label="{Node.difference(to_node)}"]'
result += f'"{str(from_node)}" -> "{str(to_node)}" {label}\n'
return result
def __str__(self) -> str:
return "digraph {\n" + self._get_unwrapped_graph() + "}\n"
# region TreeSpecificMethods
@abstractmethod
def is_correct_node(self, node: Optional[Node[T]]) -> bool:
pass
@abstractmethod
def _insert_rebalance(self, x: Node[T]) -> None:
pass
@abstractmethod
def _delete_rebalance(
self, node: Optional[Node[T]], parent: Optional[Node[T]]
) -> None:
pass
# endregion TreeSpecificMethods
# region Rotations
def rotate_right(self, x: "Node[T]") -> "Node[T]":
parent = x.parent
y = x.left
# z = x.right
assert y is not None
if parent:
if parent.left is x:
parent.left = y
else:
parent.right = y
else:
self.root = y
x.left = y.right
if x.left:
x.left.parent = x
y.right = x
x.parent = y
y.parent = parent
return y
def rotate_left(self, x: "Node[T]") -> "Node[T]":
parent = x.parent
# y = x.left
z = x.right
assert z is not None
if parent:
if parent.left is x:
parent.left = z
else:
parent.right = z
else:
self.root = z
x.right = z.left
if x.right:
x.right.parent = x
z.left = x
x.parent = z
z.parent = parent
return z
# endregion Rotations
@property
def rank(self) -> int:
return Node.get_rank(self.root)
@property
def is_correct(self) -> bool:
return self.is_correct_node(self.root)
def search(
self, value: T, node: Optional[Node[T]] = None
) -> Optional[Node[T]]:
if not node:
node = self.root
return Node.search(value, node)
def insert(self, value: T) -> None:
inserted_node = Node(value)
if not self.root:
self.root = inserted_node
return
parent = Node.find_parent_node(value, self.root)
if not parent:
return
inserted_node.parent = parent
if value < parent.value:
parent.left = inserted_node
else:
parent.right = inserted_node
self._insert_rebalance(inserted_node)
def _transplant(self, u: Node[T], v: Optional[Node[T]]) -> None:
if not u.parent:
self.root = v
elif u.parent.left is u:
u.parent.left = v
else:
u.parent.right = v
if v:
v.rank = u.rank
v.parent = u.parent
def _delete_node(
self, node: Optional[Node[T]]
) -> Optional[Tuple[Optional[Node[T]], Optional[Node[T]]]]:
if node is None:
return None
y, parent = None, node.parent
if not node.left:
y = node.right
self._transplant(node, node.right)
elif not node.right:
y = node.left
self._transplant(node, node.left)
else:
n = Node.minimum(node.right)
node.value, n.value = n.value, node.value
return self._delete_node(n)
return (y, parent)
def _delete(
self, value: T
) -> Optional[Tuple[Optional[Node[T]], Optional[Node[T]]]]:
node = self.root
while node is not None and node.value != value:
node = node.left if value < node.value else node.right
return self._delete_node(node)
def delete(self, value: T) -> None:
if to_be_rebalanced := self._delete(value):
y, parent = to_be_rebalanced
self._delete_rebalance(y, parent)
def __iter__(self) -> Iterable[T]:
"""
Yields:
Keys from the tree in an inorder fashion.
"""
if self.root:
yield from self.root
RotateFunction = Callable[[Node[T]], Node[T]]