문제 : https://www.acmicpc.net/problem/11437
트리문제를 풀다가 트리 dp문제를 접하고 문제를 해결하지 못하고 '다른 트리 dp문제가 무엇이 있을까?'라는 생각에 찾게 된 문제이다.
초기 구상과정
LCA(Lowest Common Ancestor) : 최소 공통 조상.
트리 내부에서 두 개의 노드 사이의 최소로 갈 수 있는 공통 조상을 찾는 문제이다. 간단하게 초반에는 노드의 깊이들을 계산하여 가지고 있고 그 깊이를 기반으로 한땀한땀 올라가면 될 것이라고 생각했다. 다만, 난이도를 보고 시작해서 그런지 시간초과가 날 것 같다는 가벼운 예상이 되었다.(문제를 좀 풀어봐서 그런가? 쓸데없는 것만 느는것 같은 기분.)
초기 코드
초기코드는 어떻게 풀어봐도 해결점이 안나오고 답도 안나오길래 문제가 심각해서 폐기처분했다..
초기 코드에서부터 해결점
어떻게 하면 해결할 수 있을까 고민을 하다가 구선생님의 도움을 받았다. 직접적인 코드를 보고 내가 짜게 되면 그냥 베끼는 수준 밖에 되지 않으니 다른 사람들의 설명들을 보고 해결해보고자 했다.
알려져있는 알고리즘
비유로서 생각을 하니 나에게는 좀 편하게 느껴졌다. 처음에는 이게 무슨 소리인가 이해하기 힘들었었지만 비유로 생각하니 좀 편하게 다가오는 감이 있었다.
최종 코드
import sys
sys.setrecursionlimit(10 ** 8)
input = sys.stdin.readline
def depth(n, dep):
level[n] = dep
visit[n] = 1
for nod in info[n]:
if visit[nod]: continue
par0[nod] = n
depth(nod, dep + 1)
def lca(a, b):
if level[a] > level[b]:
a, b = b, a
for i in range(idx - 1, -1, -1):
if level[par[i][b]] >= level[a]:
b = par[i][b]
if a == b:
return a
for i in range(idx - 1, -1, -1):
if par[i][a] and par[i][b] and par[i][a] != par[i][b]:
a = par[i][a]
b = par[i][b]
return par[0][a]
N = int(input())
info = [[] for _ in range(N + 1)]
level = [0] * (N + 1)
par0 = [0] * (N + 1)
for _ in range(N - 1):
P, C = map(int, input().split())
info[P].append(C)
info[C].append(P)
visit = [0] * (N + 1)
depth(1, 1)
par = [par0]
deepest = max(level)
idx = 0
while deepest >= 1:
deepest //= 2
tmp1 = par[idx]
tmp2 = [0] * (N + 1)
for i in range(N + 1):
tmp2[i] = tmp1[tmp1[i]]
par.append(tmp2)
idx += 1
for _ in range(int(input())):
a, b = map(int, input().split())
print(lca(a, b))
알려져있는 알고리즘을 설명만을 읽어본 뒤 계속해서 디버깅으로 찍어가며 어떻게 동작하고 있는지 내 코드들을 살펴보면서 했다.
처음에 작성을 하며 '읽어오는 데이터를 어떻게 처리하면 좋을까?'를 굉장히 많이 고민했다. 인접리스트 형태로 받을지, 인접행렬 형태로 받을지도 고민을 많이 했고 또 깊이를 어떤 방식으로 저장해놓아야할까?를 고민했었다.
읽어오는 데이터를 어떻게 처리하면 좋을까?
이 부분에 대해서는 처음에는 인접행렬 형태로 받아봤었다. 인접행렬로 받으니 혹시라도 N이 커지게 되면 굉장히 쓸데없이 인덱싱을 해가며 데이터들을 찾아야하는 문제가 생길 것이 예상되어 인접리스트 형태로 받는 것으로 처리했다.
for _ in range(N - 1):
P, C = map(int, input().split())
info[P].append(C)
info[C].append(P)
깊이는 어떤 방식으로 저장해놓아야하나?
이것도 이것저것 시도를 해봤었다. "dictionary 형태로 받으면 좋지 않을까?"라는 생각이 들어서 lev : [nodes] 형태로 저장하는 dictionary로 받았었는데 level을 이용해서 접근하는 것은 쉬워도 node를 통해서 접근하는 것이 너무 번거롭고 시간도 많이 들게 되어서 list에 index를 번호로 잡고 저장하는 방식으로 각 노드들의 깊이를 저장했다.
level = [0] * (N + 1)
깊이를 저장하는 방식을 dfs를 이용해서 했는데 결국엔 깊이를 저장할 떄 부모의 정보 또한 들어간다는 것을 알아서 처음에는 dfs를 이용해서 level만 저장하고 따로 부모-자식 관계를 찾아보는 형태의 코드를 작성했었는데 이를 dfs를 사용해서 저장하는 방식으로 바꾸었다.
def depth(n, dep):
level[n] = dep
visit[n] = 1
for nod in info[n]:
if visit[nod]: continue
par0[nod] = n
depth(nod, dep + 1)
뛰어서 올라가는 것이 더 나은데 그를 위한 전처리는 어떻게 할까?
뛰어서 올라가는 것 = 2^n 위의 조상으로 올라가는 것
모든 노드에서 2^N 의 조상을 구해서 저장하고 있다면 '뛰어서 올라간다'는 것이 가능하다는 것을 알았다. 뛰어서 올라가기 위해서는 각 노드의 2^N을 구해야하는데 어떻게 구할까? 답부터 말하자면 이 때 dp를 이용하게 된다.
처음에 바로 위의 조상 정보를 저장할 수 있으니 dfs를 통해서 각 노드의 조상정보를 이미 구해놓았으니 이를 쓰자! >> 2^0의 정보
바로 위의 조상의 정보를 가지고 있으니 조상의 2^0의 정보를 가져와서 저장하면 2^1 위의 조상 정보를 가져오는 꼴이 된다.
2^1의 조상의 2^1 위 조상을 가져오면 2^2
.
.
.
이런 순이므로 2^N 위 조상 = 2^(N -1) 위 조상의 2^(N-1) 조상이라는 점화식이 되고,
par[N][i] = par[par[N - 1][i]][N - 1]이라는 dp식이 만들어지게 된다.
deepest = max(level)
idx = 0
while deepest >= 1:
deepest //= 2
tmp1 = par[idx]
tmp2 = [0] * (N + 1)
for i in range(N + 1):
tmp2[i] = tmp1[tmp1[i]]
par.append(tmp2)
idx += 1
덤으로 이 과정에서 가장 깊은 곳에서부터 최대로 끌어올릴 수 있는 인덱스 값을 찾아서 max로 가져와서 나중에 쓸 것을 저장한다.
이후 두 개의 노드를 받고 공통 조상을 찾는 방법을 쓰면 된다.
def lca(a, b):
if level[a] > level[b]:
a, b = b, a
for i in range(idx - 1, -1, -1):
if level[par[i][b]] >= level[a]:
b = par[i][b]
if a == b:
return a
for i in range(idx - 1, -1, -1):
if par[i][a] and par[i][b] and par[i][a] != par[i][b]:
a = par[i][a]
b = par[i][b]
return par[0][a]
두 개의 깊이 차이를 없애서 동등하게 만들고 두개를 순서대로 위로 올려보면 공통 조상으로 올라가 값이 같아지는 경우가 되면 답이므로 위의 코드를 사용하면 찾을 수 있다!