union-find 알고리즘

박선영·2024년 9월 25일

알고리즘

목록 보기
1/1
post-thumbnail

union-find 알고리즘이란?

두 노드가 같은 유니온(그래프)에 있는지 찾는 알고리즘

자세한 알고리즘은 예제를 통해 알아보자


예제

오늘은 새 학기 새로운 반에서 처음 시작하는 날이다.

현수네 반 학생은 N명이다. 현수는 각 학생들의 친구관계를 알고 싶다. 모든 학생은 1부터 N까지 번호가 부여되어 있고, 현수에게는 각각 두 명의 학생은 친구 관계 가 번호로 표현된 숫자쌍이 주어진다.

만약 (1, 2), (2, 3), (3, 4)의 숫자쌍이 주어지면 1번 학생과 2번 학생이 친구이고, 2번 학생과 3번 학생이 친구, 3번 학생과 4번 학생이 친구이다. 그리고 1번 학생과 4번 학생은 2번과 3번을 통해서 친구관계가 된다.

학생의 친구관계를 나타내는 숫자쌍이 주어지면 특정 두 명이 친구인지를 판별하는 프로그램 을 작성하세요. 두 학생이 친구이면 “YES"이고, 아니면 ”NO"를 출력한다.

▣ 입력

첫 번째 줄에 반 학생수인 자연수 N(1<=N<=1,000)과 숫자쌍의 개수인 M(1<=M<=3,000)이 주어지고, 다음 M개의 줄에 걸쳐 숫자쌍이 주어진다. 마지막 줄에는 두 학생이 친구인지 확인하는 숫자쌍이 주어진다.

▣ 출력

첫 번째 줄에 “YES"또는 "NO"를 출력한다.

예시 입력

9 7
1 2
2 3
3 4
4 5
6 7
7 8
8 9
3 8

예시 출력

NO

설명

노드가 다 루트 노드인 초기 상태이다. 이 때 노드들은 전부 다른 유니온에 있다.

1 2 가 입력되었을 때 노드 1의 parent가 2로 바뀌어, 노드 1과 노드 2가 연결된다. 이 때 1, 2는 루트 노드가 2로 같으므로 같은 유니온에 있다.

전부 입력되었을 때의 상태이다.
1, 2, 3, 4, 5 노드는 루트노드가 5로 같으므로 같은 유니온에 있다.
6, 7, 8, 9 노드는 루트노드가 9로 같으므로 같은 유니온에 있다.

노드 3과 노드 8은 각각 루트노드가 5, 8로 다르므로 다른 유니온에 있는 것을 확인할 수 있다. 그러므로 답은 false이다.

코드

C

#include <stdio.h>

int parent[1001];

int find(int v) { //최상위 노드를 찾음
    if (v == parent[v])
        return v;

    return find(parent[v]);
}

void union(int x, int y) { //최상위 노드끼리 연결시킴
    x = find(x);
    y = find(y);

    parent[x] = y;
}

void set(int k) { //1부터 k까지 부모배열 값 초기화
    for(int i=1; i<=k; i++)
        parent[i]=i;
}

int main(void) {
    int n, m;
    scanf("%d %d", &n, &m);
    set(n);

    int a, b;
    for (int i=0; i<m; i++) { //m개 입력받고 연결
        scanf("%d %d", &a, &b);
        Union(a, b);
    }

    scanf("%d %d", &a, &b);
    if(find(a) == find(b))
        printf("YES\n");
    else
        printf("NO\n");
}

python

n, m = map(int, input().split())
parent = [i for i in range(n+1)]

#최상위 노드를 찾음
def find(x): 
    if lst[x] == x:
        return x
    lst[x] = root(lst[x])
    return lst[x]
    
# 최상위 노드끼리 연결시킴
def union(x, y): 
	root_x, root_y = find(x), root(y)
    parent[root_x] = root_y

# m개 입력받고 연결
for i in range(m):
    a, b = map(int, input().split())
    union(a, b)

a, b = map(int, input().split())
if find(a) == find(b):
	print("YES")
else:
	print("NO")

0개의 댓글