From e28f578fc28855458cf5b265f9c2106b0cb525fc Mon Sep 17 00:00:00 2001 From: Matej Focko Date: Sun, 22 May 2022 21:00:18 +0200 Subject: [PATCH] fix(rotations): move rotations to the tree Move rotations to the RankedTree, which allows us to factor out the functionality regarding rotations when root of the tree changes. Signed-off-by: Matej Focko --- avl.py | 37 ++++++++++---------------------- node.py | 55 ------------------------------------------------ ranked_tree.py | 57 +++++++++++++++++++++++++++++++++++++++++++++++++- rbt.py | 27 ++++++++++++------------ wavl.py | 25 +++++++--------------- 5 files changed, 89 insertions(+), 112 deletions(-) diff --git a/avl.py b/avl.py index 00f16b5..6b798ac 100644 --- a/avl.py +++ b/avl.py @@ -1,5 +1,5 @@ -from node import Node, Comparable, RotateFunction -from ranked_tree import RankedTree +from node import Node, Comparable +from ranked_tree import RankedTree, RotateFunction import logging from typing import TypeVar, Optional @@ -49,22 +49,18 @@ class AVLTree(RankedTree[T]): z: Node[T], rotate_left: RotateFunction[T], rotate_right: RotateFunction[T], - ) -> Optional[Node[T]]: - new_root = x.parent - + ) -> None: if not y or Node.difference(y) == 2: - new_root = rotate_right(z) + rotate_right(z) z.demote() elif Node.difference(y) == 1: rotate_left(x) - new_root = rotate_right(z) + rotate_right(z) y.promote() x.demote() z.demote() - return new_root - def _insert_rebalance(self, x: Node[T]) -> None: diffs = Node.differences(x.parent) while x.parent and (diffs == (0, 1) or diffs == (1, 0)): @@ -76,9 +72,6 @@ class AVLTree(RankedTree[T]): if not x.parent: return - rotating_around_root = x.parent.parent is None - new_root: Optional[Node[T]] = x.parent - rank_difference = Node.difference(x) if rank_difference != 0: return @@ -86,17 +79,14 @@ class AVLTree(RankedTree[T]): x_parent = x.parent assert x_parent is not None if rank_difference == 0 and x.parent.left is x: - new_root = self.__fix_0_child( - x, x.right, x_parent, Node.rotate_left, Node.rotate_right + self.__fix_0_child( + x, x.right, x_parent, self.rotate_left, self.rotate_right ) elif rank_difference == 0 and x.parent.right is x: - new_root = self.__fix_0_child( - x, x.left, x_parent, Node.rotate_right, Node.rotate_left + self.__fix_0_child( + x, x.left, x_parent, self.rotate_right, self.rotate_left ) - if rotating_around_root: - self.root = new_root - # endregion InsertRebalance # region DeleteRebalance @@ -106,7 +96,6 @@ class AVLTree(RankedTree[T]): x: Node[T], y: Node[T], leaning: int, - rotating_around_root: bool, rotate_left: RotateFunction[T], rotate_right: RotateFunction[T], ) -> bool: @@ -122,8 +111,6 @@ class AVLTree(RankedTree[T]): for n in filter(None, (new_root.left, new_root.right, new_root)): _update_rank(n) - if rotating_around_root: - self.root = new_root return factor != 0 def __delete_fixup( @@ -136,18 +123,16 @@ class AVLTree(RankedTree[T]): elif factor in (-1, 1): return False - rotating_around_root = x.parent is None y, leaning, to_left, to_right = ( - (x.right, 1, Node.rotate_left, Node.rotate_right) + (x.right, 1, self.rotate_left, self.rotate_right) if factor == 2 - else (x.left, -1, Node.rotate_right, Node.rotate_left) + else (x.left, -1, self.rotate_right, self.rotate_left) ) assert y return self.__delete_rotate( x, y, leaning, - rotating_around_root, to_left, to_right, ) diff --git a/node.py b/node.py index a1ac28d..7494858 100644 --- a/node.py +++ b/node.py @@ -1,14 +1,12 @@ from abc import abstractmethod import enum from typing import ( - Callable, Iterable, Optional, Generic, Tuple, TypeVar, Protocol, - Any, ) @@ -112,56 +110,6 @@ class Node(Generic[T]): self.rank -= 1 return self - @staticmethod - def rotate_right(x: "Node[T]", tree: Any = None) -> "Node[T]": - parent = x.parent - y = x.left - # z = x.right - - assert y is not None - if parent: - if parent.left is x: - parent.left = y - else: - parent.right = y - elif tree: - tree.root = y - - x.left = y.right - if x.left: - x.left.parent = x - - y.right = x - x.parent = y - y.parent = parent - - return y - - @staticmethod - def rotate_left(x: "Node[T]", tree: Any = None) -> "Node[T]": - parent = x.parent - # y = x.left - z = x.right - - assert z is not None - if parent: - if parent.left is x: - parent.left = z - else: - parent.right = z - elif tree: - tree.root = z - - x.right = z.left - if x.right: - x.right.parent = x - - z.left = x - x.parent = z - z.parent = parent - - return z - @staticmethod def find_parent_node( value: T, node: "Node[T]", missing: bool = True @@ -199,6 +147,3 @@ class Node(Generic[T]): node = node.right return node - - -RotateFunction = Callable[[Node[T]], Node[T]] diff --git a/ranked_tree.py b/ranked_tree.py index 80c49b4..6b8ce7a 100644 --- a/ranked_tree.py +++ b/ranked_tree.py @@ -3,7 +3,7 @@ from node import Node, Comparable from abc import abstractmethod from collections import deque import logging -from typing import Deque, Iterable, Optional, Tuple, TypeVar, Generic +from typing import Callable, Deque, Iterable, Optional, Tuple, TypeVar, Generic logger = logging.getLogger(__name__) T = TypeVar("T", bound=Comparable) @@ -58,6 +58,58 @@ class RankedTree(Generic[T]): # endregion TreeSpecificMethods + # region Rotations + + def rotate_right(self, x: "Node[T]") -> "Node[T]": + parent = x.parent + y = x.left + # z = x.right + + assert y is not None + if parent: + if parent.left is x: + parent.left = y + else: + parent.right = y + else: + self.root = y + + x.left = y.right + if x.left: + x.left.parent = x + + y.right = x + x.parent = y + y.parent = parent + + return y + + def rotate_left(self, x: "Node[T]") -> "Node[T]": + parent = x.parent + # y = x.left + z = x.right + + assert z is not None + if parent: + if parent.left is x: + parent.left = z + else: + parent.right = z + else: + self.root = z + + x.right = z.left + if x.right: + x.right.parent = x + + z.left = x + x.parent = z + z.parent = parent + + return z + + # endregion Rotations + @property def rank(self) -> int: return Node.get_rank(self.root) @@ -148,3 +200,6 @@ class RankedTree(Generic[T]): """ if self.root: yield from self.root + + +RotateFunction = Callable[[Node[T]], Node[T]] diff --git a/rbt.py b/rbt.py index 8a7b7ea..0d24ab9 100644 --- a/rbt.py +++ b/rbt.py @@ -1,5 +1,5 @@ -from node import Node, Comparable, RotateFunction -from ranked_tree import RankedTree +from node import Node, Comparable +from ranked_tree import RankedTree, RotateFunction import enum import logging @@ -13,6 +13,7 @@ class Colour(enum.IntEnum): """ Represents colour of the edge or node. """ + Red = 0 Black = 1 @@ -80,12 +81,12 @@ class RBTree(RankedTree[T]): # ====== # z’s uncle y is black and z is a right child z = p - rotate_left(tree=self, x=p) + rotate_left(p) else: # Case 3 # ====== # z’s uncle y is black and z is a left child - rotate_right(tree=self, x=pp) + rotate_right(pp) return z @@ -97,11 +98,11 @@ class RBTree(RankedTree[T]): assert pp if p == pp.left: z = self._insert_rebalance_step( - z, pp.right, p.right, Node.rotate_left, Node.rotate_right + z, pp.right, p.right, self.rotate_left, self.rotate_right ) else: z = self._insert_rebalance_step( - z, pp.left, p.left, Node.rotate_right, Node.rotate_left + z, pp.left, p.left, self.rotate_right, self.rotate_left ) # endregion InsertRebalance @@ -121,7 +122,7 @@ class RBTree(RankedTree[T]): # Case 1 # ====== # x’s sibling w is red - rotate_left(tree=self, x=parent) + rotate_left(parent) w = right(parent) if Node.differences(w) == (Colour.Black, Colour.Black): @@ -136,7 +137,7 @@ class RBTree(RankedTree[T]): # x’s sibling w is black, # w’s left child is red, and w’s right child is black if Node.difference(right(w), w) == Colour.Black: - rotate_right(tree=self, x=w) + rotate_right(w) w = right(parent) # Case 4 @@ -144,7 +145,7 @@ class RBTree(RankedTree[T]): # x’s sibling w is black, and w’s right child is red parent.rank -= Colour.Black w.rank += Colour.Black - rotate_left(tree=self, x=parent) + rotate_left(parent) x = self.root return x, (x.parent if x else None) @@ -162,8 +163,8 @@ class RBTree(RankedTree[T]): parent.right, parent, lambda x: x.right, - Node.rotate_left, - Node.rotate_right, + self.rotate_left, + self.rotate_right, ) else: node, parent = self._delete_rebalance_step( @@ -171,8 +172,8 @@ class RBTree(RankedTree[T]): parent.left, parent, lambda x: x.left, - Node.rotate_right, - Node.rotate_left, + self.rotate_right, + self.rotate_left, ) # endregion DeleteRebalance diff --git a/wavl.py b/wavl.py index 11e3621..06d2d5e 100644 --- a/wavl.py +++ b/wavl.py @@ -1,5 +1,6 @@ +from ranked_tree import RotateFunction from avl import AVLTree -from node import Node, NodeType, Comparable, RotateFunction +from node import Node, NodeType, Comparable import logging from typing import TypeVar, Optional @@ -35,11 +36,9 @@ class WAVLTree(AVLTree[T]): y: Node[T], z: Node[T], reversed: bool, - rotating_around_root: bool, rotate_left: RotateFunction[T], rotate_right: RotateFunction[T], ) -> None: - new_root = x v = y.left w = y.right @@ -49,9 +48,7 @@ class WAVLTree(AVLTree[T]): w_diff = Node.difference(w, y) if w_diff == 1 and y.parent: - new_root = rotate_left(y.parent) - if rotating_around_root: - self.root = new_root + rotate_left(y.parent) y.promote() z.demote() @@ -60,9 +57,7 @@ class WAVLTree(AVLTree[T]): z.demote() elif w_diff == 2 and v.parent: rotate_right(v.parent) - new_root = rotate_left(v.parent) - if rotating_around_root: - self.root = new_root + rotate_left(v.parent) v.promote().promote() y.demote() @@ -97,8 +92,6 @@ class WAVLTree(AVLTree[T]): x_diff = Node.difference(x, parent) y_diff = Node.difference(y, parent) - rotating_around_root = parent.parent is None - parent_node_diffs = Node.differences(parent) if parent_node_diffs in ((1, 3), (3, 1)): if parent.left is x: @@ -108,9 +101,8 @@ class WAVLTree(AVLTree[T]): parent.right, parent, False, - rotating_around_root, - Node.rotate_left, - Node.rotate_right, + self.rotate_left, + self.rotate_right, ) else: assert parent.left @@ -119,9 +111,8 @@ class WAVLTree(AVLTree[T]): parent.left, parent, True, - rotating_around_root, - Node.rotate_right, - Node.rotate_left, + self.rotate_right, + self.rotate_left, ) def _delete_rebalance(