문제풀이 - 유니온 파인드(Union Find)

Jinho Lee·2023년 12월 6일
0

도입

개요

  • Union-Find는 상호 배타적으로 이루어진 집합을 표현하기 위해 만들어진 자료구조로, Disjoint Set(서로소 집합)이라고도 한다.
  • 서로소 관계에 있는 복수의 집합에 대하여 집합들을 병합하는 Union 연산과 원소가 어떤 집합에 속해있는지 판단하는 Find 연산을 지원하여 Union Find라는 이름이 붙었다.
    • 서로소 : 공통 원소가 없는, 교집합이 공집합인 집합들 사이의 관계.
  • 원소 간의 연결 여부를 판단하는 데 유용하다.
  • 기본적으로 트리 구조이지만 그 구조가 큰 의미를 가지지는 않는 특이한 그래프 알고리즘이다.

Find 연산

  • Find 연산은 입력된 노드의 부모를 재귀적으로 거슬러 올라가 최상위 노드의 값을 반환한다. 이 최상위 노드는 각 집합과 일대일 대응되어 노드의 소속을 알 수 있다.
  • Big O 표기법으로 O(n)이 걸린다고 한다.

예시

// i가 원소인 집합의 대표를 찾는다.
 
#include<bits/stdc++.h>
using namespace std;
 
int find(int i)
 
{
 
    // 만약 i가 스스로의 부모라면
    // 즉, i가 최상위 노드라면
    if (parent[i] == i) {
 
        // i가 이 집합의 대표이다.
        return i;
    }
    else {
 
        // 만약 i가 스스로의 부모가 아니라면
        // i가 이 집합의 대표가 아니므로
        // 재귀적으로 그 부모를 찾아 호출한다.
        return find(parent[i]);
    }
}
 
// The code is contributed by Nidhi goel

최적화

  • Find 연산에서 방문하는 각 노드마다 결과값을 반환하기 전에 리스트에 저장하여 경로를 압축할 수 있다.
  • 이 경우 평균 O(log n)으로 단축된다.
  • 예시
#include <bits/stdc++.h>
using namespace std;
 
int find(int i) 
{
 
    if (Parent[i] == i) {
 
        return i;
    }
    else { 
 
        int result = find(Parent[i]);
 
        // i의 노드를 직접 이 집합의 대표(최상위 노드) 바로 아래로 움직여
        // 결과를 저장(cache)한다.
        Parent[i] = result;
       
        return result;
     }
}

Union 연산

  • 두 개 이상의 인자를 받아 Find 연산으로 그 집합을 찾고, 이를 하나의 최상위 노드 아래에 하나의 트리로 병합한다.
  • 서로소 집합만을 다루므로 합집합 연산과 같다.
  • Find 연산만이 시간에 영향을 미쳐, 시간복잡도는 Find 연산과 동일하다.

예시

// i를 원소로 갖는 집합과 j를 원소로 갖는 집합을 병합한다.
 
#include <bits/stdc++.h>
using namespace std;
 
void union(int i, int j) {
 
    // i의 집합을 찾는다.
    int irep = this.Find(i),
 
    // j의 집합을 찾는다.
    int jrep = this.Find(j);
 
    // i의 집합을 모두 j의 집합으로 옮겨
    // i의 대표가 j의 대표가 되도록한다.
    this.Parent[irep] = jrep;
}

최적화

  • Union 연산은 최악의 경우에 트리를 편중시킬 수 있다.
  • 이를 해결하기 위해 별개의 리스트를 만들어 어떠한 기준에 따라 큰 트리(집합)에 작은 트리(집합)을 합치도록 할 수 있다.
    • Union by Rank : 트리의 깊이에 따라 병합
    • Union by Size : 집합의 크기에 따라 병합

적용 예시

// C++ implementation of disjoint set

#include <bits/stdc++.h>
using namespace std;

class DisjSet {
	int *rank, *parent, n;

public:

	// Constructor to create and
	// initialize sets of n items
	DisjSet(int n)
	{
		rank = new int[n];
		parent = new int[n];
		this->n = n;
		makeSet();
	}

	// Creates n single item sets
	void makeSet()
	{
		for (int i = 0; i < n; i++) {
			parent[i] = i;
		}
	}

	// Finds set of given item x
	int find(int x)
	{
		// Finds the representative of the set
		// that x is an element of
		if (parent[x] != x) {

			// if x is not the parent of itself
			// Then x is not the representative of
			// his set,
			parent[x] = find(parent[x]);

			// so we recursively call Find on its parent
			// and move i's node directly under the
			// representative of this set
		}

		return parent[x];
	}

	// Do union of two sets by rank represented
	// by x and y.
	void Union(int x, int y)
	{
		// Find current sets of x and y
		int xset = find(x);
		int yset = find(y);

		// If they are already in same set
		if (xset == yset)
			return;

		// Put smaller ranked item under
		// bigger ranked item if ranks are
		// different
		if (rank[xset] < rank[yset]) {
			parent[xset] = yset;
		}
		else if (rank[xset] > rank[yset]) {
			parent[yset] = xset;
		}

		// If ranks are same, then increment
		// rank.
		else {
			parent[yset] = xset;
			rank[xset] = rank[xset] + 1;
		}
	}
};

// Driver Code
int main()
{

	// Function Call
	DisjSet obj(5);
	obj.Union(0, 2);
	obj.Union(4, 2);
	obj.Union(3, 1);

	if (obj.find(4) == obj.find(0))
		cout << "Yes\n";
	else
		cout << "No\n";
	if (obj.find(1) == obj.find(0))
		cout << "Yes\n";
	else
		cout << "No\n";

	return 0;
}
#include <iostream>
#include <vector>
using namespace std;

int N, M;
int knowTruth;
int parent[53]; 
// 각 번호를 그룹 이름으로, 만난 사람을 원소로 갖는 그룹을 만든다.
// 0번은 진실을 아는 사람
int temp;
vector<int> party[53];
int partyNum;

int lie;

// 유니온 파인드(Union Find) : 각 원소가 속해있는 집합을 구별하고 찾기 위한 방법
int findParent(int x)
{
	if (parent[x] != x)
		return parent[x] = findParent(parent[x]);
	return x;
}

void merge(int a, int b)
{
	int x = findParent(a);
	int y = findParent(b);

	if (x != y)
	{
		if (x < y) // 그룹 번호가 더 작은 쪽으로 합친다.
			parent[y] = x;
		else
			parent[x] = y;
	}
}

int main()
{
	ios_base::sync_with_stdio(false);
	cin.tie(nullptr);
	cout.tie(nullptr);

	cin >> N >> M;
	cin >> knowTruth;

	for (int i = 0; i <= N; ++i)
		parent[i] = i;

	while (knowTruth--)
	{
		cin >> temp;
		parent[temp] = 0;
	}

	for (int i = 0; i < M; ++i)
	{
		cin >> partyNum;
		cin >> temp;
		party[i].push_back(temp);

		for (int j = 1; j < partyNum; ++j)
		{
			cin >> temp;
			party[i].push_back(temp);
			merge(party[i][0], party[i][j]);
		}
	}

	lie = M;

	for (int i = 0; i < M; ++i)
	{
		for (int j : party[i])
		{
			if (findParent(parent[j]) == 0)
			{
				--lie;
				break;
			}
		}
	}

	cout << lie;

	return 0;
}

참조

0개의 댓글