https://www.acmicpc.net/problem/2213
DP[N][0] : N번노드를 미포함하고 N번 노드를 루트로 하는 독립집합의 최대값.
DP[N][1] : N번노드를 포함하고 N번 노드를 루트로 하는 독립집합의 최대값.
DP[N][0] = max(DP[자식 노드][0], DP[자식 노드][1]) (모든 자식 노드의 값 더해줌)
DP[N][1] = score[N] + DP[자식 노드][0] (모든 자식 노드의 값 더해줌)
import java.io.*;
import java.util.*;
public class Main {
public static BufferedReader br;
public static BufferedWriter bw;
public static int N, result = 0;
public static int[] score;
public static int[][] dp;
public static boolean[] resChk;
public static List<List<Integer>> map = new ArrayList<>();
public static void input() throws IOException {
br = new BufferedReader(new InputStreamReader(System.in));
bw = new BufferedWriter(new OutputStreamWriter(System.out));
N = Integer.parseInt(br.readLine());
score = new int[N+1];
dp = new int[N+1][2];
resChk = new boolean[N+1];
for (int i = 0; i <= N; i++) {
map.add(new ArrayList<>());
Arrays.fill(dp[i], -1);
}
StringTokenizer st = new StringTokenizer(br.readLine());
for (int i = 1; i <= N; i++)
score[i] = Integer.parseInt(st.nextToken());
for (int i = 1; i < N; i++) {
st = new StringTokenizer(br.readLine());
int from = Integer.parseInt(st.nextToken());
int to = Integer.parseInt(st.nextToken());
map.get(from).add(to);
map.get(to).add(from);
}
}
//현재 노드 : cur, 이전 노드 : before, 포함 여부 : contain
public static int dfs(int cur, int before, int contain) {
if (dp[cur][contain] != -1) return dp[cur][contain];
int result = contain == 1 ? score[cur] : 0;
for (Integer next : map.get(cur)) {
if (next == before) continue;
//현재 노드 포함되는 경우 -> 다음 노드 미포함
if (contain == 1)
result += dfs(next, cur, 0);
//현재 노드 미포함되는 경우 -> 다음 노드 포함, 미포함 둘 다 고려
else
result += Math.max(dfs(next, cur, 0), dfs(next, cur, 1));
}
return dp[cur][contain] = result;
}
public static void check(int cur, int before, int contain) {
for (Integer next : map.get(cur)) {
if (next == before) continue;
if (contain == 1) {
check(next, cur, 0);
} else {
if (dp[next][0] > dp[next][1]) {
check(next, cur, 0);
} else {
resChk[next] = true;
check(next, cur, 1);
}
}
}
}
public static void solve() throws IOException {
dfs(1, 0, 0);
dfs(1, 0, 1);
if (dp[1][0] > dp[1][1]) {
result = dp[1][0];
check(1, 0, 0);
} else {
result = dp[1][1];
resChk[1] = true;
check(1, 0, 1);
}
//출력
bw.write(result + "\n");
for (int i = 1; i <= N; i++) {
if (!resChk[i]) continue;
bw.write(i + " ");
}
bw.write("\n");
bw.flush();
}
public static void main(String[] args) throws IOException {
input();
solve();
}
}