#include <unordered_map>
#include <vector>

class Solution {
  public:
    int findMaxLength(const std::vector<int> &nums) {
        std::unordered_map<int, int> lengths;
        lengths[0] = -1;

        int max_length = 0;
        int count = 0;

        for (auto i = 0u; i < nums.size(); ++i) {
            count += nums[i] == 1 ? 1 : -1;

            auto len = lengths.find(count);
            if (len != lengths.end()) {
                max_length =
                    std::max(max_length, static_cast<int>(i) - len->second);
            } else {
                lengths[count] = i;
            }
        }

        return max_length;
    }
};