N개의 마을로 이루어진 나라가 있다. 편의상 마을에는 1부터 N까지 번호가 붙어 있다고 하자. 이 나라는 트리(Tree) 구조로 이루어져 있다. 즉 마을과 마을 사이를 직접 잇는 N-1개의 길이 있으며, 각 길은 방향성이 없어서 A번 마을에서 B번 마을로 갈 수 있다면 B번 마을에서 A번 마을로 갈 수 있다. 또, 모든 마을은 연결되어 있다. 두 마을 사이에 직접 잇는 길이 있을 때, 두 마을이 인접해 있다고 한다.
이 나라의 주민들에게 성취감을 높여 주기 위해, 다음 세 가지 조건을 만족하면서 N개의 마을 중 몇 개의 마을을 '우수 마을'로 선정하려고 한다.
'우수 마을'로 선정된 마을 주민 수의 총 합을 최대로 해야 한다.
마을 사이의 충돌을 방지하기 위해서, 만일 두 마을이 인접해 있으면 두 마을을 모두 '우수 마을'로 선정할 수는 없다. 즉 '우수 마을'끼리는 서로 인접해 있을 수 없다.
선정되지 못한 마을에 경각심을 불러일으키기 위해서, '우수 마을'로 선정되지 못한 마을은 적어도 하나의 '우수 마을'과는 인접해 있어야 한다.
각 마을 주민 수와 마을 사이의 길에 대한 정보가 주어졌을 때, 주어진 조건을 만족하도록 '우수 마을'을 선정하는 프로그램을 작성하시오.
다이나믹 프로그래밍
트리
트리에서의 다이나믹 프로그래밍
트리
를 이용하여 DP
를 구성하면 된다. 그리디로는 풀 수 없고, 완전탐색이지만 한번 방문했던 노드를 방문하는 경우가 많으므로 DP
로 풀면 된다. 방문한 노드와 우수 마을 여부에 따라 DP
를 반환하면 된다. DP[idx][0]
혹은 DP[idx][1]
의 의미는, idx
번 노드에서 우수 마을 여부 별 최대값을 말한다. 0
은 idx
노드가 우수 마을이 아닌 경우, 1
은 우수 마을로 선정할 경우이다. 즉, 1
은 현재 탐색중인 노드가 '우수 마을'인 경우, 0
은 그렇지 않은 경우로 생각한다.
현재 탐색중인 노드가 1
이어서 dp[idx][1]
을 반환하는 경우, 다음 인접한 노드는 반드시 0
이어야 한다. 따라서 인접한 노드의 다음으로 인접한 노드가 0
이거나 1
이거나의 경우로 나뉘어져서, 그 최대값을 dp[idx][1]
에 누적시킨 다음, 자신의 가중치도 누적시켜서 반환하면 된다.
반대로 현재 탐색중인 노드가 0
이어서, dp[idx][0]
을 반환하는 경우, 인접한 노드중에는 반드시 적어도 1
이 있어야 한다. 즉, 1
이 하나라도 있다면 0
이 있어도 상관 없다. 따라서 다음 인접한 노드가 0
일 경우와 1
일 경우의 최대값을 받아서 반환해준다. 본래 문제의 의도대로면 이에 대해 프로그래밍적 장치를 만들어서 확실하게 해야 하겠지만, 3
번 조건을 불만족하면서 최적해인 경우는 없음을 이용하였다. 즉, 최적해인 경우, 0
인 마을은 반드시 1
인 마을과 인접해 있다.
어떤 노드를 루트노드로 잡아도 상관이 없는 점, 1-3
2-3
과 같이 부모 노드가 반드시 오름차순은 아니라는 점도 생각해야 한다. 따라서 visited
배열이 필요하고, 적재적소에 맞게 visited
배열의 값도 관리해주어야 한다.
#include <stdio.h>
#include <iostream>
#include <algorithm>
#include <vector>
#include <map>
using namespace std;
int dp[10000][2] = { 0, }, n;
vector<int> v;
multimap<int, int> mm;
bool visited[10000];
int sol(int idx, int d) {
if (dp[idx][d] > 0) return dp[idx][d];
auto rg = mm.equal_range(idx);
visited[idx] = true;
if (d) {
for (auto& it = rg.first; it != rg.second; it++) {
if (visited[it->second]) continue;
auto rg2 = mm.equal_range(it->second);
visited[it->second] = true;
for (auto& it2 = rg2.first; it2 != rg2.second; it2++) {
if (!visited[it2->second])
dp[idx][d] += max(sol(it2->second, 0), sol(it2->second, 1));
}
visited[it->second] = false;
}
dp[idx][d] += v[idx];
}
else {
for (auto& it = rg.first; it != rg.second; it++) {
if (!visited[it->second])
dp[idx][d] += max(sol(it->second, 1), sol(it->second, 0));
}
}
visited[idx] = false;
return dp[idx][d];
}
int main() {
int in1, in2, res = 0;
scanf("%d", &n);
for (int i = 0; i < n; i++) {
scanf("%d", &in1);
v.push_back(in1);
}
for (int i = 1; i < n; i++) {
scanf("%d%d", &in1, &in2);
mm.insert({ in1 - 1,in2 - 1 });
mm.insert({ in2 - 1,in1 - 1 });
}
cout << max(sol(0, 1), sol(0, 0));
return 0;
}