📖 백준 2887번 : https://www.acmicpc.net/problem/2887

| 시간 제한 | 메모리 제한 |
|---|---|
| 1 초 | 128 MB |
때는 2040년, 이민혁은 우주에 자신만의 왕국을 만들었다. 왕국은 N개의 행성으로 이루어져 있다. 민혁이는 이 행성을 효율적으로 지배하기 위해서 행성을 연결하는 터널을 만들려고 한다.
행성은 3차원 좌표위의 한 점으로 생각하면 된다. 두 행성 A(xA, yA, zA)와 B(xB, yB, zB)를 터널로 연결할 때 드는 비용은 min(|xA-xB|, |yA-yB|, |zA-zB|)이다.
민혁이는 터널을 총 N-1개 건설해서 모든 행성이 서로 연결되게 하려고 한다. 이때, 모든 행성을 터널로 연결하는데 필요한 최소 비용을 구하는 프로그램을 작성하시오.
첫째 줄에 행성의 개수 N이 주어진다. (1 ≤ N ≤ 100,000) 다음 N개 줄에는 각 행성의 x, y, z좌표가 주어진다. 좌표는 -109보다 크거나 같고, 109보다 작거나 같은 정수이다. 한 위치에 행성이 두 개 이상 있는 경우는 없다.
첫째 줄에 모든 행성을 터널로 연결하는데 필요한 최소 비용을 출력한다.
오랜만에 빠르게 풀이를 떠올리고 바로 풀어버려서 만족스러운 문제였다. 단순하게 생각하면 각 노드마다의 모든 간선을 찾은 후에, 최소 스패닝 트리를 찾으면 된다. 하지만 모든 노드 사이의 간선을 찾는 것은 O(n^2)이 걸리므로 시간 초과를 받게된다. 따라서 최소 스패닝 트리를 구성할 때 필요한 간선들의 후보를 잘 뽑아서 최소 스패닝 트리를 구성하면 될 것 같다는 생각을 했다.
문제에서 말하는 간선의 길이는 각 좌표값의 차의 절대값이다. 이 것에 대해서 조금 생각해보면 노드 사이의 연결되는 간선의 후보를 x,y,z를 기준으로 가장 인접한 노드 사이의 간선으로 생각할 수 있다. 각 노드들을 x,y,z를 기준으로 정렬을 해주고 인접한 노드 사이의 간선의 길이를 계산해서 저장하면 최대 30만개 정도의 간선 후보를 찾을 수 있다.
x,y,z를 기준으로 정렬해서 인접한 노드만을 보는 것이 가장 최적이다. 왜냐하면 문제에서 간선 길이의 정의를 min(|xA-xB|, |yA-yB|, |zA-zB|)라고 주었는데, 간선의 길이를 최소로 만드려면 각각의 좌표를 기준으로 가장 인접한 노드를 통해서 구한 값이 최소 길이라는 것이 자명하기 때문이다. 즉 정렬했을 때 인접하지 않은 두 노드를 고른 간선의 길이가 인접한 노드를 고른 간선의 길이보다 무조건 크기 때문이다.
앞서 말한데로 x,y,z를 기준으로 각각 정렬하고 인접한 노드 사이의 간선을 저장해주면 mst를 구하는 간단한 문제로 변한다. 코드를 보면 graph를 x, y, z를 기준으로 정렬하고 인접한 노드의 간선을 뽑아서 edges에 저장한다. 마지막으로 edges에 저장된 간선들로 mst를 구성하면 최소 비용을 찾을 수 있다. 걸리는 시간은 3nlog(n) + 3nα(3n))으로 최종 시간복잡도는 O(nlogn)이다.
#include <iostream>
#include <cmath>
#include <algorithm>
#include <vector>
#define ll long long int
using namespace std;
struct pos { ll x; ll y; ll z; int index; };
struct now { ll val; int u; int v; };
vector<pos> graph;
vector<now> edges;
bool cmp1(const pos &a, const pos &b) {
return a.x < b.x;
}
bool cmp2(const pos& a, const pos& b) {
return a.y < b.y;
}
bool cmp3(const pos& a, const pos& b) {
return a.z < b.z;
}
bool cmp4(const now& a, const now& b) {
return a.val < b.val;
}
int parent[100001], sz[100001];
int Find(int x) {
if (parent[x] == x) return x;
return parent[x] = Find(parent[x]);
}
void Union(int u, int v) {
int x = Find(u), y = Find(v);
if (x == y) return;
if (sz[x] < sz[y]) swap(x, y);
parent[y] = x;
sz[x] += sz[y];
sz[y] = 0;
}
bool isUnion(int u, int v) {
return Find(u) == Find(v);
}
int main() {
ios::sync_with_stdio(false); cin.tie(nullptr);
int n;
ll x, y, z;
cin >> n;
for (int i = 0; i < n; i++) {
cin >> x >> y >> z;
graph.push_back({ x,y,z,i });
parent[i] = i;
sz[i] = 1;
}
sort(graph.begin(), graph.end(), cmp1);
for (int i = 0; i < n - 1; i++) {
int u = graph[i].index, v = graph[i + 1].index;
x = abs(graph[i].x - graph[i + 1].x);
edges.push_back({ x,u,v });
}
sort(graph.begin(), graph.end(), cmp2);
for (int i = 0; i < n - 1; i++) {
int u = graph[i].index, v = graph[i + 1].index;
y = abs(graph[i].y - graph[i + 1].y);
edges.push_back({ y,u,v });
}
sort(graph.begin(), graph.end(), cmp3);
for (int i = 0; i < n - 1; i++) {
int u = graph[i].index, v = graph[i + 1].index;
z = abs(graph[i].z - graph[i + 1].z);
edges.push_back({ z,u,v });
}
sort(edges.begin(), edges.end(), cmp4);
ll sum = 0;
int cnt = 0;
for (auto& it : edges) {
if (!isUnion(it.u, it.v)) {
Union(it.u, it.v);
sum += it.val;
cnt++;
}
if (cnt == n - 1) break;
}
cout << sum;
return 0;
}