트리에 동적 계획법을 적용해 봅시다.
Java / Python
트리 DP의 기본을 다지는 문제. 아래에 풀이 설명도 있습니다!
이번 문제는 간선에 가중치와 방향성이 없는 임의의 루트 있는 트리가 주어졌을 때, 정점 U를 루트로 하는 서브트리에 속한 정점의 수를 출력하는 문제이다.
문제의 힌트에서 makeTree, countSubtreeNode 함수를 이용하면 쉽게 구현가능하다.
import java.io.*;
import java.util.*;
public class Main {
static int[] parent;
static int[] size;
static ArrayList<Integer>[] list, tree;
public static void main(String[] args) throws Exception {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(System.out));
StringTokenizer st = new StringTokenizer(br.readLine());
StringBuilder sb = new StringBuilder();
int N = Integer.parseInt(st.nextToken());
int R = Integer.parseInt(st.nextToken());
int Q = Integer.parseInt(st.nextToken());
parent = new int[N + 1];
for (int i = 0; i < N; i++) {
parent[i] = i;
}
size = new int[N + 1];
list = new ArrayList[N + 1];
tree = new ArrayList[N + 1];
for (int i = 0; i < list.length; i++) {
list[i] = new ArrayList<>();
tree[i] = new ArrayList<>();
}
for (int i = 1; i < N; i++) {
st = new StringTokenizer(br.readLine(), " ");
int u = Integer.parseInt(st.nextToken());
int v = Integer.parseInt(st.nextToken());
list[u].add(v);
list[v].add(u);
}
makeTree(R, -1);
countSubtreeNodes(R);
while (Q-- > 0) {
int query = Integer.parseInt(br.readLine());
sb.append(size[query]).append("\n");
}
bw.write(sb.toString() + "\n");
bw.flush();
bw.close();
br.close();
}
public static void makeTree(int curNode, int p) {
for (int node : list[curNode]) {
if (node != p) {
tree[curNode].add(node);
parent[node] = curNode;
makeTree(node, curNode);
}
}
}
public static void countSubtreeNodes(int curNode) {
size[curNode] = 1;
for (int node : tree[curNode]) {
countSubtreeNodes(node);
size[curNode] += size[node];
}
}
}
import sys
input = sys.stdin.readline
sys.setrecursionlimit(10**9)
def NodeCount(x):
cnt[x]=1
for i in tree[x]:
if not cnt[i]:
NodeCount(i)
cnt[x] += cnt[i]
N, R, Q = map(int, input().split())
tree = [[] for _ in range(N+1)]
cnt = [0]*(N+1)
for i in range(N-1):
a, b = map(int, input().split())
tree[a].append(b)
tree[b].append(a)
NodeCount(R)
for i in range(Q):
u = int(input())
print(cnt[u])