https://www.acmicpc.net/problem/15681
간선에 가중치와 방향성이 없는 임의의 루트 있는 트리가 주어졌을 때, 아래의 쿼리에 답해보도록 하자.
정점 U를 루트로 하는 서브트리에 속한 정점의 수를 출력한다.
만약 이 문제를 해결하는 데에 어려움이 있다면, 하단의 힌트에 첨부한 문서를 참고하자.
입력
트리의 정점의 수 N과 루트의 번호 R, 쿼리의 수 Q가 주어진다. (2 ≤ N ≤ 105, 1 ≤ R ≤ N, 1 ≤ Q ≤ 105)
이어 N-1줄에 걸쳐, U V의 형태로 트리에 속한 간선의 정보가 주어진다. (1 ≤ U, V ≤ N, U ≠ V)
이는 U와 V를 양 끝점으로 하는 간선이 트리에 속함을 의미한다.
이어 Q줄에 걸쳐, 문제에 설명한 U가 하나씩 주어진다. (1 ≤ U ≤ N)
입력으로 주어지는 트리는 항상 올바른 트리임이 보장된다.
출력
Q줄에 걸쳐 각 쿼리의 답을 정수 하나로 출력한다.
예제 입력 1
9 5 3
1 3
4 3
5 4
5 6
6 7
2 3
9 6
6 8
5
4
8
예제 출력 1
9
4
1
굳이 복붙을 하지는 않았지만 문제 뒤에 긴 힌트가 있었다.
읽기 귀찮기도 하고 스포당하기 싫어서 처음에는 안 보고 풀었는데 메모리 초과가 났다.
왜냐면 Q개의 쿼리 입력에 대해 매번 트리를 탐색하면서 해당 정점 U를 찾은 후 서브트리의 정점의 수를 세었기 때문이었다. 모든 입력에 대해 트리 전체를 순회하여야 했던 것.
그러고 나서 힌트를 뒷부분만 살짝 보니...
메인 함수 내에서 makeTree(5, -1)과 countSubtreeNodes(5) 를 차례대로 한 번씩 호출할 경우, 5번을 루트로 하는 트리에서 모든 정점에 대해 각 정점을 루트로 하는 서브트리에 속한 정점의 수를 계산해둘 수가 있다. 이를 이용하면, 모든 질의 U에 대해 size[U] 를 출력하기만 하면 되므로, 정점이 10만 개, 질의가 10만 개인 데이터에서도 충분히 빠른 시간 내에 모든 질의를 처리할 수가 있게 될 것이다.
dp와 같은 방식으로 풀 수 있겠구나 하는 걸 깨닫게 되었다.
트리는 계층 구조로 이루어져 있으므로 부모 노드를 루트로 하는 서브 트리의 모든 노드의 개수는 자식을 루트로 하는 모든 서브 트리의 노드 개수의 합으로 구할 수 있기 때문이다.
그러니까 한 번 루트에서 리프까지의 모든 노드의 개수를 세면, 그 과정에서 동시에 각 서브 트리의 노드 개수까지 세어질 수 있는 것이다. 이 값들을 메모이제이션 해두고 그냥 쿼리 입력 값에 따라 출력만 하면 되는 거였다. i번 노드를 루트로 하는 서브 트리의 모든 노드 개수를 dp[i]라는 배열에 저장해 두기 때문에, 인덱스로 바로 접근할 수 있어 정점 U를 매번 찾을 필요도 없었다.
import java.io.*;
import java.util.*;
class Main {
static List<Integer>[] graph;
public static void main(String[] args) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
StringTokenizer st = new StringTokenizer(br.readLine());
int N = Integer.parseInt(st.nextToken());
int R = Integer.parseInt(st.nextToken());
int Q = Integer.parseInt(st.nextToken());
graph = new List[N + 1];
for (int i = 1; i <= N; i++) {
graph[i] = new ArrayList<>();
}
for (int i = 0; i < N - 1; i++) {
st = new StringTokenizer(br.readLine());
int U = Integer.parseInt(st.nextToken());
int V = Integer.parseInt(st.nextToken());
graph[U].add(V);
graph[V].add(U);
}
StringBuilder sb = new StringBuilder();
Node rootNode = new Node(R);
makeTree(rootNode);
for (int i = 0; i < Q; i++) {
sb.append(countNodes(findNode(Integer.parseInt(br.readLine()), rootNode))).append("\n");
}
System.out.print(sb.toString());
}
private static void makeTree(Node subRoot) {
for (int childNum : graph[subRoot.num]) {
Node childNode = new Node(childNum);
subRoot.children.add(childNode);
graph[childNum].remove(graph[childNum].indexOf(subRoot.num));
makeTree(childNode);
}
}
private static Node findNode(int target, Node subRoot) {
if (subRoot.num == target) {
return subRoot;
}
if (subRoot.children.isEmpty()) {
return new Node(0);
}
for (Node child : subRoot.children) {
Node found = findNode(target, child);
if (found.num > 0) {
return found;
}
}
return new Node(0);
}
private static int countNodes(Node subRoot) {
if (subRoot.children.isEmpty()) {
return 1;
}
int count = 1;
for (Node child : subRoot.children) {
count += countNodes(child);
}
return count;
}
static class Node {
int num;
List<Node> children;
Node (int num) {
this.num = num;
this.children = new ArrayList<>();
}
}
}
위에 서술했듯, dp를 이용하면 findNode()라는 메소드는 필요가 없었다.
import java.io.*;
import java.util.*;
class Main {
static int[] numOfNodes;
static List<Integer>[] graph;
public static void main(String[] args) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
StringTokenizer st = new StringTokenizer(br.readLine());
StringBuilder sb = new StringBuilder();
int N = Integer.parseInt(st.nextToken());
int R = Integer.parseInt(st.nextToken());
int Q = Integer.parseInt(st.nextToken());
graph = new List[N + 1];
numOfNodes = new int[N + 1];
for (int i = 1; i <= N; i++) {
graph[i] = new ArrayList<>();
}
for (int i = 0; i < N - 1; i++) {
st = new StringTokenizer(br.readLine());
int U = Integer.parseInt(st.nextToken());
int V = Integer.parseInt(st.nextToken());
graph[U].add(V);
graph[V].add(U);
}
Node rootNode = new Node(R);
makeTree(rootNode);
// 각 노드를 루트로 하는 서브 트리의 노드 개수 구하기
numOfNodes[R] = countNodes(rootNode);
for (int i = 0; i < Q; i++) {
sb.append(numOfNodes[Integer.parseInt(br.readLine())]).append("\n");
}
System.out.print(sb.toString());
}
private static void makeTree(Node subRoot) {
for (int childNum : graph[subRoot.num]) {
Node childNode = new Node(childNum);
subRoot.children.add(childNode);
graph[childNum].remove(graph[childNum].indexOf(subRoot.num));
makeTree(childNode);
}
}
private static int countNodes(Node subRoot) {
// 이미 노드의 개수를 구한 적이 있는 서브 트리인 경우
if (numOfNodes[subRoot.num] > 0) {
return numOfNodes[subRoot.num];
}
// 리프 노드인 경우 더 이상 서브 트리가 만들어질 수 없으므로 자기 자신의 개수인 1을 리턴하고 종료
if (subRoot.children.isEmpty()) {
return 1;
}
int count = 1;
// subRoot를 루트로 하는 서브 트리의 모든 노드의 개수
// == subRoot의 자식 노드들을 루트로 하는 모든 서브 트리의 노드 개수의 합
for (Node child : subRoot.children) {
int num = countNodes(child);
// child를 루트로 하는 서브 트리의 노드 개수를 메모이제이션
numOfNodes[child.num] = num;
count += num;
}
return count;
}
static class Node {
int num;
List<Node> children;
Node (int num) {
this.num = num;
this.children = new ArrayList<>();
}
}
}