#include <algorithm>
#include <cassert>
#include <iostream>
#include <vector>

using namespace std;

class Solution {
  public:
    int kthSmallest(vector<vector<int>> &matrix, int k) {
        int low = matrix.front().front();
        int high = matrix.back().back();

        while (low < high) {
            int mid = low + (high - low) / 2;

            int rank = 0;
            for (const auto &row : matrix) {
                rank += upper_bound(row.begin(), row.end(), mid) - row.begin();
            }

            if (rank < k) {
                low = mid + 1;
            } else {
                high = mid;
            }
        }

        return low;
    }
};

int main() {
    Solution s;

    vector<vector<int>> m;
    int k;

    m = {{1, 5, 9}, {10, 11, 13}, {12, 13, 15}};
    k = 8;
    assert(s.kthSmallest(m, k) == 13);

    m = {{-5}};
    k = 1;
    assert(s.kthSmallest(m, k) == -5);
}