#include <algorithm>
#include <cassert>
#include <cstddef>
#include <iostream>
#include <iterator>
#include <vector>

namespace {

template <typename T> void print_matrix(const std::vector<std::vector<T>> &m) {
    for (const auto &row : m) {
        for (auto x : row) {
            std::cout << x << " ";
        }

        std::cout << "\n";
    }

    std::cout << "\n";
}

template <typename T> class diagonal {
    std::vector<std::vector<T>> &matrix;
    std::size_t x;
    std::size_t y;

    class diagonal_iter {
        std::vector<std::vector<T>> &m;
        std::size_t x;
        std::size_t y;

      public:
        using difference_type = std::ptrdiff_t;
        using value_type = T;
        using pointer = T *;
        using reference = T &;
        using iterator_category = std::random_access_iterator_tag;

        diagonal_iter(std::vector<std::vector<T>> &matrix, std::size_t x,
                      std::size_t y)
            : m(matrix), x(x), y(y) {}

        bool operator!=(const diagonal_iter &rhs) const {
            return m != rhs.m || x != rhs.x || y != rhs.y;
        }
        bool operator==(const diagonal_iter &rhs) const {
            return !(*this != rhs);
        }

        diagonal_iter &operator++() {
            x++;
            y++;
            return *this;
        }

        diagonal_iter operator--() {
            x--;
            y--;
            return *this;
        }

        reference operator*() const { return m[y][x]; }

        diagonal_iter operator-(difference_type n) const {
            return diagonal_iter{m, x - n, y - n};
        }
        int operator-(const diagonal_iter &rhs) const { return x - rhs.x; }

        diagonal_iter operator+(difference_type n) const {
            return diagonal_iter{m, x + n, y + n};
        }

        bool operator<(const diagonal_iter &rhs) const { return x < rhs.x; }

        diagonal_iter &operator=(const diagonal_iter &rhs) {
            if (this != &rhs) // not a self-assignment
            {
                this->m = rhs.m;
                this->x = rhs.x;
                this->y = rhs.y;
            }
            return *this;
        }
    };

  public:
    diagonal(std::vector<std::vector<T>> &matrix, std::size_t x, std::size_t y)
        : matrix(matrix), x(x), y(y) {}

    diagonal_iter begin() const { return diagonal_iter{matrix, x, y}; }

    diagonal_iter end() const {
        auto max_x = matrix[y].size();
        auto max_y = matrix.size();

        auto steps = std::min(max_x - x, max_y - y);

        return diagonal_iter{matrix, x + steps, y + steps};
    }
};

template <typename T> class diagonals {
    std::vector<std::vector<T>> &_matrix;

    class diagonals_iter {
        std::vector<std::vector<T>> &m;
        std::size_t x;
        std::size_t y;

      public:
        diagonals_iter(std::vector<std::vector<T>> &matrix, std::size_t x,
                       std::size_t y)
            : m(matrix), x(x), y(y) {}

        bool operator!=(const diagonals_iter &rhs) const {
            return m != rhs.m || x != rhs.x || y != rhs.y;
        }

        diagonals_iter &operator++() {
            if (y != 0) {
                // iterating through diagonals down the first column
                y++;
                return *this;
            }

            // iterating the diagonals along the first row
            x++;
            if (x == m.front().size()) {
                // switching to diagonals in the first column
                x = 0;
                y++;
            }

            return *this;
        }

        diagonal<T> operator*() const { return diagonal{m, x, y}; }
    };

  public:
    diagonals(std::vector<std::vector<T>> &matrix) : _matrix(matrix) {}
    diagonals_iter begin() { return diagonals_iter{_matrix, 0, 0}; }
    diagonals_iter end() { return diagonals_iter{_matrix, 0, _matrix.size()}; }
};

} // namespace

class Solution {
  public:
    std::vector<std::vector<int>>
    diagonalSort(std::vector<std::vector<int>> mat) {
        for (auto d : diagonals(mat)) {
            std::sort(d.begin(), d.end());
        }

        return mat;
    }
};

static void test_case_1() {
    // Input: mat = [[3,3,1,1],[2,2,1,2],[1,1,1,2]]
    // Output: [[1,1,1,1],[1,2,2,2],[1,2,3,3]]

    Solution s;
    assert((s.diagonalSort(std::vector{std::vector{3, 3, 1, 1},
                                       std::vector{2, 2, 1, 2},
                                       std::vector{1, 1, 1, 2}}) ==
            std::vector{std::vector{1, 1, 1, 1}, std::vector{1, 2, 2, 2},
                        std::vector{1, 2, 3, 3}}));
}

static void test_case_2() {
    // Input: mat =
    // [[11,25,66,1,69,7],[23,55,17,45,15,52],[75,31,36,44,58,8],[22,27,33,25,68,4],[84,28,14,11,5,50]]
    // Output:
    // [[5,17,4,1,52,7],[11,11,25,45,8,69],[14,23,25,44,58,15],[22,27,31,36,50,66],[84,28,75,33,55,68]]

    Solution s;
    assert((s.diagonalSort(std::vector{std::vector{11, 25, 66, 1, 69, 7},
                                       std::vector{23, 55, 17, 45, 15, 52},
                                       std::vector{75, 31, 36, 44, 58, 8},
                                       std::vector{22, 27, 33, 25, 68, 4},
                                       std::vector{84, 28, 14, 11, 5, 50}}) ==
            std::vector{std::vector{5, 17, 4, 1, 52, 7},
                        std::vector{11, 11, 25, 45, 8, 69},
                        std::vector{14, 23, 25, 44, 58, 15},
                        std::vector{22, 27, 31, 36, 50, 66},
                        std::vector{84, 28, 75, 33, 55, 68}}));
}

int main() {
    test_case_1();
    test_case_2();

    return 0;
}