[BOJ] 2887번 : 행성 터널(C++)[Gold I]

김준형·2021년 5월 24일
1

백준

목록 보기
10/13
post-thumbnail

Problem

Solution

  • MST 알고리즘의 시간복잡도는O(ElogV)이다. 그런데 모든 행성 간의 터널 연결 비용을 구하는데 O(N2)의 시간이 들고 N≤100,000이므로 다른 접근 방법이 필요하다.
  • 각 행성의 x,y,z 좌표 각각을 기준으로 삼아 정렬하면 x,y,z 좌표계에서의 행성 간 연결 비용을 O(NlogN)안에 구할 수 있다. 이때 행성 A와 B의 연결 비용을 3개의 좌표계 중 2개 이상의 좌표계에서 구할 수 있을 때는 그 중 최솟값을 A와 B의 연결 비용으로 정한다.
  • 모든 행성 간의 터널 연결 비용을 구한 것은 아니지만 적어도 모든 행성이 서로 연결되게 하는 터널들을 구한 것이다. 터널(간선)들과 연결 비용(가중치)를 알고 있으므로 MST 알고리즘을 적용한다.
#include <iostream>
#include <algorithm>
#include <vector>
#include <string>
#include <cstring>
#include <map>
using namespace std;

int N, p[100001];

void swap(int& x, int& y){
	int temp = x;
	x = y;
	y = temp;
}
/* Union-Find */
int Find(int n){
	if(p[n] == n) return p[n];
	return p[n] = Find(p[n]);
}

bool Union(int x, int y){
	x = Find(x);
	y = Find(y);
	if(x == y) return true;
	p[y] = x;
	return false;
}

int main(){
	scanf("%d",&N);
	
	vector<pair<int, int>> c[3]; //c[0] : x배열 c[1] : y배열 c[2] : z배열
	map<pair<int,int>, int> m; // ((n1, n2), dist)
	
	for(int i=0; i<N; i++){
		int x,y,z;
		scanf("%d%d%d", &x,&y,&z);
		c[0].push_back({x, i});
		c[1].push_back({y, i});
		c[2].push_back({z, i});
	}
	// x,y,z배열 정렬
	sort(c[0].begin(), c[0].end());
	sort(c[1].begin(), c[1].end());
	sort(c[2].begin(), c[2].end());
	
	for(int n=0; n<3; n++){
		for(int i=0; i<N-1; i++){
			int n1 = c[n][i].second, n2 = c[n][i+1].second; // n1, n2번째 행성 
			int c1 = c[n][i].first, c2 = c[n][i+1].first; // n1, n2번째 행성의 x|y|z좌표
			int dist = c2 - c1; // n1, n2번째 행성 x|y|z좌표의 거리
			if(n1 > n2) swap(n1, n2); // n1이 n2보다 작게 한다
			pair<int, int> p = {n1, n2};
			if(m.find(p) == m.end()) m[p] = dist;
			else m[p] = (m[p] < dist ? m[p] : dist);
		}	
	}
	
	vector<pair<int, pair<int, int>>> v; // (dist, (n1, n2))
	
	map<pair<int, int>, int>::iterator it;
	for(it=m.begin(); it!=m.end(); it++){
		int n1 = it->first.first, n2 = it->first.second;
		int dist = it->second;
		v.push_back({dist, {n1, n2}});
	}
	// dist값에 대해 정렬
	sort(v.begin(), v.end());
	
	int num = 0; 
	int sum = 0;
	for(int i=0; i<N; i++)
		p[i] = i;
	// MST 알고리즘
	for(int i=0; i<v.size(); i++){
		int n1 = v[i].second.first, n2 = v[i].second.second;
		int dist = v[i].first;
		
		if(num == N - 1) break;
		if(Union(n1, n2)) continue;
		
		num++;
		sum += dist;
	}
	printf("%d \n", sum);
}

Result

  • Union-Find 알고리즘의 이해가 부족하였다.

잘못된 Union 함수

bool Union(int x, int y){
	if(Find(x) == Find(y)) return true;
	p[y] = x;
	return false;
}
profile
코딩하는 군인입니다.

0개의 댓글