뭔가 막연하게 최단 경로의 합을 구해야겠다 싶기는 했는데, 어떻게 구해야 할지 감이 안와서 힌트를 보니까 플로이드-워셜 알고리즘을 사용한다고 나와 있어서 참고했다.
핵심 아이디어는, 두 노드를 직접적으로 잇는 경로보다 특정 노드를 경유하는 경로의 가중치가 더 작은 경우 원래의 경로를 경유 경로로 덮어쓰면 된다. 경로는 물론 끊어져있을 수도 있고, 이 경우 만약에 어떤 방법으로든 두 노드를 잇는 경로가 있다면 경로 가중치가 무한대에서 일정 값으로 특정되겠지?
코드에 주석으로 각 단계가 어떤 일을 하는지 적어 놓기는 했는데, 눈에 잘 안들어오면 위에 링크해놓은 분께서 플로이드-워셜 알고리즘에 대해 잘 정리해 두셨으니 한 번 읽어보고 다시 살펴보면 좋을 것 같다.
from typing import List
from sys import stdin
input = stdin.readline
from math import inf
def solution(n: int, graph: List[List[bool]]) -> int:
# 연결되어 있지 않으면 inf로 처리함
result = [[inf] * n for _ in range(n)]
for src in range(n):
for dest in range(n):
# 어차피 순회 과정에서 src == dest는 건너 뛸거긴 한데,,
# src == dest는 시작지와 목적지가 같은 것이므로 비교할 필요가 없음
if src == dest:
result[src][dest] = 0
# 연결된 경우 해당 간선의 가중치를 1로 설정해줌 (bool -> int)
elif graph[src][dest]:
result[src][dest] = 1
# src-dest 연결보다 src-via-dest 연결이 더 짧은 경우 기존 연결을 via를 경유하는 방식으로 갱신해줌
# 경유지 via는 모든 노드가 될 수 있으므로 [0, n)을 범위로 함
for via in range(n):
for src in range(n):
# for dest in range(n)으로 안 한건, result가 대칭 행렬이므로
# 경로 가중치 최솟값을 갱신할 때마다 그 반대의 요소도 초기화해주면 되기 때문
# (N^2 -> 1/2 * N^2)
for dest in range(src, n):
# 1. via - via - dest == via - dest
# 2. src - via - via == src - via
# 3. src - dest == 순회할 필요 X
if via == src or via == dest or src == dest:
continue
# 경유하는 루트가 더 짧은 경우 대칭으로 갱신
new_route = result[src][via] + result[via][dest]
if new_route < result[src][dest]:
result[src][dest] = result[dest][src] = new_route
# (유저 번호, 점수) 쌍
bacon_nums = [(i, sum(result[i])) for i in range(n)]
# 점수 오름차순 정렬하되, 점수가 같은 경우 번호 오름차순 정렬
bacon_nums.sort(key=lambda p: (p[1], p[0]))
return bacon_nums[0][0]
if __name__ == "__main__":
n, m = map(int, input().split())
# 가독성을 위해 유저 번호를 1번이 아닌 0번부터 시작하도록 처리
graph = [[False] * n for _ in range(n)]
for _ in range(m):
a, b = map(int, input().split())
graph[a - 1][b - 1] = graph[b - 1][a - 1] = True
print(solution(n, graph) + 1)
일단,, 알고리즘을 왜 공부해야 하는지 덜컥 와닿았다. 막연하게 '이렇게 하면 될 것 같은데,,' 하는 일들은 최소한 코딩 테스트에 나올법한 범주 안에선 정해져 있고, 그 범주들 각각에 대응되는 이미 잘 정리된 방법론들이 대부분 있으니 단어에 겁먹지 말고 이것저것 많이 맛을 봐야겠다.
둘째로, 다른 사람들은 보통 if via == src or via == dest or src == dest:
처럼, 어찌 보면 사족일 수도 있는 코드를 잘 안쓰는 것 같다. 나는 손으로 테스트 케이스를 하나씩 그려가면서 하는 편이고, 특히나 그래프 탐색 문제는 대부분 무향 그래프여서 대칭되는 부분이나 무시해도 되는 부분을 메모하면서 푸는 편이라 자꾸 저런 코드가 나오는 것 같다.
심지어 pruning 하듯이 특정 케이스를 계산할 필요가 없으면 continue
하는 식으로 연산 횟수를 줄여 처리 속도를 높일 수 있을거라 생각했는데, 막상 저 코드가 들어가니까 한 10ms 정도 더 늦게 풀렸다.
테스트 케이스를 손으로 옮기는건 좋은데, 좀 더 알고리즘의 base case를 중심으로 먼저 생각해보는 습관을 들여야겠다.