백준 2213 트리의 독립집합

wook2·2022년 4월 28일
0

알고리즘

목록 보기
90/117
post-custom-banner

https://www.acmicpc.net/problem/2213

트리에서 dp를 이용하는 문제이다.
트리는 어떤 정점을 잡고 나머지를 아래로 늘어트리면 어느 노드던지 루트가 될 수 있다. 나는 1을 루트노드로 잡고 문제를 해결했다.

이 문제는 전형적인 dp가 문제를 해결하는 방식을 착안하는데,
부모노드와 자식노드가 2개가 있는 트리를 한번 생각해보자.
이 트리에서 독립집합은 어떻게 구할 수 있을까?

루트노드부터 생각해 보았을때,
1) 루트노드가 독립집합에 있는 경우, 바로 아래 자식은 독립집합에 포함 될 수 없다.
2) 루트노드가 독립집합에 속해있지 않은 경우, 바로 아래 자식은 독립집합에 포함 할 수있다.

즉 2가지의 상태를 dp로 만들어야 하며, 노드마다 2개의 상태를 가질 수 있는 dp를 만들어야 한다.
dp를 [n][2] 로 만든다고 가정해 보겠다. n은 노드의 번호, 0과1은 해당 노드를 가지고 있음의 여부(1이면 해당 노드 포함)
그럼 루트 노드에서 독립집합의 최댓값을 구하려면, dp[n][0]과 dp[n][1]중 최대값을 구해야 한다.

dp[n][0]은 어떻게 구할까?
나는 이런 재귀적인 사고에 빠져드는 문제에서 이런 방식으로 접근한다.
바로 아래 노드의 dp값을 어떻게 구하는지는 모르겠지만, dp[n][0]을 구하기 위해서는 어떤 자식노드에서 dp[자식노드][0]와 dp[자식노드][1]중 큰 값을 가져오면 된다.

dp[n][0]은 dp[자식노드][1]의 값을 가져오면 된다.

이렇게 처리하다보면 마지막에 dfs의 밑바닥에서는
dp[node][0] = 0,dp[node][1] = 노드의 값이 되는데,
dfs에서 바닥을 찍고 올라가면서 필요했던 값들이 채워지게 된다.

재귀,dp,dfs등의 문제를 풀때는 재귀를 파고 들면서 어떻게 값이 변하는지를 추적하다 보면, 재귀의 깊이가 깊어지는 순간에 매우 복잡함을 느낄 것이다.
그렇기 때문에 일반화된 경우를 생각하고, 그저 코드로 옮기는 것이 재귀를 푸는 방법이라고 생각한다.

n = int(input())
arr = list(map(int,input().split()))
tree = [[] for i in range(n+1)]
dp = [[0]*2 for i in range(n+1)]
track = [[[],[]] for i in range(n+1)]
for i in range(n-1):
    a,b = map(int,input().split())
    tree[a].append(b)
    tree[b].append(a)
def dfs(node,visited):
    dp[node][1] = arr[node-1]
    track[node][1].append(node)
    for x in tree[node]:
        if not visited[x]:
            visited[x] = 1
            if not dp[x][0]:
                dfs(x,visited)
            visited[x] = 0
            dp[node][1] += dp[x][0]
            track[node][1].extend(track[x][0])
            if dp[x][0] > dp[x][1]:
                dp[node][0] += dp[x][0]
                track[node][0].extend(track[x][0])
            else:
                dp[node][0] += dp[x][1]
                track[node][0].extend(track[x][1])

v = [0]*(n+1)
v[1] = 1
dfs(1,v)
if dp[1][0] > dp[1][1]:
    track[1][0].sort()
    print(dp[1][0])
    print(*track[1][0])
else:
    track[1][1].sort()
    print(dp[1][1])
    print(*track[1][1])


profile
꾸준히 공부하자
post-custom-banner

0개의 댓글