#include <algorithm> #include <cassert> #include <iostream> #include <vector> using namespace std; class Solution { public: int kthSmallest(vector<vector<int>> &matrix, int k) { int low = matrix.front().front(); int high = matrix.back().back(); while (low < high) { int mid = low + (high - low) / 2; int rank = 0; for (const auto &row : matrix) { rank += upper_bound(row.begin(), row.end(), mid) - row.begin(); } if (rank < k) { low = mid + 1; } else { high = mid; } } return low; } }; int main() { Solution s; vector<vector<int>> m; int k; m = {{1, 5, 9}, {10, 11, 13}, {12, 13, 15}}; k = 8; assert(s.kthSmallest(m, k) == 13); m = {{-5}}; k = 1; assert(s.kthSmallest(m, k) == -5); }