루트 없는 트리가 주어진다. 이때, 트리의 루트를 1이라고 정했을 때, 각 노드의 부모를 구하는 프로그램을 작성하시오.
첫째 줄에 노드의 개수 N (2 ≤ N ≤ 100,000)이 주어진다. 둘째 줄부터 N-1개의 줄에 트리 상에서 연결된 두 정점이 주어진다.
첫째 줄부터 N-1개의 줄에 각 노드의 부모 노드 번호를 2번 노드부터 순서대로 출력한다.
import sys
sys.setrecursionlimit(10**6)
n = int(sys.stdin.readline())
tree = [[] for _ in range(n+1)]
parent = [0] * (n+1)
for _ in range(n-1):
node1, node2 = map(int, sys.stdin.readline().split())
tree[node1].append(node2)
tree[node2].append(node1)
visited = [False] * (n+1)
def traverse(node=1, prev_node=0):
visited[node] = True
parent[node] = prev_node
for neighbor in tree[node]:
if not visited[neighbor]:
visited[neighbor] = True
traverse(neighbor, node)
traverse()
for i in range(2, n+1):
print(parent[i])
내가 위 테스트케이스들을 보고 먼저 답안으로 제출했던 것은 다음과 같은 코드였다.
import sys
n = int(sys.stdin.readline())
tree = dict()
tree[1] = 0
for _ in range(n-1):
node1, node2 = map(int, sys.stdin.readline().split())
if node1 in tree:
tree[node2] = node1
else:
tree[node1] = node2
for i in range(2, n+1):
print(tree[i])
tree
딕셔너리는 노드 번호를 key로, 그리고 그 노드 번호에 대응되는 parent node 번호를 value로 가진다.
나는 1번 노드를 기준으로 트리가 확장될 것이라고 가정하고 이 코드를 짰지만, 문제에 그런 조건이 없었기 때문에 근거 없이 코딩을 하는 것을 섣부른 일이었다.
실제로 이 코드는 위 두 개의 테스트케이스에 대해서는 옳은 출력을 내지만, 실제로 이 코드를 백준에 제출하면 런타임 에러(KeyError
)를 낸다.
왜냐하면 만약 입력으로 들어온 두 노드가 모두 다 tree에 등록되어 있지 않았다면 else
문 바로 다음 줄에서 딕셔너리의 키에 node1
가 없게 되기 때문이다.
그래서 다음으로 짠 코드는 아래 코드였다.
import sys
from collections import deque
n = int(sys.stdin.readline())
tree = dict()
tree[1] = 0
queue = deque()
for _ in range(n-1):
queue.append(tuple(map(int, sys.stdin.readline().split())))
while queue:
node1, node2 = queue.popleft()
if node1 in tree:
tree[node2] = node1
elif node2 in tree:
tree[node1] = node2
else:
queue.append(tuple([node1, node2]))
for i in range(2, n+1):
print(tree[i])
만약 지금 당장 노드 쌍을 처리할 수 없다면 큐에 집어넣어 놓고 나중에 처리하자는 발상이었는데, 결과적으로는 시간 초과로 문제를 틀려버렸다.
따라서 위와 같은 방법으로는 문제를 해결할 수 없겠다는 생각이 들어, 그래프 탐색을 활용해 문제를 풀어보기로 했다.
import sys
sys.setrecursionlimit(10**6)
n = int(sys.stdin.readline())
tree = [[] for _ in range(n+1)]
parent = [0] * (n+1)
for _ in range(n-1):
node1, node2 = map(int, sys.stdin.readline().split())
tree[node1].append(node2)
tree[node2].append(node1)
visited = [False] * (n+1)
def traverse(node=1, prev_node=0):
visited[node] = True
parent[node] = prev_node
for neighbor in tree[node]:
if not visited[neighbor]:
visited[neighbor] = True
traverse(neighbor, node)
traverse()
for i in range(2, n+1):
print(parent[i])
이 코드는 DFS를 사용하되, DFS 과정에서 parent
라는 리스트에 각 노드의 부모 노드 번호를 기록한다.
import sys
sys.setrecursionlimit(10**6)
n = int(sys.stdin.readline())
tree = [[] for _ in range(n+1)]
parent = [0] * (n+1)
for _ in range(n-1):
node1, node2 = map(int, sys.stdin.readline().split())
tree[node1].append(node2)
tree[node2].append(node1)
visited = [False] * (n+1)
입력을 받고, 알고리즘 수행에 필요한 변수와 컬렉션들을 초기화해준다.
tree
는 인접 리스트 방식으로 그래프를 표현한 것이고, parent
는 각 노드의 부모 노드 번호를, visited
는 이후 DFS에서 노드에 방문했는지 기록하기 위한 것이다.
이때, 만약 DFS를 재귀를 이용해 구현했다면, sys.setrecursionlimit()
을 사용하지 않으면 런타임 에러(RecursionError
)가 발생할 수 있으니 주의해야 한다.
백준에서는 최대 재귀 깊이의 기본값이 1000으로 설정되어 있는 반면, 이 문제에서 노드의 개수가 최대 10만 개이다.
그 밖의 tree
, parent
, visited
와 같은 리스트들은 인덱스 0에 해당하는 부분을 모두 비워, 노드 번호에 리스트 인덱스 번호가 일치하도록 했다.
이 문제에 제시된 트리는 무방향이기 때문에,
tree[node1].append(node2)
tree[node2].append(node1)
와 같이 두 노드에 모두 간선을 넣어주었다.
def traverse(node=1, prev_node=0):
visited[node] = True
parent[node] = prev_node
for neighbor in tree[node]:
if not visited[neighbor]:
visited[neighbor] = True
traverse(neighbor, node)
DFS를 구현한 부분이다. traverse()
함수에 이전 노드 번호도 argument로 넣어줌으로써, 각 노드의 부모 노드를 기록할 수 있게 했다.
DFS에서는 1번 노드에서부터 시작하여 다른 모든 노드까지 탐색을 진행하므로, DFS를 이용한다면 나 이전에 탐색을 진행한 노드가 나의 부모 노드임을 알 수 있다.
traverse()
for i in range(2, n+1):
print(parent[i])
traverse
함수를 실행한 뒤, index 0과 1을 제외한 parent
의 요소들을 출력해주면 정답으로 인정된다.