이전 포스트의 문제였던 우수마을과 거의 동일한 문제이다. 다른 점은 이전 포스팅으로 비유하면 우수마을로 선정된 노드들도 같이 출력하는 것이 다른 점이다.
우수마을로 선정된 노드들을 찾는 방법은 트리의 루트에서 역으로 내려가서 탐색하는 방법이 있다.
두 가지 상황이 존재한다.
참고로 우수마을(최대 독립 집합에 속한 노드)을 1 우수마을이 아닌 마을을 0이라고 하겠다.
1. 부모 노드가 1 : 자식 노드는 무조건 0이다.
2. 부모 노드가 0 : 자식 노드(child)는 0이거나 1이다. 이 때 dp[1][chlid]이 dp[0][child]보다 크다면 자식 노드(child)는 1이다. 아니라면 0이다.
1번은 자명하다. 2번을 살펴보자
2번에서 루트 노드가 0이 선택되었다는 말은 자식 노드 중에 무조건 1이 존재한다는 뜻이다. 그럼 루트의 자식 노드 중에서 누가 선택되었는지 알아보는 방법은 결국 dp[0][child]와 dp[1][child]를 비교해 선택되었는지 아닌지 살펴보는 방법 뿐이다. 왜냐하면 루트 노드를 만들 때 그렇게 선택당했기 때문이다. 이것을 재귀적으로 루트 노드에서부터 리프 노드까지 실행하면 된다. 부모 노드가 0일 때 가능성은 두 가지이다. 부모 노드의 부모 노드가 1이어서 부모 노드가 강제 0인 경우. 또 다른 경우는 부모 노드의 자식 중에 1인 노드가 하나라도 있는 경우이다. 부모 노드가 0일 때 자식 노드가 모두 0인 경우는 존재할 수 없다는 것을 이전 포스트에서 확인하였다. 그럴 경우 부모 노드는 1이 무조건적으로 선택되기 때문이다. 어쨌든 부모노드가 0이라면 자식 노드는 1과 0중 선택할 수 있다. 마찬가지로 dp[1][child]와 dp[0][child]를 비교하여 1인지 0인지 확인한다.
전체 코드는 다음과 같다.
#include<iostream>
#include<vector>
#include<algorithm>
using namespace std;
int n;
void DFS(vector<int> &w, vector<vector<int>> &edge, vector<vector<int>> &dp, int parent,int cur) {
if ( edge[cur].size() == 1 && edge[cur][0] == parent ) {
dp[0][cur] = 0;
dp[1][cur] = w[cur];
return;
}
dp[1][cur] += w[cur];
for ( int child : edge[cur] ) {
if ( child == parent )continue;
DFS(w, edge, dp, cur, child);
dp[0][cur] += max(dp[0][child], dp[1][child]);
dp[1][cur] += dp[0][child];
}
}
void check(vector<vector<int>> &dp, vector<vector<int>> &edge, vector<int> &res, int cur, int parent, bool is) {
bool isreal = 0;
if ( !is ) {
if ( dp[1][cur] > dp[0][cur] ) {
res.emplace_back(cur);
isreal = 1;
}
}
for ( int child : edge[cur] ) {
if ( child == parent )continue;
check(dp, edge, res, child, cur, isreal);
}
}
int main() {
ios_base::sync_with_stdio(0); cin.tie(NULL); cout.tie(NULL);
cin >> n;
vector<int> w(n + 1);
vector<vector<int>> edge(n + 1, vector<int>());
vector<vector<int>> dp(2, vector<int>(n + 1, 0));
vector<int> res;
for ( int i = 1; i < n + 1; i++ ) {
cin >> w[i];
}
for ( int i = 0; i < n - 1; i++ ) {
int a, b;
cin >> a >> b;
edge[a].emplace_back(b);
edge[b].emplace_back(a);
}
DFS(w, edge, dp, 0, 1);
check(dp, edge, res, 1, 0, 0);
sort(res.begin(), res.end());
cout << max(dp[0][1], dp[1][1]) << endl;
for ( int i = 0; i < res.size(); i++ ) {
printf("%d ", res[i]);
}
return 0;
}