BOJ 1389. 케빈 베이컨의 6단계 법칙 (S1)

급식·2023년 9월 26일
0

알고리즘

목록 보기
3/12
post-thumbnail

접근

뭔가 막연하게 최단 경로의 합을 구해야겠다 싶기는 했는데, 어떻게 구해야 할지 감이 안와서 힌트를 보니까 플로이드-워셜 알고리즘을 사용한다고 나와 있어서 참고했다.

핵심 아이디어는, 두 노드를 직접적으로 잇는 경로보다 특정 노드를 경유하는 경로의 가중치가 더 작은 경우 원래의 경로를 경유 경로로 덮어쓰면 된다. 경로는 물론 끊어져있을 수도 있고, 이 경우 만약에 어떤 방법으로든 두 노드를 잇는 경로가 있다면 경로 가중치가 무한대에서 일정 값으로 특정되겠지?


코드

코드에 주석으로 각 단계가 어떤 일을 하는지 적어 놓기는 했는데, 눈에 잘 안들어오면 위에 링크해놓은 분께서 플로이드-워셜 알고리즘에 대해 잘 정리해 두셨으니 한 번 읽어보고 다시 살펴보면 좋을 것 같다.

from typing import List
from sys import stdin

input = stdin.readline

from math import inf


def solution(n: int, graph: List[List[bool]]) -> int:
    # 연결되어 있지 않으면 inf로 처리함
    result = [[inf] * n for _ in range(n)]

    for src in range(n):
        for dest in range(n):
            # 어차피 순회 과정에서 src == dest는 건너 뛸거긴 한데,,
            # src == dest는 시작지와 목적지가 같은 것이므로 비교할 필요가 없음
            if src == dest:
                result[src][dest] = 0
            # 연결된 경우 해당 간선의 가중치를 1로 설정해줌 (bool -> int)
            elif graph[src][dest]:
                result[src][dest] = 1

    # src-dest 연결보다 src-via-dest 연결이 더 짧은 경우 기존 연결을 via를 경유하는 방식으로 갱신해줌
    # 경유지 via는 모든 노드가 될 수 있으므로 [0, n)을 범위로 함
    for via in range(n):
        for src in range(n):
            # for dest in range(n)으로 안 한건, result가 대칭 행렬이므로
            # 경로 가중치 최솟값을 갱신할 때마다 그 반대의 요소도 초기화해주면 되기 때문
            # (N^2 -> 1/2 * N^2)
            for dest in range(src, n):
                # 1. via - via - dest == via - dest
                # 2. src - via - via == src - via
                # 3. src - dest == 순회할 필요 X
                if via == src or via == dest or src == dest:
                    continue
                # 경유하는 루트가 더 짧은 경우 대칭으로 갱신
                new_route = result[src][via] + result[via][dest]
                if new_route < result[src][dest]:
                    result[src][dest] = result[dest][src] = new_route

    # (유저 번호, 점수) 쌍
    bacon_nums = [(i, sum(result[i])) for i in range(n)]
    # 점수 오름차순 정렬하되, 점수가 같은 경우 번호 오름차순 정렬
    bacon_nums.sort(key=lambda p: (p[1], p[0]))

    return bacon_nums[0][0]


if __name__ == "__main__":
    n, m = map(int, input().split())
    # 가독성을 위해 유저 번호를 1번이 아닌 0번부터 시작하도록 처리
    graph = [[False] * n for _ in range(n)]
    for _ in range(m):
        a, b = map(int, input().split())
        graph[a - 1][b - 1] = graph[b - 1][a - 1] = True

    print(solution(n, graph) + 1)

느낀 점

일단,, 알고리즘을 왜 공부해야 하는지 덜컥 와닿았다. 막연하게 '이렇게 하면 될 것 같은데,,' 하는 일들은 최소한 코딩 테스트에 나올법한 범주 안에선 정해져 있고, 그 범주들 각각에 대응되는 이미 잘 정리된 방법론들이 대부분 있으니 단어에 겁먹지 말고 이것저것 많이 맛을 봐야겠다.

둘째로, 다른 사람들은 보통 if via == src or via == dest or src == dest: 처럼, 어찌 보면 사족일 수도 있는 코드를 잘 안쓰는 것 같다. 나는 손으로 테스트 케이스를 하나씩 그려가면서 하는 편이고, 특히나 그래프 탐색 문제는 대부분 무향 그래프여서 대칭되는 부분이나 무시해도 되는 부분을 메모하면서 푸는 편이라 자꾸 저런 코드가 나오는 것 같다.

심지어 pruning 하듯이 특정 케이스를 계산할 필요가 없으면 continue 하는 식으로 연산 횟수를 줄여 처리 속도를 높일 수 있을거라 생각했는데, 막상 저 코드가 들어가니까 한 10ms 정도 더 늦게 풀렸다.

테스트 케이스를 손으로 옮기는건 좋은데, 좀 더 알고리즘의 base case를 중심으로 먼저 생각해보는 습관을 들여야겠다.

profile
뭐 먹고 살지.

0개의 댓글