문제 풀이 - 가장 가까운 두 점(JAVA)

지식 저장소·2021년 11월 26일
0

코딩테스트

목록 보기
8/29

가장 가까운 두 점

풀이

무식하게 풀 수 있을까?

모든 점끼리 거리를 구해서 비교하면 풀 수 있습니다. 하지만 시간 복잡도가 O(n2)O(n^2)이므로 절대 시간 안에 해결할 수 없습니다.

분할 정복 알고리즘

분할 정복 알고리즘으로 문제를 해결하려면 세 가지 구성 요소를 만들어야 합니다.
첫 번째 구성 요소를 만들기 위해선 무엇을 기준으로 문제를 분할할지 결정해야 합니다. 가운데에 있는 점의 x좌표 기준으로 반으로 나누는 방식을 쉽게 떠올릴 수 있습니다. 하지만 이렇게 나누면 기준선을 사이에 두는 두 점의 거리는 비교하지 못합니다. 따라서 세 가지 문제를 해결해야 합니다.

  • 기준선의 왼쪽 부분에서 두 점 사이의 거리의 최솟값
  • 기준선의 오른쪽 부분에서 두 점 사이의 거리의 최솟값
  • 기준선을 사이에 두는 두 점 사이의 거리의 최솟값

병합 과정은 분할된 세 가지 문제의 답 중 가장 작은 값을 반환하면 됩니다. 하지만 문제가 있습니다. 세 번째 문제를 해결하기 위한 방법을 떠올려야 합니다. 기준선을 기준으로 모든 왼쪽의 점들과 모든 오른쪽의 점들을 비교하면 시간 복잡도가 O(n2)O(n^2)이므로 구할 수 없습니다. 여기서 한 가지 통찰을 떠올려야 합니다. 기준선을 사이에 두는 두 점은 기준선으로부터 첫 번째 문제의 답과 두 번째 문제의 답 중 최솟값 minDistanceminDistance보다 더 멀리 떨어진 점은 비교할 필요가 없습니다. 더 멀리 떨어져 있다면 첫 번째 문제의 답과 두 번째 문제의 답보다 더 크기 때문입니다. 우리는 최솟값을 구해야 하기 때문에 더 큰 값은 고려할 필요가 없습니다.
여기까지 생각했다면 만약 x좌표는 같고 y좌표만 다른 점들이 매우 많으면 비효율적이라는 것을 알 수 있습니다. 이 문제를 해결하기 위해서는 가운데에 있는 점들을 y좌표 기준으로 정렬하고 가장 아래에 있는 점부터 다른 위에 있는 점들을 순서대로 거리를 비교하는데, 순서대로 비교하다가 minDistanceminDistance보다 y좌표의 차이가 더 큰 점과 비교를 한다면 그 점은 비교를 그만 두고 다른 점을 비교해야 합니다. 그러면 어떤 점은 최대 minDistance1minDistance - 1개의 점들과의 거리를 구해야 하므로 시간 복잡도는 O(n)O(n)입니다.
이제 세 번째 구성 요소를 정해야 합니다. 나눠진 구역의 점의 개수가 몇 개 이하일 때부터 분할을 그만해야 하는지 정해야 합니다. 2개 이하일 때 분할을 그만 두고 두 점 사이의 거리를 반환하면 3개가 남았을 때 1개와 2개로 분할될 수도 있습니다. 나눠진 구역의 점의 갯수가 1개라면 거리를 구할 수 없으므로 3개가 남았을 때 세 점 사이의 거리를 모두 비교해서 최솟값을 반환하도록 해야 합니다.

시간 복잡도 분석

한 번 호출하면 반으로 분할하고 위의 세 번째 문제를 해결하기 위해 시간 복잡도가 O(n)O(n)인 함수를 호출하므로 T(n)=2T(n2)+nT(n)=2T({n\over 2})+n입니다. 마스터 정리에 의해 우리가 구현한 알고리즘의 시간 복잡도는 O(nlogn)O(n\log n)입니다.

구현

import java.util.*;

class Point {
    public int x;
    public int y;

    public Point(int x, int y) {
        this.x = x;
        this.y = y;
    }

    public int getDistance(Point point) {
        return (this.x - point.x) * (this.x - point.x) + (this.y - point.y) * (this.y - point.y);
    }
}

public class Main {

    // 모든 점의 개수
    public static int n;
    // 모든 점
    public static Point[] points;
    // 답
    public static int result;

    public static void input() {
        Scanner scanner = new Scanner(System.in);
        n = scanner.nextInt();
        points = new Point[n];

        for (int i = 0; i < n; i++) {
            int x = scanner.nextInt();
            int y = scanner.nextInt();
            points[i] = new Point(x, y);
        }

        Arrays.sort(points, new Comparator<Point>() {
            @Override
            public int compare(Point o1, Point o2) {
                return o1.x - o2.x;
            }
        });
    }

    public static void solve() {
        result = getMinDistance(0, n - 1);
    }

    // 두 점들 간의 거리의 최솟값을 반환합니다.
    public static int getMinDistance(int left, int right) {
        // 기저 사례: 점들의 개수가 3개 이하면 브루트 포스로 계산
        if (right - left + 1 <= 3) {
            return getMinDistanceByBruteForce(left, right);
        }
        // 가운데 점의 x좌표 기준으로 문제 분할
        int mid = (right + left) / 2;
        int distance = Math.min(getMinDistance(left, mid), getMinDistance(mid + 1, right));
        // 가운데 영역 안의 문제 해결
        int middleLeft = left;
        int middleRight = right;
        int temp = points[mid].x - points[middleLeft].x;
        while (temp * temp > distance) {
            middleLeft++;
            temp = points[mid].x - points[middleLeft].x;
        }
        temp = points[middleRight].x - points[mid].x;
        while (temp * temp > distance) {
            middleRight--;
            temp = points[middleRight].x - points[mid].x;
        }
        distance = Math.min(distance, getMinDistanceInMiddle(distance, middleLeft, middleRight));
        return distance;
    }
    // 가운데 영역 안의 두 점들 간의 거리의 최솟값 계산
    public static int getMinDistanceInMiddle(int minDistance, int left, int right) {
        int distance = minDistance;
        // 기저 사례: 점들의 개수가 3개 이하면 브루트 포스로 계산
        int size = right - left + 1;
        if (size <= 3) {
            return getMinDistanceByBruteForce(left, right);
        }
        // 가운데 영역 안의 점들은 y좌표 기준으로 정렬함
        Point[] midPoints = new Point[size];
        for (int i = 0; i < size; i++) {
            midPoints[i] = points[left + i];
        }
        Arrays.sort(midPoints, new Comparator<Point>() {
            @Override
            public int compare(Point o1, Point o2) {
                return o1.y - o2.y;
            }
        });
        // 2중 반복문이라서 시간 복잡도가 O(n^2)이라고 생각할 수 있지만 조건문 때문에 안쪽의 반복문은 최대 minDistance번 실행됨
        for (int i = 0; i < size; i++) {
            for (int j = i + 1; j < size; j++) {
                int temp = midPoints[j].y - midPoints[i].y;
                if (temp * temp >= distance) {
                    break;
                }
                distance = Math.min(distance, midPoints[i].getDistance(midPoints[j]));
            }
        }
        return distance;
    }
    // 브루트 포스로 영역 안의 두 점들 간의 거리의 최솟값 계산
    public static int getMinDistanceByBruteForce(int left, int right) {
        int distance = Integer.MAX_VALUE;
        for (int i = left; i <= right - 1; i++) {
            for (int j = i + 1; j <= right; j++) {
                distance = Math.min(distance, points[i].getDistance(points[j]));
            }
        }
        return distance;
    }

    public static void output() {
        System.out.println(result);
    }

    public static void main(String[] args) {
        input();
        solve();
        output();
    }
}

회고

구현하기 전엔 알고리즘에 문제가 없다고 생각했는데 계속 시간 초과가 떠서 코드를 하나하나 분석해보고 원인을 찾기 위해 여러 가지 가설을 세웠습니다.
1. Scanner를 써서 데이터를 입력받을 때 너무 오래 걸리나?
Scanner로 입력받던 걸 BufferedReader로 입력받아도 시간 초과가 떳습니다. 문제를 해결하고 나서 BufferedReader를 사용했을 때와 Scanner를 사용했을 때 시간을 비교했을 때 약 1.3배의 시간 차이가 있었지만 Scanner의 문제는 아니였습니다.
2. 거리 구할 때 Math.pow() 함수가 더 오래 걸리나?
부동소수점의 곱셈이 정수의 곱셈보다 오래 걸리긴 하지만 그게 문제는 아니였습니다.
3. 변수를 저장하지 않아서 계산을 한 번 더해서 오래 걸리나?

        for (int i = 0; i < size; i++) {
            for (int j = i + 1; j < size; j++) {
                if ((midPoints[j].y - midPoints[i].y) * (midPoints[j].y - midPoints[i].y) >= distance) {
                    break;
                }
                distance = Math.min(distance, midPoints[i].getDistance(midPoints[j]));
            }
        }

가운데 영역의 점들의 거리를 비교하기 위한 반복문에서 두 점의 y좌표의 차이를 변수에 저장하지 않고 비교할 때마다 계산해서 느린게 아닐까? 생각했지만

        for (int i = 0; i < size; i++) {
            for (int j = i + 1; j < size; j++) {
                int temp = (midPoints[j].y - midPoints[i].y);
                if (temp * temp >= distance) {
                    break;
                }
                distance = Math.min(distance, midPoints[i].getDistance(midPoints[j]));
            }
        }

이렇게 바꿔도 결과는 달라지지 않았습니다.
4. 배열을 정렬하는 것이 느린가? 배열리스트로 만들고 정렬해야하나?
혹시 몰라서 Collections.sort() 함수의 구현 코드를 살펴봤는데 배열리스트를 배열로 만들고 Arrays.sort()로 정렬한 뒤 그 배열을 다시 배열리스트로 만들고 반환하는 것을 알았습니다. 따라서 Collections.sort() 함수가 Arrays.sort()보다 빠를리 없습니다.
진짜 원인은 getMinDistance() 함수 안에 int middleLeft = 0;으로 구현했던 탓입니다. 가운데 영역을 정의할 때 함수를 호출할 때 입력받은 인자인 left부터 고려해야 하는데 0부터 고려했기 때문에 함수를 호출할 때마다 인자와 상관없이 시간 복잡도가 O(n)O(n)인 작업을 수행했습니다.
코딩할 때 이런 잔실수가 많은데 좀더 집중력을 길러야 할 것 같습니다.

profile
그리디하게 살자.

0개의 댓글