Notion에서 작성한 글이라, 여기에서 더 깔끔하게 보실 수 있습니다! 😮😊
(240811 10:40 내용 추가)
costs
를 오름차순 정렬한 다음, 그리디하게 낮은 비용의 다리부터 놓으면서 정점이 모두 연결되는지 체크하면 된다.제한사항의 는 모든 정점이 서로 완전히 연결된(간선의 개수가 최대인) 상태인데, 그렇다면 최소한의 간선으로 모든 정점을 연결할 때의 간선 수는?
이를 종합하면, 비용에 따라 costs
를 오름차순 정렬한 다음, 낮은 비용의 간선부터 추가하면서 매번 이 간선의 추가로 사이클이 생기진 않는지 체크한다.
이를 반복하다가, 사이클이 생기지 않은 채로 개의 간선을 추가했다는 것은 모든 섬이 최소의 비용으로 연결됐다는 것을 의미하므로, 이 값을 return한다.
def cycle_check(adj, visited, cur, prev=None):
visited[cur] = True
for nxt in adj[cur]:
if not visited[nxt]:
if cycle_check(adj, visited, nxt, cur):
return True
elif nxt != prev: return True
return False
def solution(n, costs):
adj = [[] for _ in range(n)]
total = 0
e = 0
for u, v, cost in sorted(costs, key=lambda x: x[2]):
adj[u].append(v)
adj[v].append(u)
visited = [False]*n
if cycle_check(adj, visited, u):
adj[u].pop()
adj[v].pop()
else:
total += cost
e += 1
if e == n-1: return total
def find(x, parent):
if parent[x] == x: return x
parent[x] = find(parent[x], parent)
return parent[x]
def union(x, y, parent, rank):
xset = find(x, parent)
yset = find(y, parent)
if xset == yset: return
if rank[xset] < rank[yset]:
parent[xset] = yset
return
else:
if rank[xset] == rank[yset]:
rank[xset] += 1
parent[yset] = xset
return
def solution(n, costs):
total = 0
e = 0
parent = [*range(n)]
rank = [0]*n
for u, v, cost in sorted(costs, key=lambda x: x[2]):
if find(u, parent) != find(v, parent):
union(u, v, parent, rank)
total += cost
e += 1
if e == n-1: return total
(240815 17:30 내용 추가)
(이전 글)
정점 간의 거리를 저장하기 위해 인접 행렬을 선택했다. 인접 리스트에 비해 디버깅하기 상대적으로 편하다고 생각했고, 일단 단순하게나마 코드를 짜고 나서 바꿀 생각이었다. 지금 생각해보니 그냥 인접 리스트 썼어도 딱히 어려워지는 건 없다.
costs
를 cost
기준으로 오름차순 정렬한 다음, 정점에 간선을 하나씩 추가하는 방식으로 시작한다.
if not adj[u][v] or adj[u][v] > cost:
temp = temp - adj[u][v] + cost
이런 식으로 갱신을 하게 되면, 두 개 이상의 간선을 거쳐서 도달하는 비용이 차감되기때문에 정답보다 낮은 값이 출력된다. 추가한 간선들 자체는 adjacency matrix든 list든 따로 관리하고, 모든 정점이 연결되었는지 확인할 변수는 따로 마련하는 식으로 해결해볼 수 있을 것 같다. 내일 해보기!costs
에 대해 이 작접을 수행하게 되면 의 시간이 소모된다,,, 은 최대 , 그에 따라 는 최대 이지만 중간에 갱신이 일어나지 않을 때도 많을 테니 실제 실행시간은 이보다 적을 것이라고도 생각했고, 일단 짜고나서 시간초과가 나면 조금씩 깎아나가야겠다는 생각으로 짰다.찾아보니 MST, Union-Find라고 한다. 지금까지 공부했던 것들이 어느정도 탄탄하게 다져지고 나서 공부하려고 아직 찾아보진 않았는데,,, 어… 일단 내일 더 고민하면서 코드를 고쳐보다가 정 안되면 찾아보거나, 잠깐 미뤄두거나 해야겠다. MST는 제쳐두더라도 Union-Find는 난이도가 낮은 그래프 문제에서도 Union-Find로 엄청 빠른시간에 풀 수 있던데 이 부분은 꼭 찾아봐야겠다. 아직 그래프는 많이 부족한 것 같다,,, 😢
from collections import deque
def bfs(adj, v, visited):
queue = deque([(v, 0)])
visited[v] = True
while queue:
cur, prev_w = queue.popleft()
for nxt, cur_w in [(i, c) for i, c in enumerate(adj[cur]) if c]:
if not visited[nxt]:
visited[nxt] = True
if not adj[v][nxt] or prev_w + cur_w < adj[v][nxt]:
adj[v][nxt] = prev_w + cur_w
adj[nxt][v] = prev_w + cur_w
queue.append((nxt, prev_w+cur_w))
def solution(n, costs):
adj = [[0]*n for _ in range(n)]
temp = 0
mn = float('inf')
for u, v, cost in sorted(costs, key=lambda x: x[2]):
if not adj[u][v] or adj[u][v] > cost:
temp = temp - adj[u][v] + cost
adj[u][v] = cost
adj[v][u] = cost
for i in range(n):
visited = [False]*n
bfs(adj, i, visited)
if all(adj[0][i] for i in range(1, n)):
mn = min(mn, temp)
return mn