코딩테스트 문제풀이- 알고리즘 수업 병합정렬

jadive study·2022년 12월 10일
0

우선 병합정렬에대한 지식이 부족해서
여러 자료를 찾아보면서 공부하였다.

https://www.youtube.com/watch?v=ctkuGoJPmAE
https://www.daleseo.com/sort-merge/
-> 참고
O(logN) 시간이 필요하며, 각 패스에서 병합할 때 모든 값들을 비교해야 하므로 O(N) 시간이 소모됩니다. 따라서 총 시간 복잡도는 O(NlogN) ,
두 개의 배열을 병합할 때 병합 결과를 담아 놓을 배열이 추가로 필요합니다. 따라서 공간 복잡도는 O(N) 이다.
다른 정렬 알고리즘과 달리 인접한 값들 간에 상호 자리 교대(swap)이 일어나지 않는다.



요약하자면,
[병합 정렬] 절반으로 나누고 병합하면서 정렬한다
mergeSort(a, m, middle); (left는 왼쪽으로 나눌게 없어질때까지 실행된다)
mergeSort(a, middle + 1, n); (right는 left가 끝나면 실행된다)
merge(a, m, middle, n); (right가 끝나면 병합이 실행된다)

예시 7 6 5 8 3 5 9 1

  1. left가 실행된다 7 6 5 8 -> 7 6 -> 7
  2. left가 끝났으니 right가 실행된다 6
  3. right가 끝났으니 merge함수가 실행된다 7과 6을 비교 후 6 7로 정렬
  4. 정렬될 때까지 반복한다

전체적인 실행순서

7 6 5 8 -> 7 6 -> 7 -> 6 -> merge(6, 7) -> 5 8 -> 5 -> 8 -> merge(5, 8) -> merge(5, 6, 7, 8)
-> merge(1, 3, 5, 5, 6, 7, 8, 9) -> 1 3 5 5 6 7 8 9
3 5 9 1 -> 3 5 -> 3 -> 5 -> merge(3, 5) -> 9 1 -> 9 -> 1 -> merge(1, 9) -> merge(1, 3, 5, 9)

Java 코드
자바도 비슷한 방식으로 구현할 수 있습니다. Arrays 클래스의 copyOfRange() 정적 메서드를 사용해서 배열을 원하는 크기로 복제할 수 있습니다.

public class MergeSorter {
    public static int[] sort(int[] arr) {
        if (arr.length < 2) return arr;

        int mid = arr.length / 2;
        int[] low_arr = sort(Arrays.copyOfRange(arr, 0, mid));
        int[] high_arr = sort(Arrays.copyOfRange(arr, mid, arr.length));

        int[] mergedArr = new int[arr.length];
        int m = 0, l = 0, h = 0;
        while (l < low_arr.length && h < high_arr.length) {
            if (low_arr[l] < high_arr[h])
                mergedArr[m++] = low_arr[l++];
            else
                mergedArr[m++] = high_arr[h++];
        }
        while (l < low_arr.length) {
            mergedArr[m++] = low_arr[l++];
        }
        while (h < high_arr.length) {
            mergedArr[m++] = high_arr[h++];
        }
        return mergedArr;
    }

2
public class MergeSorter {

    public static void mergeSort(int[] arr) {
        sort(arr, 0, arr.length);
    }

    private static void sort(int[] arr, int low, int high) {
        if (high - low < 2) {
            return;
        }

        int mid = (low + high) / 2;
        sort(arr, 0, mid);
        sort(arr, mid, high);
        merge(arr, low, mid, high);
    }

    private static void merge(int[] arr, int low, int mid, int high) {
        int[] temp = new int[high - low];
        int t = 0, l = low, h = mid;

        while (l < mid && h < high) {
            if (arr[l] < arr[h]) {
                temp[t++] = arr[l++];
            } else {
                temp[t++] = arr[h++];
            }
        }

        while (l < mid) {
            temp[t++] = arr[l++];
        }

        while (h < high) {
            temp[t++] = arr[h++];
        }

        for (int i = low; i < high; i++) {
            arr[i] = temp[i - low];
        }
    }
}

핵심은 일단반으로 나누고 나중에합치면어떨까?이다.

문제

오늘도 서준이는 병합 정렬 수업 조교를 하고 있다. 아빠가 수업한 내용을 학생들이 잘 이해했는지 문제를 통해서 확인해보자.

N개의 서로 다른 양의 정수가 저장된 배열 A가 있다. 병합 정렬로 배열 A를 오름차순 정렬할 경우 배열 A에 K 번째 저장되는 수를 구해서 우리 서준이를 도와주자.

크기가 N인 배열에 대한 병합 정렬 의사 코드는 다음과 같다.

merge_sort(A[p..r]) { # A[p..r]을 오름차순 정렬한다.
if (p < r) then {
q <- ⌊(p + r) / 2⌋; # q는 p, r의 중간 지점
merge_sort(A, p, q); # 전반부 정렬
merge_sort(A, q + 1, r); # 후반부 정렬
merge(A, p, q, r); # 병합
}
}

A[p..q]와 A[q+1..r]을 병합하여 A[p..r]을 오름차순 정렬된 상태로 만든다.

A[p..q]와 A[q+1..r]은 이미 오름차순으로 정렬되어 있다.

merge(A[], p, q, r) {
i <- p; j <- q + 1; t <- 1;
while (i ≤ q and j ≤ r) {
if (A[i] ≤ A[j])
then tmp[t++] <- A[i++]; # tmp[t] <- A[i]; t++; i++;
else tmp[t++] <- A[j++]; # tmp[t] <- A[j]; t++; j++;
}
while (i ≤ q) # 왼쪽 배열 부분이 남은 경우
tmp[t++] <- A[i++];
while (j ≤ r) # 오른쪽 배열 부분이 남은 경우
tmp[t++] <- A[j++];
i <- p; t <- 1;
while (i ≤ r) # 결과를 A[p..r]에 저장
A[i++] <- tmp[t++];
}

풀이

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.StringTokenizer;

public class Main {

static int[] A, tmp;
static int count = 0;
static int result = -1;
static int K;

public static void main(String[] args) throws IOException {
	BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
	StringTokenizer st = new StringTokenizer(br.readLine());
	
	int N = Integer.parseInt(st.nextToken()); //n길이 배열
	K = Integer.parseInt(st.nextToken()); 	  //K번 병합병렬
	
	st = new StringTokenizer(br.readLine());
	
	A = new int[N]; //N길이의 배열
	for(int i = 0; i < N; i++) {
		A[i] = Integer.parseInt(st.nextToken()); //N길이의 배열에 값을 넣는다.
	}
	tmp = new int[N];			 //N길이의 임시배열
	merge_sort(A, 0, N - 1);	 //병합정렬할 함수호출
	System.out.println(result);  //출력
	
}
// 위 병합 정렬내용 참고 
public static void merge_sort(int[] A, int p, int r) {//A배열 , 0 , 마지막인덱스
	if (count > K) return ; 	// K가 음수면 리턴
	if (p < r) {			 	// p는 정렬할 인덱스 , r은 마지막 인덱스
		int q = (p + r) / 2; 	// 0+마지막 /2 == 중간
		merge_sort(A, p, q); 	// A배열,0,중간 
		merge_sort(A, q + 1, r); //A배열,중간+1,마지막
		merge(A, p, q, r);		//A배열,0,중간,마지막
	}
}

public static void merge(int[] A, int p, int q, int r) {
	int i = p;				//0인덱스
	int j = q + 1;			//중간+1
	int t = 0;
	
	while (i <= q && j <= r) {
		if(A[i] <= A[j]) {  //i번째가  j보다 크거나 작으면
			tmp[t] = A[i];  //임시 배열에 넣어준다.
			i++;			//다음 인덱스 정렬
		}else {				//i가 더 크다면
			tmp[t] = A[j];  //j가 더 작으므로 t번째에 넣어준다.
			j++;			//j 다음 탐색
		}
		t++; 정렬을 마치면 다음 인덱스 탐색
	}
	
	while (i <= q) {  //왼쪽배열부분남 음 
		tmp[t++] = A[i++];
	}	
	while (j <= r) { //오른쪽배열부분이 남은경우
		tmp[t++] = A[j++];
	}
	i = p;
	t = 0;
	while (i <= r) { 
		count++;    //0~마지막 인덱스 
		if (count == K) { // K번째 카운트값 찾기
			result = tmp[t]; 
			break;
		} 
		A[i++] = tmp[t++]; //다시 A배열에 넣어준다.
	}
}
}

문제 자체는 어렵지 않았지만,문법 생성자 컴파일 백준 오류 때문에 소스코드를 참고 하면서 풀어 보았다.

파이썬 문제도 찾아보았다..

python 문제 풀이

import sys
input = sys.stdin.readline


def mergeSort(L):
    if len(L) == 1:
        return L
    
    mid = (len(L) + 1)//2
    left = mergeSort(L[:mid])
    right = mergeSort(L[mid:])
    
    L2 = []
    i = 0
    j = 0
    
    while i < len(left) and j < len(right):
        if left[i] < right[j]:
            L2.append(left[i])
            ans.append(left[i])
            i += 1
        else:
            L2.append(right[j])
            ans.append(right[j])
            j += 1
    
    while i < len(left):
        L2.append(left[i])
        ans.append(left[i])
        i += 1
        
    while j < len(right):
        L2.append(right[j])
        ans.append(right[j])
        j += 1
        
    return L2

n, k = map(int, input().split())
a = list(map(int, input().split()))

ans = []
mergeSort(a)

if len(ans) >= k:
    print(ans[k-1])
else:
    print(-1
```) 
profile
개발 메모창고

0개의 댓글