fix(rotations): move rotations to the tree
Move rotations to the RankedTree, which allows us to factor out the functionality regarding rotations when root of the tree changes. Signed-off-by: Matej Focko <mfocko@redhat.com>
This commit is contained in:
parent
7b5bf0ec3a
commit
e28f578fc2
5 changed files with 89 additions and 112 deletions
37
avl.py
37
avl.py
|
@ -1,5 +1,5 @@
|
||||||
from node import Node, Comparable, RotateFunction
|
from node import Node, Comparable
|
||||||
from ranked_tree import RankedTree
|
from ranked_tree import RankedTree, RotateFunction
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import TypeVar, Optional
|
from typing import TypeVar, Optional
|
||||||
|
@ -49,22 +49,18 @@ class AVLTree(RankedTree[T]):
|
||||||
z: Node[T],
|
z: Node[T],
|
||||||
rotate_left: RotateFunction[T],
|
rotate_left: RotateFunction[T],
|
||||||
rotate_right: RotateFunction[T],
|
rotate_right: RotateFunction[T],
|
||||||
) -> Optional[Node[T]]:
|
) -> None:
|
||||||
new_root = x.parent
|
|
||||||
|
|
||||||
if not y or Node.difference(y) == 2:
|
if not y or Node.difference(y) == 2:
|
||||||
new_root = rotate_right(z)
|
rotate_right(z)
|
||||||
z.demote()
|
z.demote()
|
||||||
elif Node.difference(y) == 1:
|
elif Node.difference(y) == 1:
|
||||||
rotate_left(x)
|
rotate_left(x)
|
||||||
new_root = rotate_right(z)
|
rotate_right(z)
|
||||||
|
|
||||||
y.promote()
|
y.promote()
|
||||||
x.demote()
|
x.demote()
|
||||||
z.demote()
|
z.demote()
|
||||||
|
|
||||||
return new_root
|
|
||||||
|
|
||||||
def _insert_rebalance(self, x: Node[T]) -> None:
|
def _insert_rebalance(self, x: Node[T]) -> None:
|
||||||
diffs = Node.differences(x.parent)
|
diffs = Node.differences(x.parent)
|
||||||
while x.parent and (diffs == (0, 1) or diffs == (1, 0)):
|
while x.parent and (diffs == (0, 1) or diffs == (1, 0)):
|
||||||
|
@ -76,9 +72,6 @@ class AVLTree(RankedTree[T]):
|
||||||
if not x.parent:
|
if not x.parent:
|
||||||
return
|
return
|
||||||
|
|
||||||
rotating_around_root = x.parent.parent is None
|
|
||||||
new_root: Optional[Node[T]] = x.parent
|
|
||||||
|
|
||||||
rank_difference = Node.difference(x)
|
rank_difference = Node.difference(x)
|
||||||
if rank_difference != 0:
|
if rank_difference != 0:
|
||||||
return
|
return
|
||||||
|
@ -86,17 +79,14 @@ class AVLTree(RankedTree[T]):
|
||||||
x_parent = x.parent
|
x_parent = x.parent
|
||||||
assert x_parent is not None
|
assert x_parent is not None
|
||||||
if rank_difference == 0 and x.parent.left is x:
|
if rank_difference == 0 and x.parent.left is x:
|
||||||
new_root = self.__fix_0_child(
|
self.__fix_0_child(
|
||||||
x, x.right, x_parent, Node.rotate_left, Node.rotate_right
|
x, x.right, x_parent, self.rotate_left, self.rotate_right
|
||||||
)
|
)
|
||||||
elif rank_difference == 0 and x.parent.right is x:
|
elif rank_difference == 0 and x.parent.right is x:
|
||||||
new_root = self.__fix_0_child(
|
self.__fix_0_child(
|
||||||
x, x.left, x_parent, Node.rotate_right, Node.rotate_left
|
x, x.left, x_parent, self.rotate_right, self.rotate_left
|
||||||
)
|
)
|
||||||
|
|
||||||
if rotating_around_root:
|
|
||||||
self.root = new_root
|
|
||||||
|
|
||||||
# endregion InsertRebalance
|
# endregion InsertRebalance
|
||||||
|
|
||||||
# region DeleteRebalance
|
# region DeleteRebalance
|
||||||
|
@ -106,7 +96,6 @@ class AVLTree(RankedTree[T]):
|
||||||
x: Node[T],
|
x: Node[T],
|
||||||
y: Node[T],
|
y: Node[T],
|
||||||
leaning: int,
|
leaning: int,
|
||||||
rotating_around_root: bool,
|
|
||||||
rotate_left: RotateFunction[T],
|
rotate_left: RotateFunction[T],
|
||||||
rotate_right: RotateFunction[T],
|
rotate_right: RotateFunction[T],
|
||||||
) -> bool:
|
) -> bool:
|
||||||
|
@ -122,8 +111,6 @@ class AVLTree(RankedTree[T]):
|
||||||
for n in filter(None, (new_root.left, new_root.right, new_root)):
|
for n in filter(None, (new_root.left, new_root.right, new_root)):
|
||||||
_update_rank(n)
|
_update_rank(n)
|
||||||
|
|
||||||
if rotating_around_root:
|
|
||||||
self.root = new_root
|
|
||||||
return factor != 0
|
return factor != 0
|
||||||
|
|
||||||
def __delete_fixup(
|
def __delete_fixup(
|
||||||
|
@ -136,18 +123,16 @@ class AVLTree(RankedTree[T]):
|
||||||
elif factor in (-1, 1):
|
elif factor in (-1, 1):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
rotating_around_root = x.parent is None
|
|
||||||
y, leaning, to_left, to_right = (
|
y, leaning, to_left, to_right = (
|
||||||
(x.right, 1, Node.rotate_left, Node.rotate_right)
|
(x.right, 1, self.rotate_left, self.rotate_right)
|
||||||
if factor == 2
|
if factor == 2
|
||||||
else (x.left, -1, Node.rotate_right, Node.rotate_left)
|
else (x.left, -1, self.rotate_right, self.rotate_left)
|
||||||
)
|
)
|
||||||
assert y
|
assert y
|
||||||
return self.__delete_rotate(
|
return self.__delete_rotate(
|
||||||
x,
|
x,
|
||||||
y,
|
y,
|
||||||
leaning,
|
leaning,
|
||||||
rotating_around_root,
|
|
||||||
to_left,
|
to_left,
|
||||||
to_right,
|
to_right,
|
||||||
)
|
)
|
||||||
|
|
55
node.py
55
node.py
|
@ -1,14 +1,12 @@
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
import enum
|
import enum
|
||||||
from typing import (
|
from typing import (
|
||||||
Callable,
|
|
||||||
Iterable,
|
Iterable,
|
||||||
Optional,
|
Optional,
|
||||||
Generic,
|
Generic,
|
||||||
Tuple,
|
Tuple,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
Protocol,
|
Protocol,
|
||||||
Any,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -112,56 +110,6 @@ class Node(Generic[T]):
|
||||||
self.rank -= 1
|
self.rank -= 1
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def rotate_right(x: "Node[T]", tree: Any = None) -> "Node[T]":
|
|
||||||
parent = x.parent
|
|
||||||
y = x.left
|
|
||||||
# z = x.right
|
|
||||||
|
|
||||||
assert y is not None
|
|
||||||
if parent:
|
|
||||||
if parent.left is x:
|
|
||||||
parent.left = y
|
|
||||||
else:
|
|
||||||
parent.right = y
|
|
||||||
elif tree:
|
|
||||||
tree.root = y
|
|
||||||
|
|
||||||
x.left = y.right
|
|
||||||
if x.left:
|
|
||||||
x.left.parent = x
|
|
||||||
|
|
||||||
y.right = x
|
|
||||||
x.parent = y
|
|
||||||
y.parent = parent
|
|
||||||
|
|
||||||
return y
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def rotate_left(x: "Node[T]", tree: Any = None) -> "Node[T]":
|
|
||||||
parent = x.parent
|
|
||||||
# y = x.left
|
|
||||||
z = x.right
|
|
||||||
|
|
||||||
assert z is not None
|
|
||||||
if parent:
|
|
||||||
if parent.left is x:
|
|
||||||
parent.left = z
|
|
||||||
else:
|
|
||||||
parent.right = z
|
|
||||||
elif tree:
|
|
||||||
tree.root = z
|
|
||||||
|
|
||||||
x.right = z.left
|
|
||||||
if x.right:
|
|
||||||
x.right.parent = x
|
|
||||||
|
|
||||||
z.left = x
|
|
||||||
x.parent = z
|
|
||||||
z.parent = parent
|
|
||||||
|
|
||||||
return z
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def find_parent_node(
|
def find_parent_node(
|
||||||
value: T, node: "Node[T]", missing: bool = True
|
value: T, node: "Node[T]", missing: bool = True
|
||||||
|
@ -199,6 +147,3 @@ class Node(Generic[T]):
|
||||||
node = node.right
|
node = node.right
|
||||||
|
|
||||||
return node
|
return node
|
||||||
|
|
||||||
|
|
||||||
RotateFunction = Callable[[Node[T]], Node[T]]
|
|
||||||
|
|
|
@ -3,7 +3,7 @@ from node import Node, Comparable
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from collections import deque
|
from collections import deque
|
||||||
import logging
|
import logging
|
||||||
from typing import Deque, Iterable, Optional, Tuple, TypeVar, Generic
|
from typing import Callable, Deque, Iterable, Optional, Tuple, TypeVar, Generic
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
T = TypeVar("T", bound=Comparable)
|
T = TypeVar("T", bound=Comparable)
|
||||||
|
@ -58,6 +58,58 @@ class RankedTree(Generic[T]):
|
||||||
|
|
||||||
# endregion TreeSpecificMethods
|
# endregion TreeSpecificMethods
|
||||||
|
|
||||||
|
# region Rotations
|
||||||
|
|
||||||
|
def rotate_right(self, x: "Node[T]") -> "Node[T]":
|
||||||
|
parent = x.parent
|
||||||
|
y = x.left
|
||||||
|
# z = x.right
|
||||||
|
|
||||||
|
assert y is not None
|
||||||
|
if parent:
|
||||||
|
if parent.left is x:
|
||||||
|
parent.left = y
|
||||||
|
else:
|
||||||
|
parent.right = y
|
||||||
|
else:
|
||||||
|
self.root = y
|
||||||
|
|
||||||
|
x.left = y.right
|
||||||
|
if x.left:
|
||||||
|
x.left.parent = x
|
||||||
|
|
||||||
|
y.right = x
|
||||||
|
x.parent = y
|
||||||
|
y.parent = parent
|
||||||
|
|
||||||
|
return y
|
||||||
|
|
||||||
|
def rotate_left(self, x: "Node[T]") -> "Node[T]":
|
||||||
|
parent = x.parent
|
||||||
|
# y = x.left
|
||||||
|
z = x.right
|
||||||
|
|
||||||
|
assert z is not None
|
||||||
|
if parent:
|
||||||
|
if parent.left is x:
|
||||||
|
parent.left = z
|
||||||
|
else:
|
||||||
|
parent.right = z
|
||||||
|
else:
|
||||||
|
self.root = z
|
||||||
|
|
||||||
|
x.right = z.left
|
||||||
|
if x.right:
|
||||||
|
x.right.parent = x
|
||||||
|
|
||||||
|
z.left = x
|
||||||
|
x.parent = z
|
||||||
|
z.parent = parent
|
||||||
|
|
||||||
|
return z
|
||||||
|
|
||||||
|
# endregion Rotations
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def rank(self) -> int:
|
def rank(self) -> int:
|
||||||
return Node.get_rank(self.root)
|
return Node.get_rank(self.root)
|
||||||
|
@ -148,3 +200,6 @@ class RankedTree(Generic[T]):
|
||||||
"""
|
"""
|
||||||
if self.root:
|
if self.root:
|
||||||
yield from self.root
|
yield from self.root
|
||||||
|
|
||||||
|
|
||||||
|
RotateFunction = Callable[[Node[T]], Node[T]]
|
||||||
|
|
27
rbt.py
27
rbt.py
|
@ -1,5 +1,5 @@
|
||||||
from node import Node, Comparable, RotateFunction
|
from node import Node, Comparable
|
||||||
from ranked_tree import RankedTree
|
from ranked_tree import RankedTree, RotateFunction
|
||||||
|
|
||||||
import enum
|
import enum
|
||||||
import logging
|
import logging
|
||||||
|
@ -13,6 +13,7 @@ class Colour(enum.IntEnum):
|
||||||
"""
|
"""
|
||||||
Represents colour of the edge or node.
|
Represents colour of the edge or node.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
Red = 0
|
Red = 0
|
||||||
Black = 1
|
Black = 1
|
||||||
|
|
||||||
|
@ -80,12 +81,12 @@ class RBTree(RankedTree[T]):
|
||||||
# ======
|
# ======
|
||||||
# z’s uncle y is black and z is a right child
|
# z’s uncle y is black and z is a right child
|
||||||
z = p
|
z = p
|
||||||
rotate_left(tree=self, x=p)
|
rotate_left(p)
|
||||||
else:
|
else:
|
||||||
# Case 3
|
# Case 3
|
||||||
# ======
|
# ======
|
||||||
# z’s uncle y is black and z is a left child
|
# z’s uncle y is black and z is a left child
|
||||||
rotate_right(tree=self, x=pp)
|
rotate_right(pp)
|
||||||
|
|
||||||
return z
|
return z
|
||||||
|
|
||||||
|
@ -97,11 +98,11 @@ class RBTree(RankedTree[T]):
|
||||||
assert pp
|
assert pp
|
||||||
if p == pp.left:
|
if p == pp.left:
|
||||||
z = self._insert_rebalance_step(
|
z = self._insert_rebalance_step(
|
||||||
z, pp.right, p.right, Node.rotate_left, Node.rotate_right
|
z, pp.right, p.right, self.rotate_left, self.rotate_right
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
z = self._insert_rebalance_step(
|
z = self._insert_rebalance_step(
|
||||||
z, pp.left, p.left, Node.rotate_right, Node.rotate_left
|
z, pp.left, p.left, self.rotate_right, self.rotate_left
|
||||||
)
|
)
|
||||||
|
|
||||||
# endregion InsertRebalance
|
# endregion InsertRebalance
|
||||||
|
@ -121,7 +122,7 @@ class RBTree(RankedTree[T]):
|
||||||
# Case 1
|
# Case 1
|
||||||
# ======
|
# ======
|
||||||
# x’s sibling w is red
|
# x’s sibling w is red
|
||||||
rotate_left(tree=self, x=parent)
|
rotate_left(parent)
|
||||||
w = right(parent)
|
w = right(parent)
|
||||||
|
|
||||||
if Node.differences(w) == (Colour.Black, Colour.Black):
|
if Node.differences(w) == (Colour.Black, Colour.Black):
|
||||||
|
@ -136,7 +137,7 @@ class RBTree(RankedTree[T]):
|
||||||
# x’s sibling w is black,
|
# x’s sibling w is black,
|
||||||
# w’s left child is red, and w’s right child is black
|
# w’s left child is red, and w’s right child is black
|
||||||
if Node.difference(right(w), w) == Colour.Black:
|
if Node.difference(right(w), w) == Colour.Black:
|
||||||
rotate_right(tree=self, x=w)
|
rotate_right(w)
|
||||||
w = right(parent)
|
w = right(parent)
|
||||||
|
|
||||||
# Case 4
|
# Case 4
|
||||||
|
@ -144,7 +145,7 @@ class RBTree(RankedTree[T]):
|
||||||
# x’s sibling w is black, and w’s right child is red
|
# x’s sibling w is black, and w’s right child is red
|
||||||
parent.rank -= Colour.Black
|
parent.rank -= Colour.Black
|
||||||
w.rank += Colour.Black
|
w.rank += Colour.Black
|
||||||
rotate_left(tree=self, x=parent)
|
rotate_left(parent)
|
||||||
x = self.root
|
x = self.root
|
||||||
|
|
||||||
return x, (x.parent if x else None)
|
return x, (x.parent if x else None)
|
||||||
|
@ -162,8 +163,8 @@ class RBTree(RankedTree[T]):
|
||||||
parent.right,
|
parent.right,
|
||||||
parent,
|
parent,
|
||||||
lambda x: x.right,
|
lambda x: x.right,
|
||||||
Node.rotate_left,
|
self.rotate_left,
|
||||||
Node.rotate_right,
|
self.rotate_right,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
node, parent = self._delete_rebalance_step(
|
node, parent = self._delete_rebalance_step(
|
||||||
|
@ -171,8 +172,8 @@ class RBTree(RankedTree[T]):
|
||||||
parent.left,
|
parent.left,
|
||||||
parent,
|
parent,
|
||||||
lambda x: x.left,
|
lambda x: x.left,
|
||||||
Node.rotate_right,
|
self.rotate_right,
|
||||||
Node.rotate_left,
|
self.rotate_left,
|
||||||
)
|
)
|
||||||
|
|
||||||
# endregion DeleteRebalance
|
# endregion DeleteRebalance
|
||||||
|
|
25
wavl.py
25
wavl.py
|
@ -1,5 +1,6 @@
|
||||||
|
from ranked_tree import RotateFunction
|
||||||
from avl import AVLTree
|
from avl import AVLTree
|
||||||
from node import Node, NodeType, Comparable, RotateFunction
|
from node import Node, NodeType, Comparable
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import TypeVar, Optional
|
from typing import TypeVar, Optional
|
||||||
|
@ -35,11 +36,9 @@ class WAVLTree(AVLTree[T]):
|
||||||
y: Node[T],
|
y: Node[T],
|
||||||
z: Node[T],
|
z: Node[T],
|
||||||
reversed: bool,
|
reversed: bool,
|
||||||
rotating_around_root: bool,
|
|
||||||
rotate_left: RotateFunction[T],
|
rotate_left: RotateFunction[T],
|
||||||
rotate_right: RotateFunction[T],
|
rotate_right: RotateFunction[T],
|
||||||
) -> None:
|
) -> None:
|
||||||
new_root = x
|
|
||||||
v = y.left
|
v = y.left
|
||||||
w = y.right
|
w = y.right
|
||||||
|
|
||||||
|
@ -49,9 +48,7 @@ class WAVLTree(AVLTree[T]):
|
||||||
w_diff = Node.difference(w, y)
|
w_diff = Node.difference(w, y)
|
||||||
|
|
||||||
if w_diff == 1 and y.parent:
|
if w_diff == 1 and y.parent:
|
||||||
new_root = rotate_left(y.parent)
|
rotate_left(y.parent)
|
||||||
if rotating_around_root:
|
|
||||||
self.root = new_root
|
|
||||||
|
|
||||||
y.promote()
|
y.promote()
|
||||||
z.demote()
|
z.demote()
|
||||||
|
@ -60,9 +57,7 @@ class WAVLTree(AVLTree[T]):
|
||||||
z.demote()
|
z.demote()
|
||||||
elif w_diff == 2 and v.parent:
|
elif w_diff == 2 and v.parent:
|
||||||
rotate_right(v.parent)
|
rotate_right(v.parent)
|
||||||
new_root = rotate_left(v.parent)
|
rotate_left(v.parent)
|
||||||
if rotating_around_root:
|
|
||||||
self.root = new_root
|
|
||||||
|
|
||||||
v.promote().promote()
|
v.promote().promote()
|
||||||
y.demote()
|
y.demote()
|
||||||
|
@ -97,8 +92,6 @@ class WAVLTree(AVLTree[T]):
|
||||||
x_diff = Node.difference(x, parent)
|
x_diff = Node.difference(x, parent)
|
||||||
y_diff = Node.difference(y, parent)
|
y_diff = Node.difference(y, parent)
|
||||||
|
|
||||||
rotating_around_root = parent.parent is None
|
|
||||||
|
|
||||||
parent_node_diffs = Node.differences(parent)
|
parent_node_diffs = Node.differences(parent)
|
||||||
if parent_node_diffs in ((1, 3), (3, 1)):
|
if parent_node_diffs in ((1, 3), (3, 1)):
|
||||||
if parent.left is x:
|
if parent.left is x:
|
||||||
|
@ -108,9 +101,8 @@ class WAVLTree(AVLTree[T]):
|
||||||
parent.right,
|
parent.right,
|
||||||
parent,
|
parent,
|
||||||
False,
|
False,
|
||||||
rotating_around_root,
|
self.rotate_left,
|
||||||
Node.rotate_left,
|
self.rotate_right,
|
||||||
Node.rotate_right,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assert parent.left
|
assert parent.left
|
||||||
|
@ -119,9 +111,8 @@ class WAVLTree(AVLTree[T]):
|
||||||
parent.left,
|
parent.left,
|
||||||
parent,
|
parent,
|
||||||
True,
|
True,
|
||||||
rotating_around_root,
|
self.rotate_right,
|
||||||
Node.rotate_right,
|
self.rotate_left,
|
||||||
Node.rotate_left,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _delete_rebalance(
|
def _delete_rebalance(
|
||||||
|
|
Loading…
Reference in a new issue