6497 전력난

정민용·2024년 3월 20일

백준

목록 보기
261/286

Kruskal Algorithm

import sys


class Tree:
    def __init__(self, val, parent=None, height=1):
        self.val = val
        self.parent = parent
        self.height = height


def find(node):
    if node == node.parent:
        return node
    node.parent = find(node.parent)
    return node.parent


def union(a, b):
    pa = find(a)
    pb = find(b)

    if pa != pb:
        if pa.height < pb.height:
            pa.parent = pb
        else:
            pb.parent = pa

        if pa.height == pb.height:
            pa.height += 1


m, n = map(int, sys.stdin.readline().split())
while m != 0 and n!= 0:
    tree_dict = {}
    graph = []

    total_dis, min_dis = 0, 0

    for _ in range(n):
        x, y, z = map(int, sys.stdin.readline().split())
        total_dis += z

        if not(x in tree_dict.keys()):
            node = Tree(x)
            node.parent = node
            tree_dict[x] = node
        if not(y in tree_dict.keys()):
            node = Tree(y)
            node.parent = node
            tree_dict[y] = node

        graph.append((tree_dict[x], tree_dict[y], z))

    graph.sort(key=lambda x:x[2])

    for x, y, z in graph:
        px = find(x)
        py = find(y)

        if px != py:
            union(px, py)
            min_dis += z

    sys.stdout.write(str(total_dis - min_dis) + "\n")
    m, n = map(int, sys.stdin.readline().split())

Prim Algorithm

import sys, heapq
from collections import defaultdict

n, m = map(int, sys.stdin.readline().split())
while n != 0:
    graph = defaultdict(list)
    visited = [0] * (n+1)
    total_dist = 0

    for _ in range(m):
        a, b, c = map(int, sys.stdin.readline().split())
        graph[a].append([c, a, b])
        graph[b].append([c, b, a])
        total_dist += c

    def prim(graph, node):
        visited[node] = 1
        candidate = graph[node]
        heapq.heapify(candidate)
        total_weight = 0

        while candidate:
            w, u, v = heapq.heappop(candidate)

            if visited[v] == 0:
                visited[v] = 1
                total_weight += w

                for edge in graph[v]:
                    if visited[edge[2]] == 0:
                        heapq.heappush(candidate, edge)

        return total_weight

    min_dist = prim(graph, 1)
    sys.stdout.write(str(total_dist - min_dist) + "\n")

    n, m = map(int, sys.stdin.readline().split())

6497 전력난

0개의 댓글