채완이는 신도시에 건물 사이를 잇는 양방향 도로를 만들려는 공사 계획을 세웠다.
공사 계획을 검토하면서 비용이 생각보다 많이 드는 것을 확인했다.
채완이는 공사하는 데 드는 비용을 아끼려고 한다.
모든 건물이 도로를 통해 연결되도록 최소한의 도로를 만들려고 한다.
채완이는 도로가 너무 많아 절약되는 금액을 계산하는 데 어려움을 겪고 있다.
채완이를 대신해 얼마나 절약이 되는지 계산해주자.
예산을 얼마나 절약 할 수 있는지 출력한다. 만약 모든 건물이 연결되어 있지 않는다면 -1을 출력한다.
핵심 아이디어는 n개의 도시를 위해 n-1개의 간선을 설치한다는 것 이다.
MST 유형의 문제라고 판단하고 Prim 알고리즘을 이용해서 풀었다.
초기에 다른 문제를 풀 때 처럼 코드를 짰는데, 시간 초과도 아니고 메모리 초과가 떴다.
from collections import deque
import sys
si = sys.stdin.readline
n, m = map(int, si().split())
graph = [[0] * (n + 1) for _ in range(n + 1)]
total_cost = 0
for i in range(m):
b1, b2, cost = map(int, si().split())
graph[b1][b2] = cost
graph[b2][b1] = cost
total_cost += cost
def solution(g):
visited = set()
visited.add(1)
w = 0
for _ in range(n - 1):
_min, _next = sys.maxsize, -1
for node in visited:
for j in range(1, n + 1):
if j not in visited and 0 < g[node][j] < _min:
_min = g[node][j]
_next = j
w += _min
visited.add(_next)
return w
print(total_cost - solution(graph))
아무래도 graph가 너무 무거워진 것이 원인인 듯 싶다. 코드를 조금만 바꿔서 구현하려고 했으나 실패.
heapq를 이용해서 같은 로직을 구현하니 바로 성공했다.
from heapq import *
from collections import defaultdict
import sys
si = sys.stdin.readline
n, m = map(int, si().split())
total = 0
graph = []
for _ in range(m):
n1, n2, weight = map(int, si().split())
graph.append((weight, n1, n2))
total += weight
def prim(g, start):
mst = 0
# 모든 간선의 정보를 담을 트리
tree = defaultdict(list)
for w, b1, b2 in g:
tree[b1].append((w, b1, b2))
tree[b2].append((w, b2, b1))
# 방문할 노드가 담긴 세트
visited = set()
visited.add(start)
# 시작노드와 연결된 간선들의 리스트
candidate_arr = tree[start]
# 시작 노드와 연결된 간선들을 힙 자료구조로 만든다. --> 가중치가 작은 순서로 나온다.
heapify(candidate_arr)
# print(f"candidate arr = {candidate_arr}")
while candidate_arr:
# 가중치, 현재 노드와 연결된 노드가 나옴.
w, b1, b2 = heappop(candidate_arr)
# print(f"heappop! w = {w}, b1 = {b1}, b2 = {b2}")
# 만약 연결된 노드가 방문하지 않았다면
if b2 not in visited:
# 방문처리를 한다. --> 현재 노드와 연결된 간선중 가장 가중치가 작으면서 가본 적이 없는 노드
visited.add(b2)
# 가중치를 더해준다.
mst += w
# print(f"visited = {visited}, now answer is {mst}")
# 다음 방문할 노드와 연결된 노드들을 순회
for node in tree[b2]:
# 방문한 적이 없는 노드라면 힙에 넣는다.
if node[2] not in visited:
# print(f"add candidate {node[2]}")
heappush(candidate_arr, node)
# print(f"now candidate_arr = {candidate_arr}")
# print(f"total = {total}")
# print(f"mst = {mst}")
return visited, total - mst
check, answer = prim(graph, 1)
if len(check) == n:
print(answer)
else:
print(-1)