so rmb that inorder traversal (left->root->right) is the guaranteed way to get a sorted way of tree nodes for BST.
So sol is
# Definition for a binary tree node.
# class TreeNode:
# def __init__(self, val=0, left=None, right=None):
# self.val = val
# self.left = left
# self.right = right
class Solution:
def kthSmallest(self, root: Optional[TreeNode], k: int) -> int:
prev = None
ans = None
def inorder(now):
nonlocal prev, ans, k
if not now:
return
inorder(now.left)
if ans is not None:
return
if k == 1:
ans = now.val
return
k -= 1
inorder(now.right)
inorder(root)
return ans
so once we come back to previous call stack and we see that the previous previous call stack updated ans, we just return and come back up the call stack immediately.
a more clearn way though would be
class Solution:
def kthSmallest(self, root: Optional[TreeNode], k: int) -> int:
"""
Finds the k-th smallest element in a Binary Search Tree (BST).
Args:
root: The root node of the BST.
k: The desired rank (1-indexed) of the smallest element.
Returns:
The value of the k-th smallest element in the BST.
"""
ans = None
count = 0
def inorder(node):
nonlocal ans, count
if ans is not None or not node: #stop if ans is found or node is null
return
inorder(node.left)
count += 1
if count == k:
ans = node.val
return
inorder(node.right)
inorder(root)
return ans
so we traverse all the way till left, and then increment count. It shouldnt be ans=1 and increment counter after we see that count==k. This is cuz the answer will be overriden
wrong
public void dfs(TreeNode node, int k) {
if (node == null) return;
// Go left
dfs(node.left, k);
// Debug print
System.out.println("Visiting node " + node.val + ", ans=" + ans + ", tmp=" + tmp);
// Check k
if (ans == k) {
tmp = node.val;
System.out.println("Found k-th smallest! tmp=" + tmp);
return;
}
// Increment counter
ans += 1;
// Go right
dfs(node.right, k);
}
correct
class Solution {
int ans = 0;
int tmp=0;
public int kthSmallest(TreeNode root, int k) {
dfs(root,k);
return tmp;
}
public void dfs(TreeNode node,int k){
if(node==null) return;
if(tmp!=0) return;
dfs(node.left, k);
ans+=1;
if (ans==k){
tmp=node.val;
return;
}
dfs(node.right,k);
}
}
o(k) time
o(n) space worst if skewed, o log n if balanced