fix(rotations): move rotations to the tree

Move rotations to the RankedTree, which allows us to factor out the
functionality regarding rotations when root of the tree changes.

Signed-off-by: Matej Focko <mfocko@redhat.com>
This commit is contained in:
Matej Focko 2022-05-22 21:00:18 +02:00
parent 7b5bf0ec3a
commit e28f578fc2
Signed by: mfocko
GPG key ID: 7C47D46246790496
5 changed files with 89 additions and 112 deletions

37
avl.py
View file

@ -1,5 +1,5 @@
from node import Node, Comparable, RotateFunction
from ranked_tree import RankedTree
from node import Node, Comparable
from ranked_tree import RankedTree, RotateFunction
import logging
from typing import TypeVar, Optional
@ -49,22 +49,18 @@ class AVLTree(RankedTree[T]):
z: Node[T],
rotate_left: RotateFunction[T],
rotate_right: RotateFunction[T],
) -> Optional[Node[T]]:
new_root = x.parent
) -> None:
if not y or Node.difference(y) == 2:
new_root = rotate_right(z)
rotate_right(z)
z.demote()
elif Node.difference(y) == 1:
rotate_left(x)
new_root = rotate_right(z)
rotate_right(z)
y.promote()
x.demote()
z.demote()
return new_root
def _insert_rebalance(self, x: Node[T]) -> None:
diffs = Node.differences(x.parent)
while x.parent and (diffs == (0, 1) or diffs == (1, 0)):
@ -76,9 +72,6 @@ class AVLTree(RankedTree[T]):
if not x.parent:
return
rotating_around_root = x.parent.parent is None
new_root: Optional[Node[T]] = x.parent
rank_difference = Node.difference(x)
if rank_difference != 0:
return
@ -86,17 +79,14 @@ class AVLTree(RankedTree[T]):
x_parent = x.parent
assert x_parent is not None
if rank_difference == 0 and x.parent.left is x:
new_root = self.__fix_0_child(
x, x.right, x_parent, Node.rotate_left, Node.rotate_right
self.__fix_0_child(
x, x.right, x_parent, self.rotate_left, self.rotate_right
)
elif rank_difference == 0 and x.parent.right is x:
new_root = self.__fix_0_child(
x, x.left, x_parent, Node.rotate_right, Node.rotate_left
self.__fix_0_child(
x, x.left, x_parent, self.rotate_right, self.rotate_left
)
if rotating_around_root:
self.root = new_root
# endregion InsertRebalance
# region DeleteRebalance
@ -106,7 +96,6 @@ class AVLTree(RankedTree[T]):
x: Node[T],
y: Node[T],
leaning: int,
rotating_around_root: bool,
rotate_left: RotateFunction[T],
rotate_right: RotateFunction[T],
) -> bool:
@ -122,8 +111,6 @@ class AVLTree(RankedTree[T]):
for n in filter(None, (new_root.left, new_root.right, new_root)):
_update_rank(n)
if rotating_around_root:
self.root = new_root
return factor != 0
def __delete_fixup(
@ -136,18 +123,16 @@ class AVLTree(RankedTree[T]):
elif factor in (-1, 1):
return False
rotating_around_root = x.parent is None
y, leaning, to_left, to_right = (
(x.right, 1, Node.rotate_left, Node.rotate_right)
(x.right, 1, self.rotate_left, self.rotate_right)
if factor == 2
else (x.left, -1, Node.rotate_right, Node.rotate_left)
else (x.left, -1, self.rotate_right, self.rotate_left)
)
assert y
return self.__delete_rotate(
x,
y,
leaning,
rotating_around_root,
to_left,
to_right,
)

55
node.py
View file

@ -1,14 +1,12 @@
from abc import abstractmethod
import enum
from typing import (
Callable,
Iterable,
Optional,
Generic,
Tuple,
TypeVar,
Protocol,
Any,
)
@ -112,56 +110,6 @@ class Node(Generic[T]):
self.rank -= 1
return self
@staticmethod
def rotate_right(x: "Node[T]", tree: Any = None) -> "Node[T]":
parent = x.parent
y = x.left
# z = x.right
assert y is not None
if parent:
if parent.left is x:
parent.left = y
else:
parent.right = y
elif tree:
tree.root = y
x.left = y.right
if x.left:
x.left.parent = x
y.right = x
x.parent = y
y.parent = parent
return y
@staticmethod
def rotate_left(x: "Node[T]", tree: Any = None) -> "Node[T]":
parent = x.parent
# y = x.left
z = x.right
assert z is not None
if parent:
if parent.left is x:
parent.left = z
else:
parent.right = z
elif tree:
tree.root = z
x.right = z.left
if x.right:
x.right.parent = x
z.left = x
x.parent = z
z.parent = parent
return z
@staticmethod
def find_parent_node(
value: T, node: "Node[T]", missing: bool = True
@ -199,6 +147,3 @@ class Node(Generic[T]):
node = node.right
return node
RotateFunction = Callable[[Node[T]], Node[T]]

View file

@ -3,7 +3,7 @@ from node import Node, Comparable
from abc import abstractmethod
from collections import deque
import logging
from typing import Deque, Iterable, Optional, Tuple, TypeVar, Generic
from typing import Callable, Deque, Iterable, Optional, Tuple, TypeVar, Generic
logger = logging.getLogger(__name__)
T = TypeVar("T", bound=Comparable)
@ -58,6 +58,58 @@ class RankedTree(Generic[T]):
# endregion TreeSpecificMethods
# region Rotations
def rotate_right(self, x: "Node[T]") -> "Node[T]":
parent = x.parent
y = x.left
# z = x.right
assert y is not None
if parent:
if parent.left is x:
parent.left = y
else:
parent.right = y
else:
self.root = y
x.left = y.right
if x.left:
x.left.parent = x
y.right = x
x.parent = y
y.parent = parent
return y
def rotate_left(self, x: "Node[T]") -> "Node[T]":
parent = x.parent
# y = x.left
z = x.right
assert z is not None
if parent:
if parent.left is x:
parent.left = z
else:
parent.right = z
else:
self.root = z
x.right = z.left
if x.right:
x.right.parent = x
z.left = x
x.parent = z
z.parent = parent
return z
# endregion Rotations
@property
def rank(self) -> int:
return Node.get_rank(self.root)
@ -148,3 +200,6 @@ class RankedTree(Generic[T]):
"""
if self.root:
yield from self.root
RotateFunction = Callable[[Node[T]], Node[T]]

27
rbt.py
View file

@ -1,5 +1,5 @@
from node import Node, Comparable, RotateFunction
from ranked_tree import RankedTree
from node import Node, Comparable
from ranked_tree import RankedTree, RotateFunction
import enum
import logging
@ -13,6 +13,7 @@ class Colour(enum.IntEnum):
"""
Represents colour of the edge or node.
"""
Red = 0
Black = 1
@ -80,12 +81,12 @@ class RBTree(RankedTree[T]):
# ======
# zs uncle y is black and z is a right child
z = p
rotate_left(tree=self, x=p)
rotate_left(p)
else:
# Case 3
# ======
# zs uncle y is black and z is a left child
rotate_right(tree=self, x=pp)
rotate_right(pp)
return z
@ -97,11 +98,11 @@ class RBTree(RankedTree[T]):
assert pp
if p == pp.left:
z = self._insert_rebalance_step(
z, pp.right, p.right, Node.rotate_left, Node.rotate_right
z, pp.right, p.right, self.rotate_left, self.rotate_right
)
else:
z = self._insert_rebalance_step(
z, pp.left, p.left, Node.rotate_right, Node.rotate_left
z, pp.left, p.left, self.rotate_right, self.rotate_left
)
# endregion InsertRebalance
@ -121,7 +122,7 @@ class RBTree(RankedTree[T]):
# Case 1
# ======
# xs sibling w is red
rotate_left(tree=self, x=parent)
rotate_left(parent)
w = right(parent)
if Node.differences(w) == (Colour.Black, Colour.Black):
@ -136,7 +137,7 @@ class RBTree(RankedTree[T]):
# xs sibling w is black,
# ws left child is red, and ws right child is black
if Node.difference(right(w), w) == Colour.Black:
rotate_right(tree=self, x=w)
rotate_right(w)
w = right(parent)
# Case 4
@ -144,7 +145,7 @@ class RBTree(RankedTree[T]):
# xs sibling w is black, and ws right child is red
parent.rank -= Colour.Black
w.rank += Colour.Black
rotate_left(tree=self, x=parent)
rotate_left(parent)
x = self.root
return x, (x.parent if x else None)
@ -162,8 +163,8 @@ class RBTree(RankedTree[T]):
parent.right,
parent,
lambda x: x.right,
Node.rotate_left,
Node.rotate_right,
self.rotate_left,
self.rotate_right,
)
else:
node, parent = self._delete_rebalance_step(
@ -171,8 +172,8 @@ class RBTree(RankedTree[T]):
parent.left,
parent,
lambda x: x.left,
Node.rotate_right,
Node.rotate_left,
self.rotate_right,
self.rotate_left,
)
# endregion DeleteRebalance

25
wavl.py
View file

@ -1,5 +1,6 @@
from ranked_tree import RotateFunction
from avl import AVLTree
from node import Node, NodeType, Comparable, RotateFunction
from node import Node, NodeType, Comparable
import logging
from typing import TypeVar, Optional
@ -35,11 +36,9 @@ class WAVLTree(AVLTree[T]):
y: Node[T],
z: Node[T],
reversed: bool,
rotating_around_root: bool,
rotate_left: RotateFunction[T],
rotate_right: RotateFunction[T],
) -> None:
new_root = x
v = y.left
w = y.right
@ -49,9 +48,7 @@ class WAVLTree(AVLTree[T]):
w_diff = Node.difference(w, y)
if w_diff == 1 and y.parent:
new_root = rotate_left(y.parent)
if rotating_around_root:
self.root = new_root
rotate_left(y.parent)
y.promote()
z.demote()
@ -60,9 +57,7 @@ class WAVLTree(AVLTree[T]):
z.demote()
elif w_diff == 2 and v.parent:
rotate_right(v.parent)
new_root = rotate_left(v.parent)
if rotating_around_root:
self.root = new_root
rotate_left(v.parent)
v.promote().promote()
y.demote()
@ -97,8 +92,6 @@ class WAVLTree(AVLTree[T]):
x_diff = Node.difference(x, parent)
y_diff = Node.difference(y, parent)
rotating_around_root = parent.parent is None
parent_node_diffs = Node.differences(parent)
if parent_node_diffs in ((1, 3), (3, 1)):
if parent.left is x:
@ -108,9 +101,8 @@ class WAVLTree(AVLTree[T]):
parent.right,
parent,
False,
rotating_around_root,
Node.rotate_left,
Node.rotate_right,
self.rotate_left,
self.rotate_right,
)
else:
assert parent.left
@ -119,9 +111,8 @@ class WAVLTree(AVLTree[T]):
parent.left,
parent,
True,
rotating_around_root,
Node.rotate_right,
Node.rotate_left,
self.rotate_right,
self.rotate_left,
)
def _delete_rebalance(