import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.List;
public class Main {
static int N;
static List<Integer>[] graph;
static int[] villages;
static int[][] dp;
public static void main(String[] args) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
N = Integer.parseInt(br.readLine());
graph = new ArrayList[N+1];
villages = new int[N + 1];
dp = new int[N + 1][2];
String[] inputs = br.readLine().split(" ");
for (int i = 0; i < N + 1; i++) {
graph[i] = new ArrayList<>();
}
for (int i = 0; i < inputs.length; i++) {
villages[i + 1] = Integer.parseInt(inputs[i]);
}
for (int i = 0; i < N - 1; i++) {
inputs = br.readLine().split(" ");
int A = Integer.parseInt(inputs[0]);
int B = Integer.parseInt(inputs[1]);
graph[A].add(B);
graph[B].add(A);
}
System.out.println(Math.max(dfs(1, 1, -1), dfs(1, 0, -1)));
}
//특정노드가 켜져 있을때의 최대값을 구하고 싶다면, 그 자손들을 모두 루트노드로 보고 이 루트노드들이 꺼져있을때의 최대값을 구해 더해주면 된다.
//특정 노드가 꺼져 있을때의 최대값을 구하고 싶다면, 그 자손들을 모두 루트노드로 보고, 각각의 루트노드가 켜져있을때와 꺼져있을때의 최대값중에 큰값을 더해주면 된다
private static int dfs(int root, int onOff, int parent) {
if (dp[root][onOff] != 0) {
return dp[root][onOff];
}
int count = 0;
if (onOff == 1) {
count = villages[root];
}
for (int adjNode : graph[root]) {
if (adjNode != parent) {
if (onOff == 1) {
count += dfs(adjNode, 0, root);
} else if (onOff == 0) {
count += Math.max(dfs(adjNode, 0, root), dfs(adjNode, 1, root));
}
}
}
dp[root][onOff] = count;
return count;
}
}
거의 다 풀었는데 점화식에서 사소한 실수를 했다.
특정 노드가 꺼져 있을 때 최대 값은 그 자손들이 각각 켜져있을때 혹은 꺼져있을 때의 최대값을 더해주면 된다.
다만, 모든 꺼진 마을은 켜져있는 마을이랑 붙어 있어야 한다는데 이 조건을 어떻게 자연스럽게 만족하는지 증명을 아직 하지 못했다. 직감적으로는 알 것 같은데. 내가 직감적인 증명은 이렇다.
예제에서는 1 - 2 - 6 - 7 마을이 연결되어있고 1을 루트라 잡았을 때 1과 7을 켜야만 정답이 나온다. 결국 우려하는 부분은 세 번 다 꺼져 있는 마을이 나오면 어떡하냐는 것인데 마을의 인구수는 양수이고 Max값을 이용해서 가져오기 때문에 ABC 다 꺼져있는 값을 가져올 이유가 없다. 도중에 하나는 켜져있는게 최대값이 더 크기 때문이다.