모든 점끼리 거리를 구해서 비교하면 풀 수 있습니다. 하지만 시간 복잡도가 이므로 절대 시간 안에 해결할 수 없습니다.
분할 정복 알고리즘으로 문제를 해결하려면 세 가지 구성 요소를 만들어야 합니다.
첫 번째 구성 요소를 만들기 위해선 무엇을 기준으로 문제를 분할할지 결정해야 합니다. 가운데에 있는 점의 x좌표 기준으로 반으로 나누는 방식을 쉽게 떠올릴 수 있습니다. 하지만 이렇게 나누면 기준선을 사이에 두는 두 점의 거리는 비교하지 못합니다. 따라서 세 가지 문제를 해결해야 합니다.
병합 과정은 분할된 세 가지 문제의 답 중 가장 작은 값을 반환하면 됩니다. 하지만 문제가 있습니다. 세 번째 문제를 해결하기 위한 방법을 떠올려야 합니다. 기준선을 기준으로 모든 왼쪽의 점들과 모든 오른쪽의 점들을 비교하면 시간 복잡도가 이므로 구할 수 없습니다. 여기서 한 가지 통찰을 떠올려야 합니다. 기준선을 사이에 두는 두 점은 기준선으로부터 첫 번째 문제의 답과 두 번째 문제의 답 중 최솟값 보다 더 멀리 떨어진 점은 비교할 필요가 없습니다. 더 멀리 떨어져 있다면 첫 번째 문제의 답과 두 번째 문제의 답보다 더 크기 때문입니다. 우리는 최솟값을 구해야 하기 때문에 더 큰 값은 고려할 필요가 없습니다.
여기까지 생각했다면 만약 x좌표는 같고 y좌표만 다른 점들이 매우 많으면 비효율적이라는 것을 알 수 있습니다. 이 문제를 해결하기 위해서는 가운데에 있는 점들을 y좌표 기준으로 정렬하고 가장 아래에 있는 점부터 다른 위에 있는 점들을 순서대로 거리를 비교하는데, 순서대로 비교하다가 보다 y좌표의 차이가 더 큰 점과 비교를 한다면 그 점은 비교를 그만 두고 다른 점을 비교해야 합니다. 그러면 어떤 점은 최대 개의 점들과의 거리를 구해야 하므로 시간 복잡도는 입니다.
이제 세 번째 구성 요소를 정해야 합니다. 나눠진 구역의 점의 개수가 몇 개 이하일 때부터 분할을 그만해야 하는지 정해야 합니다. 2개 이하일 때 분할을 그만 두고 두 점 사이의 거리를 반환하면 3개가 남았을 때 1개와 2개로 분할될 수도 있습니다. 나눠진 구역의 점의 갯수가 1개라면 거리를 구할 수 없으므로 3개가 남았을 때 세 점 사이의 거리를 모두 비교해서 최솟값을 반환하도록 해야 합니다.
한 번 호출하면 반으로 분할하고 위의 세 번째 문제를 해결하기 위해 시간 복잡도가 인 함수를 호출하므로 입니다. 마스터 정리에 의해 우리가 구현한 알고리즘의 시간 복잡도는 입니다.
import java.util.*;
class Point {
public int x;
public int y;
public Point(int x, int y) {
this.x = x;
this.y = y;
}
public int getDistance(Point point) {
return (this.x - point.x) * (this.x - point.x) + (this.y - point.y) * (this.y - point.y);
}
}
public class Main {
// 모든 점의 개수
public static int n;
// 모든 점
public static Point[] points;
// 답
public static int result;
public static void input() {
Scanner scanner = new Scanner(System.in);
n = scanner.nextInt();
points = new Point[n];
for (int i = 0; i < n; i++) {
int x = scanner.nextInt();
int y = scanner.nextInt();
points[i] = new Point(x, y);
}
Arrays.sort(points, new Comparator<Point>() {
@Override
public int compare(Point o1, Point o2) {
return o1.x - o2.x;
}
});
}
public static void solve() {
result = getMinDistance(0, n - 1);
}
// 두 점들 간의 거리의 최솟값을 반환합니다.
public static int getMinDistance(int left, int right) {
// 기저 사례: 점들의 개수가 3개 이하면 브루트 포스로 계산
if (right - left + 1 <= 3) {
return getMinDistanceByBruteForce(left, right);
}
// 가운데 점의 x좌표 기준으로 문제 분할
int mid = (right + left) / 2;
int distance = Math.min(getMinDistance(left, mid), getMinDistance(mid + 1, right));
// 가운데 영역 안의 문제 해결
int middleLeft = left;
int middleRight = right;
int temp = points[mid].x - points[middleLeft].x;
while (temp * temp > distance) {
middleLeft++;
temp = points[mid].x - points[middleLeft].x;
}
temp = points[middleRight].x - points[mid].x;
while (temp * temp > distance) {
middleRight--;
temp = points[middleRight].x - points[mid].x;
}
distance = Math.min(distance, getMinDistanceInMiddle(distance, middleLeft, middleRight));
return distance;
}
// 가운데 영역 안의 두 점들 간의 거리의 최솟값 계산
public static int getMinDistanceInMiddle(int minDistance, int left, int right) {
int distance = minDistance;
// 기저 사례: 점들의 개수가 3개 이하면 브루트 포스로 계산
int size = right - left + 1;
if (size <= 3) {
return getMinDistanceByBruteForce(left, right);
}
// 가운데 영역 안의 점들은 y좌표 기준으로 정렬함
Point[] midPoints = new Point[size];
for (int i = 0; i < size; i++) {
midPoints[i] = points[left + i];
}
Arrays.sort(midPoints, new Comparator<Point>() {
@Override
public int compare(Point o1, Point o2) {
return o1.y - o2.y;
}
});
// 2중 반복문이라서 시간 복잡도가 O(n^2)이라고 생각할 수 있지만 조건문 때문에 안쪽의 반복문은 최대 minDistance번 실행됨
for (int i = 0; i < size; i++) {
for (int j = i + 1; j < size; j++) {
int temp = midPoints[j].y - midPoints[i].y;
if (temp * temp >= distance) {
break;
}
distance = Math.min(distance, midPoints[i].getDistance(midPoints[j]));
}
}
return distance;
}
// 브루트 포스로 영역 안의 두 점들 간의 거리의 최솟값 계산
public static int getMinDistanceByBruteForce(int left, int right) {
int distance = Integer.MAX_VALUE;
for (int i = left; i <= right - 1; i++) {
for (int j = i + 1; j <= right; j++) {
distance = Math.min(distance, points[i].getDistance(points[j]));
}
}
return distance;
}
public static void output() {
System.out.println(result);
}
public static void main(String[] args) {
input();
solve();
output();
}
}
구현하기 전엔 알고리즘에 문제가 없다고 생각했는데 계속 시간 초과가 떠서 코드를 하나하나 분석해보고 원인을 찾기 위해 여러 가지 가설을 세웠습니다.
1. Scanner를 써서 데이터를 입력받을 때 너무 오래 걸리나?
Scanner로 입력받던 걸 BufferedReader로 입력받아도 시간 초과가 떳습니다. 문제를 해결하고 나서 BufferedReader를 사용했을 때와 Scanner를 사용했을 때 시간을 비교했을 때 약 1.3배의 시간 차이가 있었지만 Scanner의 문제는 아니였습니다.
2. 거리 구할 때 Math.pow() 함수가 더 오래 걸리나?
부동소수점의 곱셈이 정수의 곱셈보다 오래 걸리긴 하지만 그게 문제는 아니였습니다.
3. 변수를 저장하지 않아서 계산을 한 번 더해서 오래 걸리나?
for (int i = 0; i < size; i++) {
for (int j = i + 1; j < size; j++) {
if ((midPoints[j].y - midPoints[i].y) * (midPoints[j].y - midPoints[i].y) >= distance) {
break;
}
distance = Math.min(distance, midPoints[i].getDistance(midPoints[j]));
}
}
가운데 영역의 점들의 거리를 비교하기 위한 반복문에서 두 점의 y좌표의 차이를 변수에 저장하지 않고 비교할 때마다 계산해서 느린게 아닐까? 생각했지만
for (int i = 0; i < size; i++) {
for (int j = i + 1; j < size; j++) {
int temp = (midPoints[j].y - midPoints[i].y);
if (temp * temp >= distance) {
break;
}
distance = Math.min(distance, midPoints[i].getDistance(midPoints[j]));
}
}
이렇게 바꿔도 결과는 달라지지 않았습니다.
4. 배열을 정렬하는 것이 느린가? 배열리스트로 만들고 정렬해야하나?
혹시 몰라서 Collections.sort() 함수의 구현 코드를 살펴봤는데 배열리스트를 배열로 만들고 Arrays.sort()로 정렬한 뒤 그 배열을 다시 배열리스트로 만들고 반환하는 것을 알았습니다. 따라서 Collections.sort() 함수가 Arrays.sort()보다 빠를리 없습니다.
진짜 원인은 getMinDistance() 함수 안에 int middleLeft = 0;으로 구현했던 탓입니다. 가운데 영역을 정의할 때 함수를 호출할 때 입력받은 인자인 left부터 고려해야 하는데 0부터 고려했기 때문에 함수를 호출할 때마다 인자와 상관없이 시간 복잡도가 인 작업을 수행했습니다.
코딩할 때 이런 잔실수가 많은데 좀더 집중력을 길러야 할 것 같습니다.