[C++] 백준 2213: 트리의 독립집합

Cyan·2024년 8월 15일
0

코딩 테스트

목록 보기
161/166

백준 2213: 트리의 독립집합

문제 요약

그래프 G(V, E)에서 정점의 부분 집합 S에 속한 모든 정점쌍이 서로 인접하지 않으면 (정점쌍을 잇는 간선이 없으면) S를 독립 집합(independent set)이라고 한다. 독립 집합의 크기는 정점에 가중치가 주어져 있지 않을 경우는 독립 집합에 속한 정점의 수를 말하고, 정점에 가중치가 주어져 있으면 독립 집합에 속한 정점의 가중치의 합으로 정의한다. 독립 집합이 공집합일 때 그 크기는 0이라고 하자. 크기가 최대인 독립 집합을 최대 독립 집합이라고 한다.

문제는 일반적인 그래프가 아니라 트리(연결되어 있고 사이클이 없는 그래프)와 각 정점의 가중치가 양의 정수로 주어져 있을 때, 최대 독립 집합을 구하는 것이다.

문제 분류

  • 다이나믹 프로그래밍
  • 트리
  • 트리에서의 다이나믹 프로그래밍

문제 풀이

탑-다운 방식의 DP로 풀었다.
모든 정점이 트리로 연결되어 있다는 것이 첫 번째 핵심이다. 즉, 어느 노드를 루트 노드로 잡아도 상관이 없는데, 나는 0번 노드(첫 번째 노드)를 루트노드로 삼아 탐색했다.

sol(0, 0)sol(0, 1)을 모두 호출하여 두 값의 최댓값을 구한다. 처음에 0번 노드를 루트노드로 잡는다는 뜻이고, 0번 노드를 포함하지 않을 경우(0)와 포함하는 경우(1)를 모두 따져보는 것이다.
탐색 과정에서는 이렇게 현재 탐색중인 노드가 독립집합에 포함되는가 포함되지 않는가를 체크한다. 현재 노드를 독립집합에 포함시키면, 다음에 탐색하는 노드는 독립집합에 반드시 포함되지 않아야 한다. 그리고 마지막에 자신의 가중치 w[idx]를 누적시킨다.
현재 노드가 독립집합에 포함되지 않는다면, 다음에 탐색하는 노드는 독립집합에 포함될 수도 있고, 포함되지 않을 수도 있다. 그 두 가지 경우의 최댓값을 누적시켜서 구한다.
리프노드이고 자신을 독립노드에 포함시킨다면, 자신의 가중치 w[idx]를 반환한다.
트리의 부분집합 역시 트리인 점을 이용했다고 볼 수 있다.

이제 사용된 노드를 구해야한다. 나는 0번 노드를 루트노드로 삼아 해결했으므로, dp[0][0]혹은 dp[0][1]에 그 최종 결과값이 들어있을 것이다. 두 값을 서로 비교하여 처음 시작노드를 잡고, sol2()로 탐색한다. 여기 역시 비슷하다. 현재 노드가 포함된다면 정답 벡터인 ary에 추가한다. 그리고 다음 노드를 모두 포함하지 않는 쪽으로 탐색한다. 현재 노드가 포함되지 않는다면, 다음 노드가 포함될 경우와 포함되지 않을 경우를 비교하여 더 큰 방향으로 탐색한다.

마지막으로 ary를 정렬하여 순서대로 출력하면 된다.

풀이 코드

#include <stdio.h>
#include <iostream>
#include <algorithm>
#include <vector>
#include <memory.h>

using namespace std;

int dp[10000][2], w[10000], n;
bool visited[10000];

vector<int> v[10000], ary;

int sol(int idx, int p)
{
	if (dp[idx][p] > -1) return dp[idx][p];
	int flag = 1;
	visited[idx] = true;
	dp[idx][p] = 0;
	for (auto& it : v[idx]) {
		if (!visited[it]) {
			flag = 0;
			if (p)
				dp[idx][p] += sol(it, 0);
			else
				dp[idx][p] += max(sol(it, 0), sol(it, 1));
		}		
	}
	if (p) dp[idx][p] += w[idx];
	visited[idx] = false;
	if (flag && p) return w[idx];

	return dp[idx][p];
}

void sol2(int idx, int p)
{
	visited[idx] = true;
	if (p) {
		ary.push_back(idx + 1);
		for (auto& it : v[idx]) {
			if (!visited[it])
				sol2(it, 0);
		}
	}
	else {
		for (auto& it : v[idx]) {
			if (!visited[it]) {
				if (dp[it][0] >= dp[it][1])
					sol2(it, 0);
				else sol2(it, 1);
			}
		}
	}
	visited[idx] = false;
}

int main()
{
	int in, in2, res;
	memset(dp, -1, sizeof(dp));
	cin >> n;
	for (int i = 0; i < n; i++)
		scanf("%d", w + i);
	while (scanf("%d%d", &in, &in2) != EOF) {
		v[in2 - 1].push_back(in - 1);
		v[in - 1].push_back(in2 - 1);
	}
	res = max(sol(0, 0), sol(0, 1));
	cout << res << '\n';
	if (dp[0][1] == res)
		sol2(0, 1);
	else
		sol2(0, 0);

	sort(ary.begin(), ary.end());
	for (auto& it : ary)
		printf("%d ", it);
	return 0;
}

0개의 댓글