[알고리즘] 최소 신장 트리 (Kruskal, Prim 알고리즘)

donghyeok·2023년 2월 11일
0

알고리즘

목록 보기
13/20

최소 신장 트리 (Minimum Spanning Tree)

신장 트리란 (Spanning Tree)

  • 그래프 내의 모든 정점을 포함하며 최소 간선을 가지는 트리
  • n개의 정점을 가지는 그래프에서 최소 간선 수는 n-1개가 된다
  • 사이클이 존재하지 않는다.
  • 하나의 그래프에서 여러 신장 트리가 존재할 수 있다

최소 신장 트리 (Minimum Spanning Tree)

  • 신장 트리 중에서 가중치의 합이 가장 적은 신장 트리를 뜻한다.
  • 네트워크에 있는 모든 정점들을 가장 적은 수의 간선과 비용으로 연결하는 트리

1. Kruskal 알고리즘

  • 간선 선택을 기반으로 하는 알고리즘이다.
  • 시간 복잡도는 O(ElogE)로 그래프에 적은 숫자의 간선을 가지는 희소 그래프 (Sparse Graph)에 유리하다.

알고리즘 개요

  1. 그래프의 간선을 가중치의 오름차순으로 정렬한다.
  2. 간선들을 순회하며 사이클을 만드는 간선을 제외하고 선택한다. (Union-Find 알고리즘)

알고리즘 구현

public int find (int a) {
	if (parent[a] == a) return a;
    else return parent[a] = find(parent[a]);
}

public void union (int a, int b) {
	a = find(a);
    b = find(b);
    if (a != b)
    	parent[b] = a;
}

//costs[n][0], costs[n][1] : 연결된 두 노드 번호
//costs[n][2] : 연결된 간선 가중치 
//return : 최소 가중치 리턴 
public int kruskal(int[][] costs) {
	Arrays.sort(costs, (o1, o2) -> o1[2] - o2[2]);
    int result = 0;
    for (int i = 0; i < costs.length; i++) {
        int a = costs[i][0];
        int b = costs[i][1];
        int val = costs[i][2];
        if (find(a) != find(b)) {
            union(a, b);
            result += val;
        }
    }
    return result;
}

2. Prim 알고리즘

  • 정점 선택을 기반으로하는 알고리즘이다.
  • 시간 복잡도는 O(n^2)으로 그래프에 간선이 많이 존재하는 밀집 그래프(Dense Graph)에 유리하다.

알고리즘 개요

  1. 시작 정점을 우선 순위 큐에 넣어준다. (가중치 0)
  2. 우선 순위 큐에서 가중치가 가장 작은 노드를 빼고 해당 노드에서 인접한 방문하지 않은 노드를 모두 넣어준다. (큐에서 뺄때 방문 여부를 true로 지정한다.)
  3. 2)의 과정을 큐에 원소가 없을 때까지 반복한다.

알고리즘 구현

public class Point implements Comparable<Point>{
        int node, cost;
        
        public Point (int node, int cost) {
            this.node = node;
            this.cost = cost;
        }
        
        @Override
        public int compareTo (Point o) {
            return this.cost - o.cost;
        }
    }
    
public int prim() {
        boolean[] visit = new boolean[n];
        PriorityQueue<Point> q = new PriorityQueue<>();
        q.add(new Point(0, 0));
        
        int result = 0;
        while(!q.isEmpty()) {
            Point cur = q.poll();
            
            //큐에서 뺄때 방문체크를 진행해야함
            if (visit[cur.node]) continue;
            visit[cur.node] = true;
            result += cur.cost;
            
            for (int i = 0; i < map.get(cur.node).size(); i++) {
                int next = map.get(cur.node).get(i).node;
                int cost = map.get(cur.node).get(i).cost;
                if (visit[next]) continue;
                q.add(new Point(next, cost));
            }
        }
        
        return result;
    }

0개의 댓글