문제 : BOJ 2261 가장 가까운 두 점
코드를 읽고 이해하긴 했으나, 다시 공부가 필요하다...
내가 이해한 내용은 아래와 같다.
좌표 평면상의 점들을 x축으로 분할하고, y축으로 거리를 계산한다.
좌표평면상의 점들을 x축으로 정렬해서 중간 지점을 기준으로 분할하고,
좌/우의 분할에서 구한 최단 거리 중 작은 값을
중간 지점으로부터의 좌/우 간격으로 설정해 새로운 범위의 배열을 만들어 준다.이때 만들어진 배열은 y축으로 정렬하고,
현재까지 구한 최단 거리를 기준으로 탐색 기준을 바꿔주며 거리를 계산해준다.
탐색 기준인 점과 탐색 대상인 점 사이의 거리가 최단 거리보다 더 짧으면
최단 거리를 두 점 사이의 거리로 갱신해준다.
다른 내용은 주석으로 작성하였다.
[ 전체 코드 ]
import sys
n = int(sys.stdin.readline().rstrip())
points = sorted([list(map(int, sys.stdin.readline().split())) for _ in range(n)])
# 분할 정복을 위한 solution()
def solution(s, e):
# 점이 두개 이하일 때
if e - s <= 1:
return get_distance(points[s], points[e])
# 점이 세개일 때
if e - s == 2:
return min(get_distance(points[s], points[s+1]), get_distance(points[s+1], points[e]), get_distance(points[s], points[e]))
# 분할된 두 배열에서 찾은 최소 거리 중 작은 값을 min_dist에 저장
m = (s + e) // 2
min_dist = min(solution(s, m), solution(m+1, e))
# m을 기준으로 양 끝으로 min_dist 만큼 벌려준 범위를 탐색하기 위해
# 새로운 탐색 대상 new_arr 를 만들어준다.
mid_range = abs(m - get_index(points, min_dist, points[m], s, e))
new_start = m-mid_range if m >= mid_range else 0
new_end = m+mid_range if m+mid_range <= e else e
# new_arr 은 y축 좌표를 기준으로 정렬한다.
new_arr = sorted(points[new_start:new_end+1], key=lambda x: x[1])
# new_arr 내의 가장 짧은 두 점간 거리를 계산하여 min_dist를 최소값으로 갱신한다.
min_dist = get_mid_range_dist(new_arr, min_dist)
return min_dist
# y축으로 정렬된 좌표 간 가장 짧은 거리를 구하는 get_mid_range_dist()
def get_mid_range_dist(new_arr, min_dist):
# 2중 for문으로 모든 점과 점을 순회한다.
for i in range(len(new_arr)):
for j in range(i+1, len(new_arr)):
y_dist = (new_arr[i][1]-new_arr[j][1])**2
# 탐색 기준(i)과 대상(j)의 y축 거리가 min_dist 보다 클 경우
if y_dist >= min_dist:
# 다음 탐색 기준을 j로 바꿔주고
i = j-1
# 현재 범위에서 탐색을 종료한다.
break
# y축 거리가 min_dist보다 짧을 경우
# x축 거리도 구해준다.
tmp_dist = (new_arr[i][0]-new_arr[j][0])**2 + y_dist
# 두 거리의 합이 min_dist보다 작을 경우,
if tmp_dist < min_dist:
# min_dist를 갱신해주고
min_dist = tmp_dist
# 탐색 기준을 j로 바꿔준다.
i = j-1
break
return min_dist
# 탐색 기준과 주어진 거리를 이용해서 index를 구하는 get_index()
def get_index(arr, key, target, start, end):
s = start
e = end
while s < e:
m = (s + e) // 2 + 1
if abs(arr[m][0]-target[0])**2 > key:
e = m - 1
else:
s = m
return e
# 두 점의 거리를 반환하는 get_distance()
def get_distance(p1, p2):
return (p1[0]-p2[0])**2 + (p1[1]-p2[1])**2
sys.stdout.write(f'{solution(0, n-1)}')