재귀를 응용하는 알고리즘, 분할 정복을 익혀 봅시다.
Java / Python
가장 가까운 두 점을 구하는 문제. 잘 알려진 문제지만 상당히 어렵기 때문에 검색을 추천드립니다.
이번 문제는 2차원 평면상에 n개의 점이 주어졌을 때, 이 점들 중 가장 가까운 두 점을 구하는 프로그램을 작성하는 문제입니다. x좌표가 가까운 점들끼리 비교를 하기 위해 점들을 x좌표를 기준으로 오름차순 정렬하고 중앙을 기준으로 왼쪽, 오른쪽으로 나눈 영역에서의 최솟값을 비교하고 x좌표를 추려낸 후 y좌표를 기준으로 최솟값을 구해나가는 방식입니다..
import java.awt.Point;
import java.io.*;
import java.util.*;
public class Main {
public static void main(String[] args) throws NumberFormatException, IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(System.out));
StringTokenizer st;
ArrayList<Point> arrList = new ArrayList<>();
int N = Integer.parseInt(br.readLine());
for (int i = 0; i < N; i++) {
st = new StringTokenizer(br.readLine());
int x = Integer.parseInt(st.nextToken());
int y = Integer.parseInt(st.nextToken());
arrList.add(new Point(x, y));
}
Collections.sort(arrList, (p1, p2) -> p1.x - p2.x); // x좌표 기준 오름차순 정렬
bw.write(divCon(arrList, 0, N - 1) + "\n");
br.close();
bw.close();
}
// 두 점 사이 거리의 제곱 계산 함수
public static int getDist(Point p, Point q) {
return (p.x - q.x) * (p.x - q.x) + (p.y - q.y) * (p.y - q.y);
}
// 완전 탐색으로 가장 가까운 거리 찾는 함수
static int bruteForce(ArrayList<Point> arrList, int start, int end) {
int minDist = Integer.MAX_VALUE;
for (int i = start; i < end; i++) {
for (int j = i + 1; j <= end; j++) {
int k = getDist(arrList.get(i), arrList.get(j));
minDist = Math.min(k, minDist);
}
}
return minDist;
}
public static int divCon(ArrayList<Point> arrList, int start, int end) {
int n = end - start + 1;
// n <= 3 -> 가장 가까운 두 점 사이의 거리 찾기
if (n <= 3) {
return bruteForce(arrList, start, end);
}
int mid = (start + end) / 2;
// 중앙선을 기준으로 왼쪽 점들 중 가장 작은 거리(k1)
// 오른쪽 점들 중 가장 작은 거리 (k2)
// min(k1, k2)구하기
int d = Math.min(divCon(arrList, start, mid), divCon(arrList, mid + 1, end));
// 중앙선을 기준으로 양쪽으로 d 거리 이내에 들어오는 점들 고려
// y좌표 기준으로 오름차순 정렬
ArrayList<Point> band = new ArrayList<>();
for (int i = start; i <= end; i++) {
int t = arrList.get(mid).x - arrList.get(i).x;
if (t * t < d) {
band.add(arrList.get(i));
}
}
Collections.sort(band, (p1, p2) -> p1.y - p2.y);
// y좌표 기준으로 오름차순 정렬된 band의 각 요소
// -> 현재 좌표보다 큰 요소만 보면서 거리를 측정
for (int i = 0; i < band.size() - 1; i++) {
for (int j = i + 1; j < band.size(); j++) {
int t = band.get(j).y - band.get(i).y;
if (t * t < d) {
d = Math.min(d, getDist(band.get(i), band.get(j)));
} else { // d보다 거리가 큰 순간이 오면 반복문 종료
break;
}
}
}
return d;
}
}
import sys
N = int(sys.stdin.readline())
point = []
for _ in range(N):
x, y = list(map(int, sys.stdin.readline().split()))
point.append((x, y))
point.sort()
def getDist(a, b):
return (a[0] - b[0]) ** 2 + (a[1] - b[1]) ** 2
def solution(left, right):
if left == right:
return float('inf')
else:
mid = (left + right) // 2
min_dist = min(solution(left, mid), solution(mid + 1, right))
target_list = []
for i in range(mid, left - 1, -1):
if (point[i][0] - point[mid][0]) ** 2 < min_dist:
target_list.append(point[i])
else:
break
for j in range(mid + 1, right + 1):
if (point[j][0] - point[mid][0]) ** 2 < min_dist:
target_list.append(point[j])
else:
break
target_list.sort(key=lambda x: x[1])
for i in range(len(target_list) - 1):
for j in range(i + 1, len(target_list)):
if (target_list[i][1] - target_list[j][1]) ** 2 < min_dist:
dist = getDist(target_list[i], target_list[j])
min_dist = min(min_dist, dist)
else:
break
return min_dist
if len(point) != len(set(point)):
print(0)
else:
print((solution(0, len(point) - 1)))