[TIL] Merge Sort

Hyeon·2022년 10월 12일
0

TIL

목록 보기
5/8
post-thumbnail

Merge Sort

오름차순으로 정렬된 두 개의 배열 A, B가 있다고 생각해보자
A = [1, 3, 8]
B = [2, 4, 7]

두 배열 A, B를 '정렬이 유지되도록 합친' 배열 C를 만들려면 어떻게 해야 할까?

  1. A[0]B[0]의 크기를 비교하여, 더 작은 A[0]C[0]에 넣는다.
  1. A[1]B[0]의 크기를 비교하여, 더 작은 B[0]C[1]에 넣는다.
  1. A[1]B[1]의 크기를 비교하여, 더 작은 A[1]C[2]에 넣는다.
  1. A[2]B[1]의 크기를 비교하여, 더 작은 B[1]C[3]에 넣는다.
  1. A[2]B[2]의 크기를 비교하여, 더 작은 B[2]C[4]에 넣는다.
  1. 마지막으로, 남은 A[2]C[5]에 넣는다.

짠. 정렬된 배열 C가 만들어졌다.
C = [1, 2, 3, 4, 7, 8]

만드는 방법은 아래의 내용을 반복한다.

각 배열 A, B의 가장 앞부분 부터 비교를 시작해서
작은 값을 배열 C에 넣어주고,
작은 값을 가지고 있던 배열의 비교 위치를 1 증가시킨다.

그리고, 마지막에 남은 배열의 모든 값을 차례로 C에 넣어주며 종료된다.

이렇게, 정렬된 두 부분 배열을 병합하여
전체 배열이 정렬되게 하는 전략을 사용하는 정렬 방식이
바로 Merge Sort 이다.

분할

정렬된 두 부분배열을 만들기 위해서는 원본 배열의 분할이 필요하다.
원본 배열 arr을 두개로 나누기 위해,
전달 받은 배열의 첫번째 index인 start와 마지막 index인 end가 필요하고
중간 지점을 계산하여 전달 받은 배열을 둘로 나눠준다.

def merge_sort(arr, start, end):
	# 1
	mid = (start + end) // 2
	left_arr = merge_sort(arr, start, mid)
    right_arr = merge_sort(arr, mid+1, end)
    

배열을 return 해주는 조건은
분할된 부분배열의 길이가 1이되어 더 이상 분할 할 수 없을 때이다.

길이가 1인 배열은 항상 정렬된 상태인 부분 배열이므로
부분 배열의 길이가 1인 배열 2개를 처음 설명한 방법으로 합치고,
병합한 배열을 return 하며 진행해 나가면
최종적으로 정렬된 전체 배열을 return 받을 수 있다.

def merge_sort(arr, start, end):
	# 2
	if end == start:
    	return [arr[start]]
        
    # 1
	mid = (start + end) // 2    
	left_arr = merge_sort(arr, start, mid)
    right_arr = merge_sort(arr, mid+1, end)

병합

분할받은 배열을 병합하자.
처음에 설명한대로, 병합은 정렬된 두 부분 배열을 정렬 시키며 합쳐야 한다.

코드의 left_arrright_arr의 index를 각각 기록하는 변수를 할당하고
두 배열을 합치기 위한 배열을 선언한 뒤
병합이 진행되어야 할 것이다.

def merge_sort(arr, start, end):
	# 2
	if end == start:
    	return [arr[start]]
    # 1 
	mid = (start + end) // 2    
	left_arr = merge_sort(arr, start, mid)
    right_arr = merge_sort(arr, mid+1, end)
    
    # 3
    l_idx = 0
    r_idx = 0
    m_idx = 0
    merged_arr = [0] * (len(left_arr) + len(right_arr))

이후, 조건에 따라 반복문을 수행시키기 위해 while문을 사용해
merged_arr에 조건에 맞는 값을 넣어주고 해당 배열의 index를 증가시킨다.

반복문이 수행되는 동안 merged_arr에는 항상 값이 들어가므로,
m_idx 는 반복문 수행시 마다 값을 증가시켜야 한다.

def merge_sort(arr, start, end):
	# 2
	if end == start:
    	return [arr[start]]
    # 1 
	mid = (start + end) // 2    
	left_arr = merge_sort(arr, start, mid)
    right_arr = merge_sort(arr, mid+1, end)
    
    # 3
    l_idx = 0
    r_idx = 0
    m_idx = 0
    merged_arr = [0] * (len(left_arr) + len(right_arr))
    
    # 4
    while l_idx < len(left_arr) and r_idx < len(right_arr):
    	if left_arr[l_idx] < right_arr[r_idx]:
        	merged_arr[m_idx] = left_arr[l_idx]
            l_idx += 1
        else:
        	merged_arr[m_idx] = right_arr[r_idx]
            r_idx += 1
        m_idx += 1

거의 다 왔다.
이제 left_arrright_arr 중,
아직 merged_arr에 들어가지 못한 값을 차례대로 merged_arr에 넣어주고
병합된 merged_arr 을 return 해주면 된다.

def merge_sort(arr, start, end):
	# 2
	if end == start:
    	return [arr[start]]
    # 1 
	mid = (start + end) // 2    
	left_arr = merge_sort(arr, start, mid)
    right_arr = merge_sort(arr, mid+1, end)
    
    # 3
    l_idx = 0
    r_idx = 0
    m_idx = 0
    merged_arr = [0] * (len(left_arr) + len(right_arr))
    
    # 4
    while l_idx < len(left_arr) and r_idx < len(right_arr):
    	if left_arr[l_idx] < right_arr[r_idx]:
        	merged_arr[m_idx] = left_arr[l_idx]
            l_idx += 1
        else:
        	merged_arr[m_idx] = right_arr[r_idx]
            r_idx += 1
        m_idx += 1
    
    # 5
    right_arr = right_arr if r_idx < len(right_arr) else left_arr
    r_idx = r_idx if r_idx < len(right_arr) else l_idx
    
    while r_idx < len(right_arr):
        merged_arr[m_idx] = right_arr[r_idx]
        r_idx += 1
        m_idx += 1
    
    return merged_arr

문제 풀이

BOJ 2751 수 정렬하기 2

입력으로 주어지는 배열을 오름차순으로 정렬해서 출력하는 문제이다.
배열의 크기가 1,000,0001,000,000 이므로,
시간 제한 안에 풀기 위해서는 시간 복잡도 O(n×logn)O(n\times \log n) 의 정렬 알고리즘을 사용해야 한다.

[ 전체 코드 ]

import sys


def merge_sort(arr, start, end):
    # 2
    if end == start:
        return [arr[start]]
    # 1
    mid = (start + end) // 2
    left_arr = merge_sort(arr, start, mid)
    right_arr = merge_sort(arr, mid + 1, end)

    # 3
    l_idx = 0
    r_idx = 0
    m_idx = 0
    merged_arr = [0] * (len(left_arr) + len(right_arr))

    # 4
    while l_idx < len(left_arr) and r_idx < len(right_arr):
        if left_arr[l_idx] < right_arr[r_idx]:
            merged_arr[m_idx] = left_arr[l_idx]
            l_idx += 1
        else:
            merged_arr[m_idx] = right_arr[r_idx]
            r_idx += 1
        m_idx += 1

    # 5
    right_arr = right_arr if r_idx < len(right_arr) else left_arr
    r_idx = r_idx if r_idx < len(right_arr) else l_idx

    while r_idx < len(right_arr):
        merged_arr[m_idx] = right_arr[r_idx]
        r_idx += 1
        m_idx += 1

    return merged_arr


n = int(sys.stdin.readline().rstrip())
num_list = [0] * n
for i in range(n):
    num_list[i] = int(sys.stdin.readline().rstrip())

result = merge_sort(num_list, 0, n - 1)

for r in result:
    sys.stdout.write(f'{r}\n')

참고
[알고리즘] 합병 정렬(merge sort)이란
thanks to jnl1128

profile
그럼에도 불구하고

0개의 댓글