분할 정복 알고리즘 중 하나인 합병 정렬을 구현해본다. 재귀적인 구조로 구현할 수 있다.
#include <stdio.h>
#include <time.h>
#include <stdlib.h>
#include "vld.h"
#define MAX_VALUE 30
#define SIZE 10
void merge(int arr[], int l, int r);
void merge_sort(int arr[], int l, int r);
void swap(int* a, int* b);
int* generate_random_arr(int size);
void print_arr(int arr[], int size);
int main(void)
{
srand((unsigned)time(NULL));
int* arr = generate_random_arr(SIZE);
print_arr(arr, SIZE);
merge_sort(arr, 0, SIZE - 1);
print_arr(arr, SIZE);
free(arr);
}
void swap(int* a, int* b)
{
int tmp = *a;
*a = *b;
*b = tmp;
}
void merge_sort(int arr[], int l, int r)
{
if (l != r) {
int m = (l + r) / 2;
merge_sort(arr, l, m);
merge_sort(arr, m + 1, r);
merge(arr, l, r);
print_arr(arr, SIZE);
}
}
void merge(int arr[], int l, int r)
{
int m = (l + r) / 2;
int p = l;
int q = m + 1;
//int* tmp_arr = (int*)malloc(sizeof(int) * (r - l + 1));
int* tmp_arr = (int*)malloc(sizeof(int) * SIZE);
if (tmp_arr == NULL) return;
for (int i = l; i <= r; i++)
tmp_arr[i] = arr[i];
for (int i = l; i <= r; i++) {
if (p > m) {
while (i <= r)
arr[i++] = tmp_arr[q++];
}
else if (q > r) {
while (i <= r)
arr[i++] = tmp_arr[p++];
}
else {
if (tmp_arr[p] > tmp_arr[q])
arr[i] = tmp_arr[q++];
else arr[i] = tmp_arr[p++];
}
}
free(tmp_arr);
}
int* generate_random_arr(int size)
{
int* arr = (int*)malloc(sizeof(int) * size);
if (arr == NULL) return NULL;
for (int i = 0; i < size; i++)
arr[i] = rand() % MAX_VALUE + 1;
return arr;
}
void print_arr(int arr[], int size)
{
for (int i = 0; i < size; i++)
printf("%d ", arr[i]);
printf("\n\n");
}
void merge_sort(int arr[], int l, int r)
{
if (l != r) {
int m = (l + r) / 2;
merge_sort(arr, l, m);
merge_sort(arr, m + 1, r);
merge(arr, l, r);
print_arr(arr, SIZE);
}
}
합병 정렬 함수 자체는 큰 내용이 없다. 구조만 담겨 있다.
l(left)부터 r(right)까지의 요소들을 정렬하기 위해선, 중간 점 m을 구한 다음 l부터 m까지 정렬하고 m+1부터 r까지 정렬하여 합치는 과정이 필요하다.
합병 정렬의 핵심이 되는 함수이다. 잘게 분할한 문제를 다시 합치는 과정이다. 그냥 합치는게 아니라 정렬하면서 합쳐야 한다.
void merge(int arr[], int l, int r)
{
int m = (l + r) / 2;
int p = l;
int q = m + 1;
int* tmp_arr = (int*)malloc(sizeof(int) * SIZE);
if (tmp_arr == NULL) return;
for (int i = l; i <= r; i++)
tmp_arr[i] = arr[i];
for (int i = l; i <= r; i++) {
if (p > m) {
while (i <= r)
arr[i++] = tmp_arr[q++];
}
else if (q > r) {
while (i <= r)
arr[i++] = tmp_arr[p++];
}
else {
if (tmp_arr[p] > tmp_arr[q])
arr[i] = tmp_arr[q++];
else arr[i] = tmp_arr[p++];
}
}
free(tmp_arr);
}
p와 q는 분할한 두 배열의 시작 점이다. 그림으로 나타내면 아래와 같다.
이후에 p와 q를 비교하면서 작은 순서대로 합쳐주면 된다.
분할과정 후에 다시 합칠 때 원본 배열의 데이터가 훼손되는 것을 생각하지 못해서 특정 요소가 여러번 복사되었다.
위의 코드에서는 tmp_arr 배열을 동적으로 생성하고, 원본 배열의 데이터를 백업한 뒤에야 비로소 배열 arr에 데이터를 쓸 수 있었다.