#include <unordered_map>

#ifdef _MF_TEST
struct ListNode {
    int val;
    ListNode *next;
    ListNode() : val(0), next(nullptr) {}
    ListNode(int x) : val(x), next(nullptr) {}
    ListNode(int x, ListNode *next) : val(x), next(next) {}
};
#endif

class Solution {
  public:
    ListNode *removeZeroSumSublists(ListNode *head) {
        ListNode front(0, head);

        std::unordered_map<int, ListNode *> sums;

        int sum = 0;
        for (ListNode *node = &front; node != nullptr; node = node->next) {
            sum += node->val;

            if (auto it = sums.find(sum); it != sums.end()) {
                auto prev = it->second;
                node = prev->next;

                int p = sum + node->val;
                while (p != sum) {
                    sums.erase(p);
                    node = node->next;
                    p += node->val;
                }

                prev->next = node->next;
            } else {
                sums[sum] = node;
            }
        }

        return front.next;
    }
};