class Solution { private record Result(int[] reachable, int pairs) { public Result(int[] reachable) { this(reachable, 0); } public static Result deadEnd(int distance) { return new Result(new int[distance + 1]); } public static Result leaf(int distance) { int[] reachable = new int[distance + 1]; reachable[0] = 1; return new Result(reachable); } public static Result merge(int distance, Result left, Result right) { int[] merged = new int[distance + 1]; // we're moving by one for (int i = 0; i < distance; ++i) { merged[i + 1] += left.reachable[i] + right.reachable[i]; } int pairs = left.pairs + right.pairs; for (int i = 0; i <= distance; ++i) { for (int j = 0; j <= distance; ++j) { if (i + j + 2 <= distance) { pairs += left.reachable[i] * right.reachable[j]; } } } return new Result(merged, pairs); } } private boolean isLeaf(TreeNode node) { return node.left == null && node.right == null; } private Result traverse(TreeNode node, int distance) { if (node == null) { return Result.deadEnd(distance); } else if (isLeaf(node)) { return Result.leaf(distance); } var left = traverse(node.left, distance); var right = traverse(node.right, distance); return Result.merge(distance, left, right); } public int countPairs(TreeNode root, int distance) { return traverse(root, distance).pairs; } }