1197 최소 스패닝 트리

정민용·2024년 3월 20일

백준

목록 보기
259/286

Kruskal Algorithm

import sys


class Tree:
    def __init__(self, val, parent, height=1):
        self.val = val
        self.parent = parent
        self.height = height


def find(node):
    if node.val == node.parent.val:
        return node
    node.parent = find(node.parent)
    return node.parent


def union(a, b):
    pa = find(a)
    pb = find(b)

    if pa != pb:
        if pa.height < pb.height:
            pa.parent = pb
        else:
            pb.parent = pa

        if pa.height == pb.height:
            pa.height += 1


def kruskal(graph):
    mst = 0

    for a, b, c in graph:
        pa = find(a)
        pb = find(b)

        if pa == pb:
            continue

        union(pa, pb)
        mst += c

    return mst


v, e = map(int, sys.stdin.readline().strip().split())
treedict = {}
graph = []

for _ in range(e):
    a, b, c = map(int, sys.stdin.readline().strip().split())

    if not (a in treedict.keys()):
        node = Tree(a, None)
        node.parent = node
        treedict[a] = node
    if not (b in treedict.keys()):
        node = Tree(b, None)
        node.parent = node
        treedict[b] = node

    graph.append((treedict[a], treedict[b], c))

graph.sort(key=lambda x: x[2])

sys.stdout.write(str(kruskal(graph)))

Prim Algorithm

import sys
from collections import defaultdict
import heapq

v, e = map(int, sys.stdin.readline().split())
graph = defaultdict(list)
visited = [0] * (v+1)

for _ in range(e):
    a, b, c = map(int, sys.stdin.readline().split())
    graph[a].append([c, a, b])
    graph[b].append([c, b, a])

def prim(graph, node):
    visited[node] = 1
    candidate = graph[node]
    heapq.heapify(candidate)
    total_weight = 0

    while candidate:
        w, u, v = heapq.heappop(candidate)

        if visited[v] == 0:
            visited[v] = 1
            total_weight += w

            for edge in graph[v]:
                if visited[edge[2]] == 0:
                    heapq.heappush(candidate, edge)

    return total_weight


sys.stdout.write(str(prim(graph, 1)))

1197 최소 스패닝 트리

0개의 댓글