https://www.acmicpc.net/problem/1517
시간 1초, 메모리 512MB
input :
output :
조건 :
그래서 대부분의 경우 병합정렬이나, 세그 트리를 이용한다고 한다.
병합정렬의 경우. 입력을 받은 리스트 둘 중 하나를 기준으로 잡는다.
병합 정렬이 수행 될 때, 이제 merge 할 리스트들은 정렬이 되어서 올라온다.
그럴 경우 이미 왼쪽에 있던 애들이 오른쪽에 있는 애들보다 숫자가 커서 swap해 주는 것을 기록해야 한다.
그러면 왼쪽에 있는 애들이 들어가기 전에 이미 정렬이 된 개수를 세아려서 swap에 넣어주자.
그리고 이 cnt(오른쪽 애들이 정렬된 개수)는 초기화를 시키지 않는다. 모든 왼쪽 리스트에 대하여 적용이 되어야 하기 때문에 누적이 되어야 한다.
그리고 가장 고민 했던 것은 크기가 같은 숫자이면 정렬을 어떻게 하는가? 였는데
그냥.. 같은 숫자이면 스왑을 안 하기 때문에 왼쪽에 존재하는 숫자를 new_arr에 집어넣고 다시 merge를 진행하면 된다. 즉 따로 생각할 필요가 없다 ......
import sys
input = sys.stdin.readline
sys.setrecursionlimit(10 ** 9)
def merge(start, end):
# merge
global swap
new_arr = []
mid = (start + end) // 2
l_idx, r_idx = start, mid
cnt = 0
while l_idx < mid and r_idx < end:
if arr[l_idx] > arr[r_idx]:
new_arr.append(arr[r_idx])
r_idx += 1
cnt += 1
else: # arr[idx1] < arr[idx2]
new_arr.append(arr[l_idx])
l_idx += 1
swap += cnt
while l_idx < mid:
new_arr.append(arr[l_idx])
l_idx += 1
swap += cnt
while r_idx < end:
new_arr.append(arr[r_idx])
r_idx += 1
# reflect
for t in range(len(new_arr)):
arr[start + t] = new_arr[t]
def merge_sort(start, end):
global swap, arr
size = end - start
mid = (start + end) // 2
if size <= 1:
return
# divide
merge_sort(start, mid)
merge_sort(mid, end)
merge(start, end)
n = int(input())
arr = list(map(int, input().split()))
swap = 0
merge_sort(0, n)
print(swap)
배열의 인덱스를 이용했기 때문에 배열은 업데이트를 해줘야 정렬이 된 모양을 가지게 된다.
import sys
sys.setrecursionlimit(10 ** 9)
def merge_sort(start, end):
if start + 1 >= end:
return
mid = (end + start) // 2
merge_sort(start, mid)
merge_sort(mid, end)
merge(start, end)
def merge(left, right):
global cnt
mid = (left + right) // 2
i, j, ret, small_cnt = left, mid, [], 0
while i < mid and j < right:
if data[i] > data[j]:
ret.append(data[j])
j += 1
small_cnt += 1
else:
ret.append(data[i])
i += 1
cnt += small_cnt
if i == mid:
ret += data[j:right]
else:
ret += data[i:mid]
cnt += (mid - i) * small_cnt
for i in range(len(ret)):
data[left + i] = ret[i]
n = int(sys.stdin.readline())
data = list(map(int, sys.stdin.readline().split()))
cnt = 0
merge_sort(0, n)
print(cnt)