跳转至

230.Kth Smallest Element in a BST

Tags: Medium Binary Search

Links: https://leetcode.com/problems/kth-smallest-element-in-a-bst/


Given a binary search tree, write a function kthSmallest to find the **k**th smallest element in it.

Note: You may assume k is always valid, 1 ≤ k ≤ BST's total elements.

Example 1:

Input: root = [3,1,4,null,2], k = 1
   3
  / \
 1   4
  \
   2
Output: 1

Example 2:

Input: root = [5,3,6,2,4,null,null,1], k = 3
       5
      / \
     3   6
    / \
   2   4
  /
 1
Output: 3

Follow up: What if the BST is modified (insert/delete operations) often and you need to find the kth smallest frequently? How would you optimize the kthSmallest routine?


/**
 * Definition for a binary tree node.
 * struct TreeNode {
 *     int val;
 *     TreeNode *left;
 *     TreeNode *right;
 *     TreeNode(int x) : val(x), left(NULL), right(NULL) {}
 * };
 */
class Solution {
public:
    int kthSmallest(TreeNode* root, int k) {
        std::ios_base::sync_with_stdio(false);
        cin.tie(NULL);
        cout.tie(NULL);

        vector<int> num;
        inorderTraversal(num, root);

        return num[k - 1];
    }

    void inorderTraversal(vector<int> & num, TreeNode *root)
    {
        if (!root) return;
        inorderTraversal(num, root -> left);
        num.push_back(root -> val);
        inorderTraversal(num, root -> right);
    }
};

中序遍历后的序列是有序的。其实并不需要将所有的遍历数据都写出来,只需要用一个计数器即可,节省存储空间。

/**
 * Definition for a binary tree node.
 * struct TreeNode {
 *     int val;
 *     TreeNode *left;
 *     TreeNode *right;
 *     TreeNode(int x) : val(x), left(NULL), right(NULL) {}
 * };
 */
class Solution {
public:
    int kthSmallest(TreeNode* root, int k) {
        std::ios_base::sync_with_stdio(false);
        cin.tie(NULL);
        cout.tie(NULL);

        stack<TreeNode*> s;
        TreeNode *p = root;
        int cnt = 0;
        while (!s.empty() || p) {
            if (p) {
                s.push(p);
                p = p -> left;
            }
            else {
                p = s.top(); s.pop();
                ++cnt;
                if (cnt == k) return p -> val;
                p = p -> right;
            }
        }

        return -1;
    }
};

第三种方法是计算以root为根的节点个数,加速方法是利用一个unordered_map来存储已经遍历过的节点的子树大小。

/**
 * Definition for a binary tree node.
 * struct TreeNode {
 *     int val;
 *     TreeNode *left;
 *     TreeNode *right;
 *     TreeNode(int x) : val(x), left(NULL), right(NULL) {}
 * };
 */
class Solution {
    unordered_map<TreeNode *, int> um;
public:
    int kthSmallest(TreeNode* root, int k) {
        std::ios_base::sync_with_stdio(false);
        cin.tie(NULL);
        cout.tie(NULL);

        int num = count(root -> left);
        if (num + 1 == k) return root -> val;
        else if (num + 1 > k) {
            return kthSmallest(root -> left, k);
        }
        return kthSmallest(root -> right, k - num - 1);
    }

    int count(TreeNode *root)
    {
        if (!root) return 0;
        if (um.find(root) != um.end()) return um[root];

        return um[root] = 1 + count(root -> left) + count(root -> right);
    }
};

这道题的follow up是如果该BST被修改的很频繁,那么我们就需要自定义数据结构,让每个节点存储其子树的大小。

class Solution
{
struct myTreeNode
{
    int val, cnt;
    myTreeNode *left, *right;
    myTreeNode(int x) : val(x), cnt(0), left(NULL), right(NULL) {}
};

public:
    myTreeNode *build(TreeNode *root)
    {
        if (!root) return NULL;

        myTreeNode *myRoot = new myTreeNode(root -> val);
        myRoot -> left = build(root -> left);
        myRoot -> right = build(root -> right);

        if (myRoot -> left) myRoot -> cnt += myRoot -> left -> cnt;
        if (myRoot -> right) myRoot -> cnt += myRoot -> right -> cnt;
        myRoot -> cnt += 1;

        return myRoot;
    }

    int kthSmallest(TreeNode *root, int k)
    {
        myTreeNode *myRoot = build(root);
        return solve(myRoot, k);
    }   

    int solve(myTreeNode *myRoot, int k)
    {
        if (myRoot -> left) {
            if (myRoot -> left -> cnt > k - 1) return solve(myRoot -> left, k);
            else if (myRoot -> left -> cnt < k - 1) return solve(myRoot -> right, k - 1 - myRoot -> left -> cnt);
            return root -> val;
        }
        else {
            if (k == 1) return myRoot -> val;
            else return solve(myRoot -> right, k - 1);
        }
    }
};