https://www.acmicpc.net/problem/11725
단순 DFS또는 BFS 문제였지만, 메모리와 시간 제약이 생각보다 깐깐했던 문제였다.
처음은 아래와 같이 구현했었는데, 메모리 초과가 났다.
import sys
N = int(sys.stdin.readline())
# N의 크기가 100,000이기 때문에 2차원 리스트로 생성하면 100,000 * 100,000 = 10,000,000,000
# 이 되기 때문에 메모리 초과가 발생하였다.
isConnected = [[False for _ in range(N+1)] for _ in range(N+1)]
isVisits = [False for _ in range(N+1)]
for _ in range(N-1):
a, b = map(int, sys.stdin.readline().split())
isConnected[a][b] = True
isConnected[b][a] = True
results = [0 for i in range(N+1)]
def dfs(start):
isVisits[start] = True
for i in range(1, N+1):
if not isVisits[i] and (isConnected[start][i] or isConnected[i][start]):
dfs(i)
results[i] = start
dfs(1)
for i in range(2, N+1):
print(results[i])
메모리 초과를 개선한 코드는 아래와 같다. 기존처럼 2차원 배열이긴 하지만, 2차원 배열을 (N+1) * (N+1)로 모두 그리는 것이 아닌, 노드의 위치에 해당하는 리스트만 값을 추가하여 연결 관계를 표시할 수 있도록 수정하였다.
그러나 이번에는 시간초과가 발생하였다.
import sys
N = int(sys.stdin.readline())
isVisits = [False for _ in range(N+1)]
connections = [[] for _ in range(N+1)]
for _ in range(N-1):
a, b = map(int, sys.stdin.readline().split())
connections[a].append(b)
connections[b].append(a)
results = dict()
def dfs(start):
isVisits[start] = True
# 마찬가지로 N이 최대 100,000이기 때문에, 매번 재귀마다 1부터 100,000까지를 탐색하면
# 시간초과가 발생한다.
for i in range(1, N+1):
if not isVisits[i] and i in connections[start]:
dfs(i)
results[i] = start
dfs(1)
for i in range(2, N+1):
print(results[i])
이를 개선하기 위해서, 매번 재귀 때 모든 1부터 N까지의 경우의 수를 탐색하기보다는 앞선 connections
리스트에 저장된 인덱스내에서 연관관계를 가지고 있는 노드만 탐색하도록 수정해서 시간 초과를 해결할 수 있었다.
최종적으로 DFS를 이용해 정답을 받은 코드는 아래와 같다.
import sys
sys.setrecursionlimit(10**6)
N = int(sys.stdin.readline())
isVisits = [False for _ in range(N+1)]
connections = [[] for _ in range(N+1)]
for _ in range(N-1):
a, b = map(int, sys.stdin.readline().split())
connections[a].append(b)
connections[b].append(a)
results = dict()
def dfs(start):
isVisits[start] = True
# 현재 노드와 연결되어 있는 정점들에 대해서만 탐색을 진행한다.
for i in connections[start]:
if not isVisits[i]:
dfs(i)
results[i] = start
dfs(1)
for i in range(2, N+1):
print(results[i])
추가적으로 BFS로도 풀어보았다.
import sys
from collections import deque
N = int(sys.stdin.readline())
connections = [[] for _ in range(N+1)]
isVisits = [False for _ in range(N+1)]
for _ in range(N-1):
a, b = map(int, sys.stdin.readline().split())
connections[a].append(b)
connections[b].append(a)
queue = deque()
results = dict()
def bfs(start):
isVisits[start] = True
queue.append(start)
while queue:
popped = queue.popleft()
for i in connections[popped]:
if not isVisits[i]:
isVisits[i] = True
queue.append(i)
results[i] = popped
bfs(1)
for i in range(2, N+1):
print(results[i])