
378.Kth Smallest Element in a Sorted Matrix

Tags: Binary Search Heap

Given a n x n matrix where each of the rows and columns are sorted in ascending order, find the kth smallest element in the matrix.

Note that it is the kth smallest element in the sorted order, not the kth distinct element.


matrix = [
   [ 1,  5,  9],
   [10, 11, 13],
   [12, 13, 15]
k = 8,

return 13.

Note: You may assume k is always valid, 1 ≤ k ≤ n^2.

class Solution {
    int kthSmallest(vector<vector<int>>& matrix, int k) {
        priority_queue<int> pq;
        for (int i = 0; i < matrix.size(); ++i){
            for (auto e : matrix[i]){
        int n = matrix.size();
        int rank = n * n - k + 1;

        while (--rank){

        return pq.top();
priority_queue实现的是最大堆,所以寻要查找的是n^2 - k +1的数值。时间复杂度O(n^2logK)



class Solution {
    int valueBinarySearch(vector<vector<int>>& matrix, int k, int target)
        int count = 0;
        for (int i = 0; i < matrix.size(); ++i) {
            for (int j = 0; j < matrix[i].size(); ++j) {
                if (matrix[i][j] <= target) {
                break; //matrix[i][j]后面的数字肯定大于target
        return count;
    int kthSmallest(vector<vector<int>>& matrix, int k) {
        int left = matrix[0][0], n = matrix.size(), right = matrix[n-1][n-1];

        while (left < right) {
            int mid = left + ((right - left) >> 1);
            int cnt = valueBinarySearch(matrix, k, mid);
            if (cnt >= k) right = mid;
            else left = mid + 1;

        return left;
class Solution {
    int search(vector<vector<int>>& matrix, int target,int n)
        int row = n - 1 ,col = 0;
        int count = 0;
        while(row >=0 && col < n)
            //那么count加上row + 1,因为这一列已经遍历过了所以同时把col + 1,
            if(matrix[row][col] <= target)
                count += row + 1;
                col ++;
                row --;
        return count;
    int kthSmallest(vector<vector<int>>& matrix, int k) {
        int n = matrix.size();
        int left = matrix[0][0],right = matrix[n - 1][n - 1];
        while(left < right)
            int mid = (left + right)/2;
            int count = search(matrix,mid,n);
            if(count >= k)
                right = mid;
                left = mid + 1;
        return left;