오늘도 서준이는 병합 정렬 수업 조교를 하고 있다. 아빠가 수업한 내용을 학생들이 잘 이해했는지 문제를 통해서 확인해보자.
N개의 서로 다른 양의 정수가 저장된 배열 A가 있다. 병합 정렬로 배열 A를 오름차순 정렬할 경우 배열 A에 K 번째 저장되는 수를 구해서 우리 서준이를 도와주자.
크기가 N인 배열에 대한 병합 정렬 의사 코드는 다음과 같다.
merge_sort(A[p..r]) { // A[p..r]을 오름차순 정렬한다.
if (p < r) then {
q <- ⌊(p + r) / 2⌋; # q는 p, r의 중간 지점
merge_sort(A, p, q); # 전반부 정렬
merge_sort(A, q + 1, r); # 후반부 정렬
merge(A, p, q, r); # 병합
}
}
// A[p..q]와 A[q+1..r]을 병합하여 A[p..r]을 오름차순 정렬된 상태로 만든다.
// A[p..q]와 A[q+1..r]은 이미 오름차순으로 정렬되어 있다.
merge(A[], p, q, r) {
i <- p; j <- q + 1; t <- 1;
while (i ≤ q and j ≤ r) {
if (A[i] ≤ A[j])
then tmp[t++] <- A[i++]; # tmp[t] <- A[i]; t++; i++;
else tmp[t++] <- A[j++]; # tmp[t] <- A[j]; t++; j++;
}
while (i ≤ q) # 왼쪽 배열 부분이 남은 경우
tmp[t++] <- A[i++];
while (j ≤ r) # 오른쪽 배열 부분이 남은 경우
tmp[t++] <- A[j++];
i <- p; t <- 1;
while (i ≤ r) # 결과를 A[p..r]에 저장
A[i++] <- tmp[t++];
}
첫째 줄에 배열 A의 크기 N(5 ≤ N ≤ 500,000), 저장 횟수 K(1 ≤ K ≤ 10^8)가 주어진다.
다음 줄에 서로 다른 배열 A의 원소A1, A2, ..., AN
이 주어진다. (1 ≤ Ai ≤ 10^9)
배열 A에 K 번째 저장 되는 수를 출력한다. 저장 횟수가 K 보다 작으면
-1
을 출력한다.
❗ 문제를 풀기 전,
c
로 주어진 코드를java
로 변환하였다!
병합 정렬이란 배열을 분할하여 각 배열을 오름차순 정렬시킨 후, 분할된 배열을 합병하여 최종적으로 오름차순 정렬하는 과정이다.
아래 코드를 보면 병합 정렬을 분할과 결합 두 단계로 나누어 메서드로 나타냈다.
merge_sort(int[], int, int)
: 배열a[]
와 배열의 첫번째 인덱스p
, 마지막 인덱스r
을 입력받아 중간 인덱스q
를 구한 후,q
를 기준으로 배열a
를 분할하여 분할된 각 배열을 오름차순 정렬하고,merge()
메서드를 호출하여 결합을 수행한다.
merge(int[], int, int, int)
: 배열a[]
와 첫번째, 중간, 마지막 인덱스p
,q
,r
을 입력받아 분할된 배열들을 합병하여 최종적으로 오름차순 정렬을 수행한다.
static void merge_sort(int[] a, int p, int r) {
if(p < r) {
int q = (p + r) / 2; // 중간 인덱스
merge_sort(a, p, q); // 전반부 정렬
merge_sort(a, q+1, r); // 후반부 정렬
merge(a, p, q, r); // 병합
}
}
static void merge(int[] a, int p, int q, int r) {
int i = p; int j = q + 1; int t = 0;
while(i <= q && j <= r) {
if(a[i] <= a[j]) tmp[t++] = a[i++];
else tmp[t++] = a[j++];
}
while(i <= q) {
tmp[t++] = a[i++];
}
while(j <= r) {
tmp[t++] = a[j++];
}
i = p; t = 0;
while(i <= r) { // 최종 값 저장
cnt++;
a[i++] = tmp[t++];
}
}
✅ 위 코드를 보면
merge()
메서드의 마지막 while문에서 최종적으로 오름차순 정렬을 수행한 값을 저장하므로, 해당 반복문이 실행될 때마다 저장 횟수cnt
를 증가시킨다.
문제에서k
번째 저장 값을 요구하므로cnt++
을 수행한 후,cnt
의 값이k
와 같은 경우에 저장할 값tmp[t]
를 최종 결과result
에 저장한 후 메서드를 종료한다. 만약 저장 횟수가k
보다 작다면 if문에 걸리지 않고 모든 반복을 수행하고 종료되므로,result
의 값은 초기값인-1
이 된다.
import java.io.*;
import java.util.*;
public class Main {
static int[] tmp;
static int n, k;
static int cnt = 0;
static int result = -1;
public static void main(String[] args) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(System.out));
StringTokenizer st = new StringTokenizer(br.readLine());
n = Integer.parseInt(st.nextToken());
k = Integer.parseInt(st.nextToken());
int[] a = new int[n]; tmp = new int[n];
st = new StringTokenizer(br.readLine());
for(int i=0;i<n;i++) {
a[i] = Integer.parseInt(st.nextToken());
}
merge_sort(a, 0, n-1);
bw.write(result + "");
br.close();
bw.close();
}
static void merge_sort(int[] a, int p, int r) {
if(p < r) {
int q = (p + r) / 2; // 중간 인덱스
merge_sort(a, p, q); // 전반부 정렬
merge_sort(a, q+1, r); // 후반부 정렬
merge(a, p, q, r); // 병합
}
}
static void merge(int[] a, int p, int q, int r) {
int i = p;
int j = q + 1;
int t = 0;
while(i <= q && j <= r) {
if(a[i] <= a[j]) tmp[t++] = a[i++];
else tmp[t++] = a[j++];
}
while(i <= q) {
tmp[t++] = a[i++];
}
while(j <= r) {
tmp[t++] = a[j++];
}
i = p; t = 0;
while(i <= r) { // 결과 저장
cnt++;
if(cnt == k) {
result = tmp[t];
return;
}
a[i++] = tmp[t++];
}
}
}