class Solution {
  private record Result(int[] reachable, int pairs) {
    public Result(int[] reachable) {
      this(reachable, 0);
    }

    public static Result deadEnd(int distance) {
      return new Result(new int[distance + 1]);
    }

    public static Result leaf(int distance) {
      int[] reachable = new int[distance + 1];
      reachable[0] = 1;
      return new Result(reachable);
    }

    public static Result merge(int distance, Result left, Result right) {
      int[] merged = new int[distance + 1];

      // we're moving by one
      for (int i = 0; i < distance; ++i) {
        merged[i + 1] += left.reachable[i] + right.reachable[i];
      }

      int pairs = left.pairs + right.pairs;
      for (int i = 0; i <= distance; ++i) {
        for (int j = 0; j <= distance; ++j) {
          if (i + j + 2 <= distance) {
            pairs += left.reachable[i] * right.reachable[j];
          }
        }
      }

      return new Result(merged, pairs);
    }
  }

  private boolean isLeaf(TreeNode node) {
    return node.left == null && node.right == null;
  }

  private Result traverse(TreeNode node, int distance) {
    if (node == null) {
      return Result.deadEnd(distance);
    } else if (isLeaf(node)) {
      return Result.leaf(distance);
    }

    var left = traverse(node.left, distance);
    var right = traverse(node.right, distance);

    return Result.merge(distance, left, right);
  }

  public int countPairs(TreeNode root, int distance) {
    return traverse(root, distance).pairs;
  }
}