풀이
- 말단 노드 간 거리가 distance보다 짧은 말단 노드 개수를 구해야 한다
- 일단 말단 노드를 모두 찾아 리스트에 저장
- 그래프를 구축한다
- 말단 노드를 순회하면서 모든 그래프의 노드들과 거리를 계산하다
코드
class Solution {
private int distance;
private Map<TreeNode, List<TreeNode>> graph;
private Set<TreeNode> leaves;
public int countPairs(TreeNode root, int distance) {
this.distance = distance;
this.leaves = new HashSet();
this.graph = new HashMap();
getGraph(root, null);
int pairs = 0;
for(TreeNode leaf : leaves){
pairs += bfs(leaf);
}
return pairs / 2;
}
public void getGraph(TreeNode node, TreeNode parent){
if(node == null) return;
if(node.left == null && node.right == null){
leaves.add(node);
}
if(parent != null){
graph.computeIfAbsent(parent, k -> new ArrayList<>()).add(node);
graph.computeIfAbsent(node, k -> new ArrayList<>()).add(parent);
}
getGraph(node.left, node);
getGraph(node.right, node);
}
public int bfs(TreeNode leaf){
Queue<Pair> queue = new LinkedList<>();
Set<TreeNode> visited = new HashSet();
queue.add(new Pair(leaf, 0));
visited.add(leaf);
int cnt = 0;
while(!queue.isEmpty()){
Pair pair = queue.poll();
TreeNode node = pair.node;
int dist = pair.dist;
if(dist > distance) continue;
if(dist > 0 && leaves.contains(node)){
cnt++;
}
for(TreeNode nextNode : graph.getOrDefault(node, Collections.emptyList())){
if(!visited.contains(nextNode)){
visited.add(nextNode);
queue.add(new Pair(nextNode, dist + 1));
}
}
}
return cnt;
}
class Pair {
TreeNode node;
int dist;
public Pair(TreeNode node, int dist){
this.node = node;
this.dist = dist;
}
}
}
회고
- bfs 순회 시 level를 부여해서 level이 가장 큰 (마지막 level은 말단 노드) 노드를 찾아서 각 말단 노드의 거리를 구하는 방식으로 했다가 말단 노드는 level이 다를 수 있음 + 코드가 저세상으로 가버림으로 인해 포기
- 말단 노드를 먼저 구하는 것을 생각하지 못 해서 다른 코드 참고해서 작성했다
- 어려웠다!