[백준] 2213번 트리의 독립집합

donghyeok·2023년 7월 15일
0

알고리즘 문제풀이

목록 보기
130/171

문제 설명

https://www.acmicpc.net/problem/2213

문제 풀이

  • 다이나믹 프로그래밍으로 풀이하였다.
  • 트리에서 DP를 사용하는 가장 대표적인 문제라고 볼 수 있다.
  • 점화식은 다음과 같이 구성된다.

    DP[N][0] : N번노드를 미포함하고 N번 노드를 루트로 하는 독립집합의 최대값.
    DP[N][1] : N번노드를 포함하고 N번 노드를 루트로 하는 독립집합의 최대값.
    DP[N][0] = max(DP[자식 노드][0], DP[자식 노드][1]) (모든 자식 노드의 값 더해줌)
    DP[N][1] = score[N] + DP[자식 노드][0] (모든 자식 노드의 값 더해줌)

  • 트리에서는 어떤 노드를 루트 노드로 놓아도 상관 없으므로 1번 노드를 루트로 두고 위와 같은 점화식을 바탕으로 DFS를 진행한다.
  • 또한 집합의 원소들을 출력할 때는 DP값을 바탕으로 재귀를 통해 추적하여 출력한다.

소스 코드 (JAVA)

import java.io.*;
import java.util.*;

public class Main {
    public static BufferedReader br;
    public static BufferedWriter bw;

    public static int N, result = 0;
    public static int[] score;
    public static int[][] dp;
    public static boolean[] resChk;
    public static List<List<Integer>> map = new ArrayList<>();

    public static void input() throws IOException {
        br = new BufferedReader(new InputStreamReader(System.in));
        bw = new BufferedWriter(new OutputStreamWriter(System.out));
        N = Integer.parseInt(br.readLine());
        score = new int[N+1];
        dp = new int[N+1][2];
        resChk = new boolean[N+1];
        for (int i = 0; i <= N; i++) {
            map.add(new ArrayList<>());
            Arrays.fill(dp[i], -1);
        }
        StringTokenizer st = new StringTokenizer(br.readLine());
        for (int i = 1; i <= N; i++)
            score[i] = Integer.parseInt(st.nextToken());
        for (int i = 1; i < N; i++) {
            st = new StringTokenizer(br.readLine());
            int from = Integer.parseInt(st.nextToken());
            int to = Integer.parseInt(st.nextToken());
            map.get(from).add(to);
            map.get(to).add(from);
        }
    }

    //현재 노드 : cur, 이전 노드 : before, 포함 여부 : contain
    public static int dfs(int cur, int before, int contain) {
        if (dp[cur][contain] != -1) return dp[cur][contain];
        int result = contain == 1 ? score[cur] : 0;
        for (Integer next : map.get(cur)) {
            if (next == before) continue;
            //현재 노드 포함되는 경우 -> 다음 노드 미포함
            if (contain == 1)
                result += dfs(next, cur, 0);
            //현재 노드 미포함되는 경우 -> 다음 노드 포함, 미포함 둘 다 고려
            else
                result += Math.max(dfs(next, cur, 0), dfs(next, cur, 1));
        }
        return dp[cur][contain] = result;
    }

    public static void check(int cur, int before, int contain) {
        for (Integer next : map.get(cur)) {
            if (next == before) continue;
            if (contain == 1) {
                check(next, cur, 0);
            } else {
                if (dp[next][0] > dp[next][1]) {
                    check(next, cur, 0);
                } else {
                    resChk[next] = true;
                    check(next, cur, 1);
                }
            }
        }
    }


    public static void solve() throws IOException {
        dfs(1, 0, 0);
        dfs(1, 0, 1);
        if (dp[1][0] > dp[1][1]) {
            result = dp[1][0];
            check(1, 0, 0);
        } else {
            result = dp[1][1];
            resChk[1] = true;
            check(1, 0, 1);
        }

        //출력
        bw.write(result + "\n");
        for (int i = 1; i <= N; i++) {
            if (!resChk[i]) continue;
            bw.write(i + " ");
        }
        bw.write("\n");
        bw.flush();
    }

    public static void main(String[] args) throws IOException {
        input();
        solve();
    }
}

0개의 댓글