[알고리즘] 크루스칼(Kruskal)

채상엽·2023년 3월 16일
0

SW사관학교 정글

목록 보기
11/35
post-thumbnail

크루스칼(Kruskal) 알고리즘

그리디(Greedy) 알고리즘을 기반으로 간선에 가중치를 할당한 그래프의 모든 정점을 최소 비용으로 연결하는 최적해를 구하는 알고리즘이다.

최소 신장 트리(Minimun Spanning Tree, MST)를 찾는 알고리즘에 프림 알고리즘과 크루스칼 알고리즘이 사용될 수 있다. 그 중 크루스칼 알고리즘에 대해 알아보려고 한다.

먼저 신장 트리란 무엇일까?

신장 트리(Spanning Tree)

신장트리란 하나의 그래프가 있을때, 모든 노드가 연결되어 있되 사이클은 존재하지 않는 '부분' 그래프를 의미한다.

최소 신장 트리란 이렇게 발생 가능한 신장 트리의 경우의 수들 중에서 가중치의 합이 최소가 되는 신장 트리의 경우를 최소 신장 트리(MST)라고 한다.

크루스칼 알고리즘 최적해 도출 과정

크루스칼 알고리즘을 이용해 최적해를 뽑아내는 과정을 요약하면 다음과 같다.

  1. 주어진 모든 간선의 가중치를 비용이 낮은 순서대로 정렬한다.
  2. 정렬된 간선 순서로 탐색하며, 현재의 간선이 노드들 간에 사이클을 발생시키는지 확인한다.
  3. 사이클이 발생할 경우, 가중치가 더 낮더라도 생략하고, 그 다음으로 낮은 가중치의 간선을 다시 확인하고, 사이클이 발생하지 않는다면 신장 트리에 포함시킨다.
  4. 1~3번의 과정을 모든 간선에 대해 반복 수행한다.

그래서 코드로 어떻게...?

과정 그 자체로는 어려운 알고리즘이 아닌 것 같다. 그렇다면 이를 코드로는 어떻게 구현해야할까?

먼저 사이클 여부를 판단하는 방법이다.

사이클 여부를 판단하기 위해서, 각 연결된 노드의 부모(root)를 저장하는 부모 테이블을 리스트로 관리한다. 그림을 통해 알아보자

위와 같은 신장 트리가 있다면, 다음과 같이 간선/비용 리스트와 부모(root)테이블 리스트를 선언할 수 있다.

사이클 여부는 이 테이블에서 각 노드의 부모테이블 값이 같으면 사이클이 생기는 것으로 보면 된다.

크루스칼 알고리즘을 이해했다면, 이제는 Union & Find 알고리즘을 이해해야한다.

Union & Find 알고리즘

대표적인 그래프 알고리즘으로 '합집합 찾기' 라는 의미를 갖는다. 또 다른 이름으로는 '상호 배타적 집합(Disjoint-set)' 이라고 불리기도 한다. 여러 노드가 존재할 때, 두 개의 노드를 선택해서 현재 두 노드가 서로 같은 그래프에 속하는지 판별하는 알고리즘이다.

크게 2가지 연산으로 나뉘어진다.

  • Find : x 또는 y가 어떤 집합에 포함되어 있는지 찾는 연산
  • Union : x와 y가 포함되어 있는 집합을 합치는 연산

이는 크루스칼 알고리즘과 Union & Find 알고리즘을 사용해 최소 신장 트리의 최적해를 구하는 코드를 보면 더 쉽게 이해할 수 있다.

아래 예제는 백준 1197 최소 스패닝 트리 예제를 풀이한 코드이다.

import sys

V, E = map(int, sys.stdin.readline().split())

edges = []
# 입력으로 들어온 간선을 저장한다. (시작노드, 끝노드, 간선 가중치)
for _ in range(E):
    edges.append(list(map(int, sys.stdin.readline().split())))

# 간선 가중치를 기준으로 오름차순 정렬한다.
edges.sort(key= lambda x: x[2]) 

# 각 노드의 부모(root) 테이블의 값을 초기화 한다.
# 초기에는 서로 연결된 노드가 없으므로, 노드 자신의 값이 부모 테이블의 값이 된다.
parents = [0 for _ in range(V+1)]
for i in range(1, V+1):
    parents[i] = i

# *경로압축을 통해 해당 노드의 부모 노드를 찾고, 갱신되지 않았다면 새롭게 부모 노드를 갱신한다. 
def find(x):
    if parents[x] != x:
        parents[x] = find(parents[x])
    return parents[x]

# 연결하고자 하는 두 노드를 연결한다. 통일성을 가져가기 위해 부모 노드는 둘 중 노드의 값이 더 작은 것으로 한다.
def union(start, end):
    start = find(start)
    end = find(end)

    if start < end:
        parents[end] = start
    else:
        parents[start] = end

result = 0
for i in range(len(edges)):
    start, end, cost = edges[i]
    # 사이클이 생기는것을 확인하기 위해, 연결하는 노드간에 부모 테이블의 값이 같은지 여부를 검사한다.
    # 같다면 연결을 생략하고 다음 간선을 탐색한다. 
    if find(start) != find(end):
    	# 다르다면, 두 노드를 합치고 부모 테이블의 값을 더 값이 작은 쪽의 노드로 갱신한다.
        union(start, end)
        result += cost # 합쳤다면 해당 간선의 가중치를 결과에 더한다.
 
print(result)

출처

profile
프로게이머 연습생 출신 주니어 서버 개발자 채상엽입니다.

0개의 댓글