/**
 * Definition for a binary tree node.
 * struct TreeNode {
 *     int val;
 *     TreeNode *left;
 *     TreeNode *right;
 *     TreeNode() : val(0), left(nullptr), right(nullptr) {}
 *     TreeNode(int x) : val(x), left(nullptr), right(nullptr) {}
 *     TreeNode(int x, TreeNode *left, TreeNode *right) : val(x), left(left),
 * right(right) {}
 * };
 */
class Solution {
    int dfs(TreeNode *root, int start, int &time) {
        if (root == nullptr) {
            return 0;
        }

        int left = dfs(root->left, start, time);
        int right = dfs(root->right, start, time);

        if (root->val == start) {
            time = std::max(left, right);
            return -1;
        }

        if (left >= 0 && right >= 0) {
            return 1 + std::max(left, right);
        }

        auto t = std::abs(left) + std::abs(right);
        time = std::max(time, t);

        return std::min(left, right) - 1;
    }

  public:
    int amountOfTime(TreeNode *root, int start) {
        int time = 0;
        dfs(root, start, time);
        return time;
    }
};