fix(rotations): move rotations to the tree
Move rotations to the RankedTree, which allows us to factor out the functionality regarding rotations when root of the tree changes. Signed-off-by: Matej Focko <mfocko@redhat.com>
This commit is contained in:
parent
7b5bf0ec3a
commit
e28f578fc2
5 changed files with 89 additions and 112 deletions
37
avl.py
37
avl.py
|
@ -1,5 +1,5 @@
|
|||
from node import Node, Comparable, RotateFunction
|
||||
from ranked_tree import RankedTree
|
||||
from node import Node, Comparable
|
||||
from ranked_tree import RankedTree, RotateFunction
|
||||
|
||||
import logging
|
||||
from typing import TypeVar, Optional
|
||||
|
@ -49,22 +49,18 @@ class AVLTree(RankedTree[T]):
|
|||
z: Node[T],
|
||||
rotate_left: RotateFunction[T],
|
||||
rotate_right: RotateFunction[T],
|
||||
) -> Optional[Node[T]]:
|
||||
new_root = x.parent
|
||||
|
||||
) -> None:
|
||||
if not y or Node.difference(y) == 2:
|
||||
new_root = rotate_right(z)
|
||||
rotate_right(z)
|
||||
z.demote()
|
||||
elif Node.difference(y) == 1:
|
||||
rotate_left(x)
|
||||
new_root = rotate_right(z)
|
||||
rotate_right(z)
|
||||
|
||||
y.promote()
|
||||
x.demote()
|
||||
z.demote()
|
||||
|
||||
return new_root
|
||||
|
||||
def _insert_rebalance(self, x: Node[T]) -> None:
|
||||
diffs = Node.differences(x.parent)
|
||||
while x.parent and (diffs == (0, 1) or diffs == (1, 0)):
|
||||
|
@ -76,9 +72,6 @@ class AVLTree(RankedTree[T]):
|
|||
if not x.parent:
|
||||
return
|
||||
|
||||
rotating_around_root = x.parent.parent is None
|
||||
new_root: Optional[Node[T]] = x.parent
|
||||
|
||||
rank_difference = Node.difference(x)
|
||||
if rank_difference != 0:
|
||||
return
|
||||
|
@ -86,17 +79,14 @@ class AVLTree(RankedTree[T]):
|
|||
x_parent = x.parent
|
||||
assert x_parent is not None
|
||||
if rank_difference == 0 and x.parent.left is x:
|
||||
new_root = self.__fix_0_child(
|
||||
x, x.right, x_parent, Node.rotate_left, Node.rotate_right
|
||||
self.__fix_0_child(
|
||||
x, x.right, x_parent, self.rotate_left, self.rotate_right
|
||||
)
|
||||
elif rank_difference == 0 and x.parent.right is x:
|
||||
new_root = self.__fix_0_child(
|
||||
x, x.left, x_parent, Node.rotate_right, Node.rotate_left
|
||||
self.__fix_0_child(
|
||||
x, x.left, x_parent, self.rotate_right, self.rotate_left
|
||||
)
|
||||
|
||||
if rotating_around_root:
|
||||
self.root = new_root
|
||||
|
||||
# endregion InsertRebalance
|
||||
|
||||
# region DeleteRebalance
|
||||
|
@ -106,7 +96,6 @@ class AVLTree(RankedTree[T]):
|
|||
x: Node[T],
|
||||
y: Node[T],
|
||||
leaning: int,
|
||||
rotating_around_root: bool,
|
||||
rotate_left: RotateFunction[T],
|
||||
rotate_right: RotateFunction[T],
|
||||
) -> bool:
|
||||
|
@ -122,8 +111,6 @@ class AVLTree(RankedTree[T]):
|
|||
for n in filter(None, (new_root.left, new_root.right, new_root)):
|
||||
_update_rank(n)
|
||||
|
||||
if rotating_around_root:
|
||||
self.root = new_root
|
||||
return factor != 0
|
||||
|
||||
def __delete_fixup(
|
||||
|
@ -136,18 +123,16 @@ class AVLTree(RankedTree[T]):
|
|||
elif factor in (-1, 1):
|
||||
return False
|
||||
|
||||
rotating_around_root = x.parent is None
|
||||
y, leaning, to_left, to_right = (
|
||||
(x.right, 1, Node.rotate_left, Node.rotate_right)
|
||||
(x.right, 1, self.rotate_left, self.rotate_right)
|
||||
if factor == 2
|
||||
else (x.left, -1, Node.rotate_right, Node.rotate_left)
|
||||
else (x.left, -1, self.rotate_right, self.rotate_left)
|
||||
)
|
||||
assert y
|
||||
return self.__delete_rotate(
|
||||
x,
|
||||
y,
|
||||
leaning,
|
||||
rotating_around_root,
|
||||
to_left,
|
||||
to_right,
|
||||
)
|
||||
|
|
55
node.py
55
node.py
|
@ -1,14 +1,12 @@
|
|||
from abc import abstractmethod
|
||||
import enum
|
||||
from typing import (
|
||||
Callable,
|
||||
Iterable,
|
||||
Optional,
|
||||
Generic,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Protocol,
|
||||
Any,
|
||||
)
|
||||
|
||||
|
||||
|
@ -112,56 +110,6 @@ class Node(Generic[T]):
|
|||
self.rank -= 1
|
||||
return self
|
||||
|
||||
@staticmethod
|
||||
def rotate_right(x: "Node[T]", tree: Any = None) -> "Node[T]":
|
||||
parent = x.parent
|
||||
y = x.left
|
||||
# z = x.right
|
||||
|
||||
assert y is not None
|
||||
if parent:
|
||||
if parent.left is x:
|
||||
parent.left = y
|
||||
else:
|
||||
parent.right = y
|
||||
elif tree:
|
||||
tree.root = y
|
||||
|
||||
x.left = y.right
|
||||
if x.left:
|
||||
x.left.parent = x
|
||||
|
||||
y.right = x
|
||||
x.parent = y
|
||||
y.parent = parent
|
||||
|
||||
return y
|
||||
|
||||
@staticmethod
|
||||
def rotate_left(x: "Node[T]", tree: Any = None) -> "Node[T]":
|
||||
parent = x.parent
|
||||
# y = x.left
|
||||
z = x.right
|
||||
|
||||
assert z is not None
|
||||
if parent:
|
||||
if parent.left is x:
|
||||
parent.left = z
|
||||
else:
|
||||
parent.right = z
|
||||
elif tree:
|
||||
tree.root = z
|
||||
|
||||
x.right = z.left
|
||||
if x.right:
|
||||
x.right.parent = x
|
||||
|
||||
z.left = x
|
||||
x.parent = z
|
||||
z.parent = parent
|
||||
|
||||
return z
|
||||
|
||||
@staticmethod
|
||||
def find_parent_node(
|
||||
value: T, node: "Node[T]", missing: bool = True
|
||||
|
@ -199,6 +147,3 @@ class Node(Generic[T]):
|
|||
node = node.right
|
||||
|
||||
return node
|
||||
|
||||
|
||||
RotateFunction = Callable[[Node[T]], Node[T]]
|
||||
|
|
|
@ -3,7 +3,7 @@ from node import Node, Comparable
|
|||
from abc import abstractmethod
|
||||
from collections import deque
|
||||
import logging
|
||||
from typing import Deque, Iterable, Optional, Tuple, TypeVar, Generic
|
||||
from typing import Callable, Deque, Iterable, Optional, Tuple, TypeVar, Generic
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
T = TypeVar("T", bound=Comparable)
|
||||
|
@ -58,6 +58,58 @@ class RankedTree(Generic[T]):
|
|||
|
||||
# endregion TreeSpecificMethods
|
||||
|
||||
# region Rotations
|
||||
|
||||
def rotate_right(self, x: "Node[T]") -> "Node[T]":
|
||||
parent = x.parent
|
||||
y = x.left
|
||||
# z = x.right
|
||||
|
||||
assert y is not None
|
||||
if parent:
|
||||
if parent.left is x:
|
||||
parent.left = y
|
||||
else:
|
||||
parent.right = y
|
||||
else:
|
||||
self.root = y
|
||||
|
||||
x.left = y.right
|
||||
if x.left:
|
||||
x.left.parent = x
|
||||
|
||||
y.right = x
|
||||
x.parent = y
|
||||
y.parent = parent
|
||||
|
||||
return y
|
||||
|
||||
def rotate_left(self, x: "Node[T]") -> "Node[T]":
|
||||
parent = x.parent
|
||||
# y = x.left
|
||||
z = x.right
|
||||
|
||||
assert z is not None
|
||||
if parent:
|
||||
if parent.left is x:
|
||||
parent.left = z
|
||||
else:
|
||||
parent.right = z
|
||||
else:
|
||||
self.root = z
|
||||
|
||||
x.right = z.left
|
||||
if x.right:
|
||||
x.right.parent = x
|
||||
|
||||
z.left = x
|
||||
x.parent = z
|
||||
z.parent = parent
|
||||
|
||||
return z
|
||||
|
||||
# endregion Rotations
|
||||
|
||||
@property
|
||||
def rank(self) -> int:
|
||||
return Node.get_rank(self.root)
|
||||
|
@ -148,3 +200,6 @@ class RankedTree(Generic[T]):
|
|||
"""
|
||||
if self.root:
|
||||
yield from self.root
|
||||
|
||||
|
||||
RotateFunction = Callable[[Node[T]], Node[T]]
|
||||
|
|
27
rbt.py
27
rbt.py
|
@ -1,5 +1,5 @@
|
|||
from node import Node, Comparable, RotateFunction
|
||||
from ranked_tree import RankedTree
|
||||
from node import Node, Comparable
|
||||
from ranked_tree import RankedTree, RotateFunction
|
||||
|
||||
import enum
|
||||
import logging
|
||||
|
@ -13,6 +13,7 @@ class Colour(enum.IntEnum):
|
|||
"""
|
||||
Represents colour of the edge or node.
|
||||
"""
|
||||
|
||||
Red = 0
|
||||
Black = 1
|
||||
|
||||
|
@ -80,12 +81,12 @@ class RBTree(RankedTree[T]):
|
|||
# ======
|
||||
# z’s uncle y is black and z is a right child
|
||||
z = p
|
||||
rotate_left(tree=self, x=p)
|
||||
rotate_left(p)
|
||||
else:
|
||||
# Case 3
|
||||
# ======
|
||||
# z’s uncle y is black and z is a left child
|
||||
rotate_right(tree=self, x=pp)
|
||||
rotate_right(pp)
|
||||
|
||||
return z
|
||||
|
||||
|
@ -97,11 +98,11 @@ class RBTree(RankedTree[T]):
|
|||
assert pp
|
||||
if p == pp.left:
|
||||
z = self._insert_rebalance_step(
|
||||
z, pp.right, p.right, Node.rotate_left, Node.rotate_right
|
||||
z, pp.right, p.right, self.rotate_left, self.rotate_right
|
||||
)
|
||||
else:
|
||||
z = self._insert_rebalance_step(
|
||||
z, pp.left, p.left, Node.rotate_right, Node.rotate_left
|
||||
z, pp.left, p.left, self.rotate_right, self.rotate_left
|
||||
)
|
||||
|
||||
# endregion InsertRebalance
|
||||
|
@ -121,7 +122,7 @@ class RBTree(RankedTree[T]):
|
|||
# Case 1
|
||||
# ======
|
||||
# x’s sibling w is red
|
||||
rotate_left(tree=self, x=parent)
|
||||
rotate_left(parent)
|
||||
w = right(parent)
|
||||
|
||||
if Node.differences(w) == (Colour.Black, Colour.Black):
|
||||
|
@ -136,7 +137,7 @@ class RBTree(RankedTree[T]):
|
|||
# x’s sibling w is black,
|
||||
# w’s left child is red, and w’s right child is black
|
||||
if Node.difference(right(w), w) == Colour.Black:
|
||||
rotate_right(tree=self, x=w)
|
||||
rotate_right(w)
|
||||
w = right(parent)
|
||||
|
||||
# Case 4
|
||||
|
@ -144,7 +145,7 @@ class RBTree(RankedTree[T]):
|
|||
# x’s sibling w is black, and w’s right child is red
|
||||
parent.rank -= Colour.Black
|
||||
w.rank += Colour.Black
|
||||
rotate_left(tree=self, x=parent)
|
||||
rotate_left(parent)
|
||||
x = self.root
|
||||
|
||||
return x, (x.parent if x else None)
|
||||
|
@ -162,8 +163,8 @@ class RBTree(RankedTree[T]):
|
|||
parent.right,
|
||||
parent,
|
||||
lambda x: x.right,
|
||||
Node.rotate_left,
|
||||
Node.rotate_right,
|
||||
self.rotate_left,
|
||||
self.rotate_right,
|
||||
)
|
||||
else:
|
||||
node, parent = self._delete_rebalance_step(
|
||||
|
@ -171,8 +172,8 @@ class RBTree(RankedTree[T]):
|
|||
parent.left,
|
||||
parent,
|
||||
lambda x: x.left,
|
||||
Node.rotate_right,
|
||||
Node.rotate_left,
|
||||
self.rotate_right,
|
||||
self.rotate_left,
|
||||
)
|
||||
|
||||
# endregion DeleteRebalance
|
||||
|
|
25
wavl.py
25
wavl.py
|
@ -1,5 +1,6 @@
|
|||
from ranked_tree import RotateFunction
|
||||
from avl import AVLTree
|
||||
from node import Node, NodeType, Comparable, RotateFunction
|
||||
from node import Node, NodeType, Comparable
|
||||
|
||||
import logging
|
||||
from typing import TypeVar, Optional
|
||||
|
@ -35,11 +36,9 @@ class WAVLTree(AVLTree[T]):
|
|||
y: Node[T],
|
||||
z: Node[T],
|
||||
reversed: bool,
|
||||
rotating_around_root: bool,
|
||||
rotate_left: RotateFunction[T],
|
||||
rotate_right: RotateFunction[T],
|
||||
) -> None:
|
||||
new_root = x
|
||||
v = y.left
|
||||
w = y.right
|
||||
|
||||
|
@ -49,9 +48,7 @@ class WAVLTree(AVLTree[T]):
|
|||
w_diff = Node.difference(w, y)
|
||||
|
||||
if w_diff == 1 and y.parent:
|
||||
new_root = rotate_left(y.parent)
|
||||
if rotating_around_root:
|
||||
self.root = new_root
|
||||
rotate_left(y.parent)
|
||||
|
||||
y.promote()
|
||||
z.demote()
|
||||
|
@ -60,9 +57,7 @@ class WAVLTree(AVLTree[T]):
|
|||
z.demote()
|
||||
elif w_diff == 2 and v.parent:
|
||||
rotate_right(v.parent)
|
||||
new_root = rotate_left(v.parent)
|
||||
if rotating_around_root:
|
||||
self.root = new_root
|
||||
rotate_left(v.parent)
|
||||
|
||||
v.promote().promote()
|
||||
y.demote()
|
||||
|
@ -97,8 +92,6 @@ class WAVLTree(AVLTree[T]):
|
|||
x_diff = Node.difference(x, parent)
|
||||
y_diff = Node.difference(y, parent)
|
||||
|
||||
rotating_around_root = parent.parent is None
|
||||
|
||||
parent_node_diffs = Node.differences(parent)
|
||||
if parent_node_diffs in ((1, 3), (3, 1)):
|
||||
if parent.left is x:
|
||||
|
@ -108,9 +101,8 @@ class WAVLTree(AVLTree[T]):
|
|||
parent.right,
|
||||
parent,
|
||||
False,
|
||||
rotating_around_root,
|
||||
Node.rotate_left,
|
||||
Node.rotate_right,
|
||||
self.rotate_left,
|
||||
self.rotate_right,
|
||||
)
|
||||
else:
|
||||
assert parent.left
|
||||
|
@ -119,9 +111,8 @@ class WAVLTree(AVLTree[T]):
|
|||
parent.left,
|
||||
parent,
|
||||
True,
|
||||
rotating_around_root,
|
||||
Node.rotate_right,
|
||||
Node.rotate_left,
|
||||
self.rotate_right,
|
||||
self.rotate_left,
|
||||
)
|
||||
|
||||
def _delete_rebalance(
|
||||
|
|
Loading…
Reference in a new issue