fix(rotations): move rotations to the tree

Move rotations to the RankedTree, which allows us to factor out the
functionality regarding rotations when root of the tree changes.

Signed-off-by: Matej Focko <mfocko@redhat.com>
This commit is contained in:
Matej Focko 2022-05-22 21:00:18 +02:00
parent 7b5bf0ec3a
commit e28f578fc2
Signed by: mfocko
GPG key ID: 7C47D46246790496
5 changed files with 89 additions and 112 deletions

37
avl.py
View file

@ -1,5 +1,5 @@
from node import Node, Comparable, RotateFunction from node import Node, Comparable
from ranked_tree import RankedTree from ranked_tree import RankedTree, RotateFunction
import logging import logging
from typing import TypeVar, Optional from typing import TypeVar, Optional
@ -49,22 +49,18 @@ class AVLTree(RankedTree[T]):
z: Node[T], z: Node[T],
rotate_left: RotateFunction[T], rotate_left: RotateFunction[T],
rotate_right: RotateFunction[T], rotate_right: RotateFunction[T],
) -> Optional[Node[T]]: ) -> None:
new_root = x.parent
if not y or Node.difference(y) == 2: if not y or Node.difference(y) == 2:
new_root = rotate_right(z) rotate_right(z)
z.demote() z.demote()
elif Node.difference(y) == 1: elif Node.difference(y) == 1:
rotate_left(x) rotate_left(x)
new_root = rotate_right(z) rotate_right(z)
y.promote() y.promote()
x.demote() x.demote()
z.demote() z.demote()
return new_root
def _insert_rebalance(self, x: Node[T]) -> None: def _insert_rebalance(self, x: Node[T]) -> None:
diffs = Node.differences(x.parent) diffs = Node.differences(x.parent)
while x.parent and (diffs == (0, 1) or diffs == (1, 0)): while x.parent and (diffs == (0, 1) or diffs == (1, 0)):
@ -76,9 +72,6 @@ class AVLTree(RankedTree[T]):
if not x.parent: if not x.parent:
return return
rotating_around_root = x.parent.parent is None
new_root: Optional[Node[T]] = x.parent
rank_difference = Node.difference(x) rank_difference = Node.difference(x)
if rank_difference != 0: if rank_difference != 0:
return return
@ -86,17 +79,14 @@ class AVLTree(RankedTree[T]):
x_parent = x.parent x_parent = x.parent
assert x_parent is not None assert x_parent is not None
if rank_difference == 0 and x.parent.left is x: if rank_difference == 0 and x.parent.left is x:
new_root = self.__fix_0_child( self.__fix_0_child(
x, x.right, x_parent, Node.rotate_left, Node.rotate_right x, x.right, x_parent, self.rotate_left, self.rotate_right
) )
elif rank_difference == 0 and x.parent.right is x: elif rank_difference == 0 and x.parent.right is x:
new_root = self.__fix_0_child( self.__fix_0_child(
x, x.left, x_parent, Node.rotate_right, Node.rotate_left x, x.left, x_parent, self.rotate_right, self.rotate_left
) )
if rotating_around_root:
self.root = new_root
# endregion InsertRebalance # endregion InsertRebalance
# region DeleteRebalance # region DeleteRebalance
@ -106,7 +96,6 @@ class AVLTree(RankedTree[T]):
x: Node[T], x: Node[T],
y: Node[T], y: Node[T],
leaning: int, leaning: int,
rotating_around_root: bool,
rotate_left: RotateFunction[T], rotate_left: RotateFunction[T],
rotate_right: RotateFunction[T], rotate_right: RotateFunction[T],
) -> bool: ) -> bool:
@ -122,8 +111,6 @@ class AVLTree(RankedTree[T]):
for n in filter(None, (new_root.left, new_root.right, new_root)): for n in filter(None, (new_root.left, new_root.right, new_root)):
_update_rank(n) _update_rank(n)
if rotating_around_root:
self.root = new_root
return factor != 0 return factor != 0
def __delete_fixup( def __delete_fixup(
@ -136,18 +123,16 @@ class AVLTree(RankedTree[T]):
elif factor in (-1, 1): elif factor in (-1, 1):
return False return False
rotating_around_root = x.parent is None
y, leaning, to_left, to_right = ( y, leaning, to_left, to_right = (
(x.right, 1, Node.rotate_left, Node.rotate_right) (x.right, 1, self.rotate_left, self.rotate_right)
if factor == 2 if factor == 2
else (x.left, -1, Node.rotate_right, Node.rotate_left) else (x.left, -1, self.rotate_right, self.rotate_left)
) )
assert y assert y
return self.__delete_rotate( return self.__delete_rotate(
x, x,
y, y,
leaning, leaning,
rotating_around_root,
to_left, to_left,
to_right, to_right,
) )

55
node.py
View file

@ -1,14 +1,12 @@
from abc import abstractmethod from abc import abstractmethod
import enum import enum
from typing import ( from typing import (
Callable,
Iterable, Iterable,
Optional, Optional,
Generic, Generic,
Tuple, Tuple,
TypeVar, TypeVar,
Protocol, Protocol,
Any,
) )
@ -112,56 +110,6 @@ class Node(Generic[T]):
self.rank -= 1 self.rank -= 1
return self return self
@staticmethod
def rotate_right(x: "Node[T]", tree: Any = None) -> "Node[T]":
parent = x.parent
y = x.left
# z = x.right
assert y is not None
if parent:
if parent.left is x:
parent.left = y
else:
parent.right = y
elif tree:
tree.root = y
x.left = y.right
if x.left:
x.left.parent = x
y.right = x
x.parent = y
y.parent = parent
return y
@staticmethod
def rotate_left(x: "Node[T]", tree: Any = None) -> "Node[T]":
parent = x.parent
# y = x.left
z = x.right
assert z is not None
if parent:
if parent.left is x:
parent.left = z
else:
parent.right = z
elif tree:
tree.root = z
x.right = z.left
if x.right:
x.right.parent = x
z.left = x
x.parent = z
z.parent = parent
return z
@staticmethod @staticmethod
def find_parent_node( def find_parent_node(
value: T, node: "Node[T]", missing: bool = True value: T, node: "Node[T]", missing: bool = True
@ -199,6 +147,3 @@ class Node(Generic[T]):
node = node.right node = node.right
return node return node
RotateFunction = Callable[[Node[T]], Node[T]]

View file

@ -3,7 +3,7 @@ from node import Node, Comparable
from abc import abstractmethod from abc import abstractmethod
from collections import deque from collections import deque
import logging import logging
from typing import Deque, Iterable, Optional, Tuple, TypeVar, Generic from typing import Callable, Deque, Iterable, Optional, Tuple, TypeVar, Generic
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
T = TypeVar("T", bound=Comparable) T = TypeVar("T", bound=Comparable)
@ -58,6 +58,58 @@ class RankedTree(Generic[T]):
# endregion TreeSpecificMethods # endregion TreeSpecificMethods
# region Rotations
def rotate_right(self, x: "Node[T]") -> "Node[T]":
parent = x.parent
y = x.left
# z = x.right
assert y is not None
if parent:
if parent.left is x:
parent.left = y
else:
parent.right = y
else:
self.root = y
x.left = y.right
if x.left:
x.left.parent = x
y.right = x
x.parent = y
y.parent = parent
return y
def rotate_left(self, x: "Node[T]") -> "Node[T]":
parent = x.parent
# y = x.left
z = x.right
assert z is not None
if parent:
if parent.left is x:
parent.left = z
else:
parent.right = z
else:
self.root = z
x.right = z.left
if x.right:
x.right.parent = x
z.left = x
x.parent = z
z.parent = parent
return z
# endregion Rotations
@property @property
def rank(self) -> int: def rank(self) -> int:
return Node.get_rank(self.root) return Node.get_rank(self.root)
@ -148,3 +200,6 @@ class RankedTree(Generic[T]):
""" """
if self.root: if self.root:
yield from self.root yield from self.root
RotateFunction = Callable[[Node[T]], Node[T]]

27
rbt.py
View file

@ -1,5 +1,5 @@
from node import Node, Comparable, RotateFunction from node import Node, Comparable
from ranked_tree import RankedTree from ranked_tree import RankedTree, RotateFunction
import enum import enum
import logging import logging
@ -13,6 +13,7 @@ class Colour(enum.IntEnum):
""" """
Represents colour of the edge or node. Represents colour of the edge or node.
""" """
Red = 0 Red = 0
Black = 1 Black = 1
@ -80,12 +81,12 @@ class RBTree(RankedTree[T]):
# ====== # ======
# zs uncle y is black and z is a right child # zs uncle y is black and z is a right child
z = p z = p
rotate_left(tree=self, x=p) rotate_left(p)
else: else:
# Case 3 # Case 3
# ====== # ======
# zs uncle y is black and z is a left child # zs uncle y is black and z is a left child
rotate_right(tree=self, x=pp) rotate_right(pp)
return z return z
@ -97,11 +98,11 @@ class RBTree(RankedTree[T]):
assert pp assert pp
if p == pp.left: if p == pp.left:
z = self._insert_rebalance_step( z = self._insert_rebalance_step(
z, pp.right, p.right, Node.rotate_left, Node.rotate_right z, pp.right, p.right, self.rotate_left, self.rotate_right
) )
else: else:
z = self._insert_rebalance_step( z = self._insert_rebalance_step(
z, pp.left, p.left, Node.rotate_right, Node.rotate_left z, pp.left, p.left, self.rotate_right, self.rotate_left
) )
# endregion InsertRebalance # endregion InsertRebalance
@ -121,7 +122,7 @@ class RBTree(RankedTree[T]):
# Case 1 # Case 1
# ====== # ======
# xs sibling w is red # xs sibling w is red
rotate_left(tree=self, x=parent) rotate_left(parent)
w = right(parent) w = right(parent)
if Node.differences(w) == (Colour.Black, Colour.Black): if Node.differences(w) == (Colour.Black, Colour.Black):
@ -136,7 +137,7 @@ class RBTree(RankedTree[T]):
# xs sibling w is black, # xs sibling w is black,
# ws left child is red, and ws right child is black # ws left child is red, and ws right child is black
if Node.difference(right(w), w) == Colour.Black: if Node.difference(right(w), w) == Colour.Black:
rotate_right(tree=self, x=w) rotate_right(w)
w = right(parent) w = right(parent)
# Case 4 # Case 4
@ -144,7 +145,7 @@ class RBTree(RankedTree[T]):
# xs sibling w is black, and ws right child is red # xs sibling w is black, and ws right child is red
parent.rank -= Colour.Black parent.rank -= Colour.Black
w.rank += Colour.Black w.rank += Colour.Black
rotate_left(tree=self, x=parent) rotate_left(parent)
x = self.root x = self.root
return x, (x.parent if x else None) return x, (x.parent if x else None)
@ -162,8 +163,8 @@ class RBTree(RankedTree[T]):
parent.right, parent.right,
parent, parent,
lambda x: x.right, lambda x: x.right,
Node.rotate_left, self.rotate_left,
Node.rotate_right, self.rotate_right,
) )
else: else:
node, parent = self._delete_rebalance_step( node, parent = self._delete_rebalance_step(
@ -171,8 +172,8 @@ class RBTree(RankedTree[T]):
parent.left, parent.left,
parent, parent,
lambda x: x.left, lambda x: x.left,
Node.rotate_right, self.rotate_right,
Node.rotate_left, self.rotate_left,
) )
# endregion DeleteRebalance # endregion DeleteRebalance

25
wavl.py
View file

@ -1,5 +1,6 @@
from ranked_tree import RotateFunction
from avl import AVLTree from avl import AVLTree
from node import Node, NodeType, Comparable, RotateFunction from node import Node, NodeType, Comparable
import logging import logging
from typing import TypeVar, Optional from typing import TypeVar, Optional
@ -35,11 +36,9 @@ class WAVLTree(AVLTree[T]):
y: Node[T], y: Node[T],
z: Node[T], z: Node[T],
reversed: bool, reversed: bool,
rotating_around_root: bool,
rotate_left: RotateFunction[T], rotate_left: RotateFunction[T],
rotate_right: RotateFunction[T], rotate_right: RotateFunction[T],
) -> None: ) -> None:
new_root = x
v = y.left v = y.left
w = y.right w = y.right
@ -49,9 +48,7 @@ class WAVLTree(AVLTree[T]):
w_diff = Node.difference(w, y) w_diff = Node.difference(w, y)
if w_diff == 1 and y.parent: if w_diff == 1 and y.parent:
new_root = rotate_left(y.parent) rotate_left(y.parent)
if rotating_around_root:
self.root = new_root
y.promote() y.promote()
z.demote() z.demote()
@ -60,9 +57,7 @@ class WAVLTree(AVLTree[T]):
z.demote() z.demote()
elif w_diff == 2 and v.parent: elif w_diff == 2 and v.parent:
rotate_right(v.parent) rotate_right(v.parent)
new_root = rotate_left(v.parent) rotate_left(v.parent)
if rotating_around_root:
self.root = new_root
v.promote().promote() v.promote().promote()
y.demote() y.demote()
@ -97,8 +92,6 @@ class WAVLTree(AVLTree[T]):
x_diff = Node.difference(x, parent) x_diff = Node.difference(x, parent)
y_diff = Node.difference(y, parent) y_diff = Node.difference(y, parent)
rotating_around_root = parent.parent is None
parent_node_diffs = Node.differences(parent) parent_node_diffs = Node.differences(parent)
if parent_node_diffs in ((1, 3), (3, 1)): if parent_node_diffs in ((1, 3), (3, 1)):
if parent.left is x: if parent.left is x:
@ -108,9 +101,8 @@ class WAVLTree(AVLTree[T]):
parent.right, parent.right,
parent, parent,
False, False,
rotating_around_root, self.rotate_left,
Node.rotate_left, self.rotate_right,
Node.rotate_right,
) )
else: else:
assert parent.left assert parent.left
@ -119,9 +111,8 @@ class WAVLTree(AVLTree[T]):
parent.left, parent.left,
parent, parent,
True, True,
rotating_around_root, self.rotate_right,
Node.rotate_right, self.rotate_left,
Node.rotate_left,
) )
def _delete_rebalance( def _delete_rebalance(