package main

func balanceBST(root *TreeNode) *TreeNode {
	biggestPowerOfTwoThan := func(n int) int {
		x := 1
		for x <= n {
			x <<= 1
		}

		return x >> 1
	}

	rotateLeft := func(parent, node *TreeNode) {
		tmp := node.Right
		node.Right = tmp.Left
		tmp.Left = node
		parent.Right = tmp
	}
	rotateRight := func(parent, node *TreeNode) {
		tmp := node.Left
		node.Left = tmp.Right
		tmp.Right = node
		parent.Right = tmp
	}

	rotate := func(head *TreeNode, count int) {
		current := head
		for i := 0; i < count; i++ {
			tmp := current.Right
			rotateLeft(current, tmp)
			current = current.Right
		}
	}

	if root == nil {
		return nil
	}

	// Step 1 of DSW
	head := TreeNode{Val: 0}
	head.Right = root

	current := &head
	for current.Right != nil {
		if current.Right.Left != nil {
			rotateRight(current, current.Right)
		} else {
			current = current.Right
		}
	}

	// Step 2 of DSW
	count, current := 0, head.Right
	for current != nil {
		count++
		current = current.Right
	}

	// Step 3 of DSW
	m := biggestPowerOfTwoThan(count+1) - 1
	rotate(&head, count-m)

	for m > 1 {
		m >>= 1
		rotate(&head, m)
	}

	return head.Right
}