
연결해야 하는 지점(황선자씨, 우주신)의 좌표가 순서대로 주어지고, 이미 연결된 두 지점의 정보가 주어질 때, 모든 지점을 연결하기 위한 최소 비용(거리)를 구하면 되는 문제.
정신나간 문제 설명에 살짝 어지러울 수 있지만 모든 지점을 최소 비용으로 연결해야하는 MST 문제임을 알 수 있습니다. 단, 이미 연결되어있는 지점을 어떻게 처리할지는 살짝 고민해야 할 요소입니다.
MST를 형성하기 위한 알고리즘으로 크루스칼 알고리즘과 프림 알고리즘이 있고, 체감상 Union-Find 알고리즘을 활용해야하는 크루스칼 알고리즘을 많이 활용하는 것 같습니다. 그 구현도 프림 알고리즘에 비해 크루스칼 알고리즘이 더 쉽다고 많이 알려져있는 것 같습니다.
하지만 저는 우선순위 큐를 활용하는 프림 알고리즘을 더 선호합니다. 그래프 이론을 BFS, 다익스트라로 먼저 접했기 때문에 그 구현이 비슷한 프림 알고리즘이 구현하기 더 쉽게 느껴지기 때문입니다. 따라서 이번 문제도 프림 알고리즘으로 접근해보도록 하겠습니다.
문제의 입력은 좌표 형태로 주어집니다. 우리는 MST를 형성하고 프림 알고리즘을 적용하기 위해 각 지점 간의 거리를 관리해야 합니다. 따라서 이 좌표 입력을 인접 행렬로 변환하는 작업을 먼저 거칩니다.
import sys
from typing import Tuple
from heapq import heappush, heappop
input = sys.stdin.readline
INF = 10e9
N, M = map(int, input().rstrip().split())
gods = [tuple(map(int, input().rstrip().split())) for _ in range(N)]
adj = [[INF for _ in range(N + 1)] for _ in range(N + 1)] # distance matrix
def get_distance(p1: Tuple[int], p2: Tuple[int]) -> float:
"""
두 좌표를 입력 받아 두 좌표 사이의 거리를 반환한다.
"""
return ((p1[0] - p2[0]) ** 2 + (p1[1] - p2[1]) ** 2) ** 0.5
for idx, points in enumerate(gods, 1):
for other_idx, other_points in enumerate(gods, 1):
# 본인과의 거리 제외
if idx == other_idx:
continue
# 좌표 사이의 거리 저장
adj[idx][other_idx] = get_distance(points, other_points)
인접행렬 형성에 입력받은 좌표 리스트들을 중첩 순회하는 과정이 필요합니다. N은 1,000 이하이니 무리는 없어보입니다.
다음 입력으로는 이미 연결되어 있는 두 지점의 정보가 주어집니다. 이미 연결되어있다는 것을 두 지점사이의 거리(비용)가 0이라는 것으로 표현한다면 프림 알고리즘 수행 과정에서 자동으로 이가 반영될 것입니다.
# 이미 연결되어있는 지점 처리
for _ in range(M):
st, en = map(int, input().rstrip().split())
# 이미 연결되어있는 비용은 0으로 처리
adj[st][en] = adj[en][st] = 0
여기서 주의할 점은 두 지점이 이미 연결되어있다고 해서 MST에 해당 연결 부가 포함된다고 확신할 수는 없습니다. 두 지점 사이의 거리(비용)이 0이더라도 해당 연결 부를 MST에 포함했을 때 다른 연결 부가 더 높은 비용으로 연결될 수도 있기 때문입니다.
따라서 위와 같이 비용을 0으로 처리해줘야 모든 경우 중에서 최적의 비용을 고려할 수 있게 됩니다. 시작부터 해당 연결부들을 MST에 포함하면 극단적인 예시에서 틀리게 됩니다.
이제 이미 연결되어있는 지점과 모든 지점의 거리 정보가 담긴 인접 행렬 형성이 되었으니, 프림 알고리즘을 적용해서 MST를 만드는 과정을 수행해보겠습니다.
heap = [(0, 1)] # 우선순위 큐 for prim
answer = 0
visited = [True] + [False for _ in range(N)] # MST 포함 여부
# prim
while heap:
w, u = heappop(heap)
if not visited[u]:
answer += w
visited[u] = True
for nxt_v, nxt_w in enumerate(adj[u]):
if not visited[nxt_v]:
heappush(heap, (nxt_w, nxt_v))
print(format(answer, ".2f")) # 둘째자리까지 반올림 (둘째자리까지 출력해야 하므로 round 안됨)
현재 MST에 포함되어 있는 지점을 기점으로 아직 MST에 포함되지 않은 지점 중 가장 작은 비용(거리)을 가진 지점을 포함해 나갑니다. 이 과정에서 우선 순위 큐 역할을 하는 heap이 활용됩니다.
MST에 지점을 포함할 때 해당 연결 부의 비용을 결과 변수(answer)에 더합니다. 이미 연결된 지점이 MST에 포함될 때는 0이 자동으로 반영되게 됩니다.
마지막으로 결과를 소수점 둘째자리까지 반올림하여 출력합니다. 여기서 round를 활용하게 되면 반올림 이후 소수점이 모두 0인 경우 첫째자리까지만 출력해줍니다.
ex) round(4.01, 2)을 출력하면 4.0이 출력됩니다.
따라서 format 함수를 적용해서 자동 반올림과 함께 둘째자리까지 정상적으로 출력되게 만들어주면 됩니다.