
N의 최대 범위가 1,000,000이므로 의 시간복잡도로 정렬을 수행하면 된다.
이 경우 병합 정렬은 의 시간복잡도를 보장하니 해당 알고리즘을 통해 풀이를 진행해 보겠다.
먼저 병합 정렬이란 우리가 앞서 배운 분할 정복과 투 포인터를 합친 개념이다.
예를 들어 아래와 같은 데이터가 있다고 가정해보자.
e.g. [5, 3, 8, 4, 2, 7, 1, 6]
최초에는 각 숫자마다 그룹을 할당하여 8개 그룹으로 나눈다.
이 상태에서 2개씩 그룹을 합치며 오름차순으로 정렬하는 과정을 통해 해결한다.
8개 그룹 나누기
(5), (3), (8), (4), (2), (7), (1), (6)
2개씩 그룹을 합치기
(5)와 (3) 합친 후 정렬 => (3, 5)
(8)과 (4) 합친 후 정렬 => (4, 8)
...
2개 그룹을 구체적으로 병합하는 방법에서 투 포인터를 사용한다.
왼쪽 포인터와 오른쪽 포인터의 값을 비교하여 작은 값을 결과 배열에 추가하고 포인터를 오른쪽으로 1칸 이동시키는 것이다.
따라서 문제풀이 방법은 아래와 같다.
merge_sort(start, end) 함수 선언
end - start < 1인 경우, 다시 말해 그룹화한 원소가 모두 분할되어 1개인 경우 함수 종료
start와 end의 중간값 mid 계산
재귀호출을 통해 가장 처음의 mid값을 기준으로 왼쪽 그룹의 각각 원소들이 모두 그룹을 가지도록 분할
=> merge_sort(start, mid)
재귀호출을 통해 가장 처음의 mid값을 기준으로 오른쪽 그룹의 각각 원소들이 모두 그룹을 가지도록 분할
=> merge_sort(mid + 1, end)
왼쪽 그룹 시작지점인 idx1은 start, 오른쪽 그룹 시작지점인 idx2는 mid + 1에 위치하고 쵭종 병합된 배열에서 시작위치인 k는 start에 위치
양쪽 그룹의 index가 가리키는 값을 비교한 후 더 작은 수를 선택해 리스트에 저장한 후, 선택된 데이터의 index 값을 오른쪽으로 1칸 이동
왼쪽 또는 오른쪽 그룹에 대해 데이터가 남은 경우 삽입하여 데이터 정리
이를 코드로 작성하면 아래와 같다.
import sys
input = sys.stdin.readline
print = sys.stdout.write
# 병합정렬 알고리즘
def merge_sort(start, end):
# 그룹화한 원소가 1개인 경우 모두 분할하였으므로 종료
if end - start < 1: return
# start와 end의 중간값 mid 계산
mid = int((start + end) / 2)
# 재귀호출을 통해 가장 처음의 mid값을 기준으로
# 왼쪽 그룹의 각각 원소들이 모두 그룹을 가지도록 분할
merge_sort(start, mid)
# 재귀호출을 통해 가장 처음의 mid값을 기준으로
# 오른쪽 그룹의 각각 원소들이 모두 그룹을 가지도록 분할
merge_sort(mid + 1, end)
# tmp의 역할은 아래 2가지와 같다.
# 1. 각 그룹에 해당하는 값을 복사하여 인덱스에 대한 저장된 값을 비교한 후 저장함
# 2. 병합이 진행된 그룹에 한정하여 저장하는 방식 (for문이 재귀호출하는 점을 감안해라!)
# e.g. merge(1, 2, 4) 인 경우
# A = [0, 3, 5, 4, 8, 2, 7, 1, 6]
# tmp = [0, 3, 5, 4, 8, 0, 0, 0]
for i in range(start, end + 1):
tmp[i] = A[i]
k = start # 병합 정렬에서 정렬된 값을 원본 배열 A에 다시 저장하는 위치 인덱스
idx1 = start # 첫 번째 부분 배열의 시작 인덱스
idx2 = mid + 1 # 두 번째 부분 배열의 시작 인덱스
while idx1 <= mid and idx2 <= end:
# 양쪽 그룹의 index가 가리키는 값을 비교한 후 더 작은 수를 선택해 리스트에 저장한 후
# 선택된 데이터의 index 값을 오른쪽으로 1칸 이동
# index1이 가리키는 데이터가 더 큰 경우
if tmp[idx1] > tmp[idx2]:
A[k] = tmp[idx2]
k += 1 # 값을 넣었으니 최종 인덱스도 1칸 이동
idx2 += 1 # 1칸 이동
# index2가 가리키는 데이터가 더 큰 경우
else:
A[k] = tmp[idx1]
k += 1 # 값을 넣었으니 최종 인덱스도 1칸 이동
idx1 += 1 # 1칸 이동
# 한 쪽 그룹의 데이터가 모두 선택된 후 남은 다른 한쪽 그룹에 대한 남은 데이터 삽입
# 왼쪽 그룹에 대해 데이터가 남은 경우 삽입
while idx1 <= mid:
A[k] = tmp[idx1]
k += 1
idx1 += 1
# 오른쪽 그룹에 대해 데이터가 남은 경우 삽입
while idx2 <= end:
A[k] = tmp[idx2]
k += 1
idx2 += 1
N = int(input())
A = [0] * (N + 1)
tmp = [0] * (N + 1)
for i in range(1, N + 1):
A[i] = int(input())
merge_sort(1, N)
for i in range(1, N + 1):
print(str(A[i]) + '\n')
재귀함수를 필수로 사용하다보니, 코드의 실행흐름이 어떻게 동작하는지 이해하기 힘들 수 있다.
하지만 아래 사실만 기억하고 차근차근 접근해보면 깨달을 수 있을 것이다.
재귀함수는 스택 방식을 통해 코드가 실행된다.
e.g. 배열 A가 [5, 3, 8, 4, 2, 7, 1, 6] 인 경우 함수 재귀호출 흐름
merge_sort(1, 8) => [5, 3, 8, 4, 2, 7, 1, 6]
├── merge_sort(1, 4) => [5, 3, 8, 4]
│ ├── merge_sort(1, 2) => [5, 3]
│ │ ├── merge_sort(1, 1) → 종료
│ │ ├── merge_sort(2, 2) → 종료
│ │ ├── merge(1, 1, 2) → [3, 5]
│ ├── merge_sort(3, 4) => [8, 4]
│ │ ├── merge_sort(3, 3) → 종료
│ │ ├── merge_sort(4, 4) → 종료
│ │ ├── merge(3, 3, 4) → [4, 8]
│ ├── merge(1, 2, 4) → [3, 4, 5, 8]
│
├── merge_sort(5, 8) => [2, 7, 1, 6]
│ ├── merge_sort(5, 6) => [2, 7]
│ │ ├── merge_sort(5, 5) → 종료
│ │ ├── merge_sort(6, 6) → 종료
│ │ ├── merge(5, 5, 6) → [2, 7]
│ ├── merge_sort(7, 8) => [1, 6]
│ │ ├── merge_sort(7, 7) → 종료
│ │ ├── merge_sort(8, 8) → 종료
│ │ ├── merge(7, 7, 8) → [1, 6]
│ ├── merge(5, 6, 8) → [1, 2, 6, 7]
│
├── merge(1, 4, 8) → [1, 2, 3, 4, 5, 6, 7, 8]