오름차순으로 정렬된 두 개의 배열 A, B가 있다고 생각해보자
A = [1, 3, 8]
B = [2, 4, 7]
두 배열 A, B를 '정렬이 유지되도록 합친' 배열 C를 만들려면 어떻게 해야 할까?
A[0]
과B[0]
의 크기를 비교하여, 더 작은A[0]
를C[0]
에 넣는다.
A[1]
과B[0]
의 크기를 비교하여, 더 작은B[0]
를C[1]
에 넣는다.
A[1]
과B[1]
의 크기를 비교하여, 더 작은A[1]
를C[2]
에 넣는다.
A[2]
과B[1]
의 크기를 비교하여, 더 작은B[1]
를C[3]
에 넣는다.
A[2]
과B[2]
의 크기를 비교하여, 더 작은B[2]
를C[4]
에 넣는다.
- 마지막으로, 남은
A[2]
를C[5]
에 넣는다.
짠. 정렬된 배열 C가 만들어졌다.
C = [1, 2, 3, 4, 7, 8]
만드는 방법은 아래의 내용을 반복한다.
각 배열 A, B의 가장 앞부분 부터 비교를 시작해서
작은 값을 배열 C에 넣어주고,
작은 값을 가지고 있던 배열의 비교 위치를 1 증가시킨다.
그리고, 마지막에 남은 배열의 모든 값을 차례로 C에 넣어주며 종료된다.
이렇게, 정렬된 두 부분 배열을 병합하여
전체 배열이 정렬되게 하는 전략을 사용하는 정렬 방식이
바로 Merge Sort
이다.
정렬된 두 부분배열을 만들기 위해서는 원본 배열의 분할이 필요하다.
원본 배열 arr
을 두개로 나누기 위해,
전달 받은 배열의 첫번째 index인 start
와 마지막 index인 end
가 필요하고
중간 지점을 계산하여 전달 받은 배열을 둘로 나눠준다.
def merge_sort(arr, start, end):
# 1
mid = (start + end) // 2
left_arr = merge_sort(arr, start, mid)
right_arr = merge_sort(arr, mid+1, end)
배열을 return
해주는 조건은
분할된 부분배열의 길이가 1이되어 더 이상 분할 할 수 없을 때이다.
길이가 1
인 배열은 항상 정렬된 상태인 부분 배열이므로
부분 배열의 길이가 1인 배열 2개를 처음 설명한 방법으로 합치고,
병합한 배열을 return
하며 진행해 나가면
최종적으로정렬된 전체 배열을 return
받을 수 있다.
def merge_sort(arr, start, end):
# 2
if end == start:
return [arr[start]]
# 1
mid = (start + end) // 2
left_arr = merge_sort(arr, start, mid)
right_arr = merge_sort(arr, mid+1, end)
분할받은 배열을 병합하자.
처음에 설명한대로, 병합은 정렬된 두 부분 배열을 정렬 시키며 합쳐야 한다.
코드의 left_arr
와 right_arr
의 index를 각각 기록하는 변수를 할당하고
두 배열을 합치기 위한 배열을 선언한 뒤
병합이 진행되어야 할 것이다.
def merge_sort(arr, start, end):
# 2
if end == start:
return [arr[start]]
# 1
mid = (start + end) // 2
left_arr = merge_sort(arr, start, mid)
right_arr = merge_sort(arr, mid+1, end)
# 3
l_idx = 0
r_idx = 0
m_idx = 0
merged_arr = [0] * (len(left_arr) + len(right_arr))
이후, 조건에 따라 반복문을 수행시키기 위해 while문을 사용해
merged_arr
에 조건에 맞는 값을 넣어주고 해당 배열의 index를 증가시킨다.
반복문이 수행되는 동안 merged_arr
에는 항상 값이 들어가므로,
m_idx
는 반복문 수행시 마다 값을 증가시켜야 한다.
def merge_sort(arr, start, end):
# 2
if end == start:
return [arr[start]]
# 1
mid = (start + end) // 2
left_arr = merge_sort(arr, start, mid)
right_arr = merge_sort(arr, mid+1, end)
# 3
l_idx = 0
r_idx = 0
m_idx = 0
merged_arr = [0] * (len(left_arr) + len(right_arr))
# 4
while l_idx < len(left_arr) and r_idx < len(right_arr):
if left_arr[l_idx] < right_arr[r_idx]:
merged_arr[m_idx] = left_arr[l_idx]
l_idx += 1
else:
merged_arr[m_idx] = right_arr[r_idx]
r_idx += 1
m_idx += 1
거의 다 왔다.
이제 left_arr
와 right_arr
중,
아직 merged_arr
에 들어가지 못한 값을 차례대로 merged_arr
에 넣어주고
병합된 merged_arr
을 return 해주면 된다.
def merge_sort(arr, start, end):
# 2
if end == start:
return [arr[start]]
# 1
mid = (start + end) // 2
left_arr = merge_sort(arr, start, mid)
right_arr = merge_sort(arr, mid+1, end)
# 3
l_idx = 0
r_idx = 0
m_idx = 0
merged_arr = [0] * (len(left_arr) + len(right_arr))
# 4
while l_idx < len(left_arr) and r_idx < len(right_arr):
if left_arr[l_idx] < right_arr[r_idx]:
merged_arr[m_idx] = left_arr[l_idx]
l_idx += 1
else:
merged_arr[m_idx] = right_arr[r_idx]
r_idx += 1
m_idx += 1
# 5
right_arr = right_arr if r_idx < len(right_arr) else left_arr
r_idx = r_idx if r_idx < len(right_arr) else l_idx
while r_idx < len(right_arr):
merged_arr[m_idx] = right_arr[r_idx]
r_idx += 1
m_idx += 1
return merged_arr
입력으로 주어지는 배열을 오름차순으로 정렬해서 출력하는 문제이다.
배열의 크기가 이므로,
시간 제한 안에 풀기 위해서는 시간 복잡도 의 정렬 알고리즘을 사용해야 한다.
[ 전체 코드 ]
import sys
def merge_sort(arr, start, end):
# 2
if end == start:
return [arr[start]]
# 1
mid = (start + end) // 2
left_arr = merge_sort(arr, start, mid)
right_arr = merge_sort(arr, mid + 1, end)
# 3
l_idx = 0
r_idx = 0
m_idx = 0
merged_arr = [0] * (len(left_arr) + len(right_arr))
# 4
while l_idx < len(left_arr) and r_idx < len(right_arr):
if left_arr[l_idx] < right_arr[r_idx]:
merged_arr[m_idx] = left_arr[l_idx]
l_idx += 1
else:
merged_arr[m_idx] = right_arr[r_idx]
r_idx += 1
m_idx += 1
# 5
right_arr = right_arr if r_idx < len(right_arr) else left_arr
r_idx = r_idx if r_idx < len(right_arr) else l_idx
while r_idx < len(right_arr):
merged_arr[m_idx] = right_arr[r_idx]
r_idx += 1
m_idx += 1
return merged_arr
n = int(sys.stdin.readline().rstrip())
num_list = [0] * n
for i in range(n):
num_list[i] = int(sys.stdin.readline().rstrip())
result = merge_sort(num_list, 0, n - 1)
for r in result:
sys.stdout.write(f'{r}\n')
참고
[알고리즘] 합병 정렬(merge sort)이란
thanks to jnl1128