[코딩스터디] 24061번&1874번

Chaejung·2022년 2월 22일
0

알고리즘 수업 - 병합 정렬 2

풀기 전

제목이 대놓고 "병합 정렬"이면 개념 문제겠네~
스터디원들이랑 같이 풀면 어렵지 않겠다!

풀면서

내가 왜 이걸 스터디에서 풀자고 했지...
좀 푸는 시늉이라도 하고 정할 걸 미쳤나봐...

결론: 새벽 5시까지 붙잡고 있었음

병합정렬이란?

첫 시도

import sys
listNum, savingNum = sys.stdin.readline().split()

# 중요!
global save_num
save_num = 0


def merge_sort(array):
	# 탈출 조건: mergeSort하는 요소가 1개가 될 때까지
    if len(array) <= 1:
        return array
    mid = len(array) // 2
    left_array = merge_sort(array[:mid])
    right_array = merge_sort(array[mid:])
    print(left_array, right_array)

    return merge(left_array, right_array)

# 앞의 병합정렬 merge와 동일
def merge(array1, array2):
    result = []
    array1_index = 0
    array2_index = 0
    global save_num 
    breakValue = True
    
    while array1_index < len(array1) and array2_index < len(array2):
        if array1[array1_index] < array2[array2_index]:
            result.append(array1[array1_index])
            save_num += 1
            if numberK(result, save_num)==-1:
                return
            array1_index += 1
        else:
            result.append(array2[array2_index])
            save_num += 1
            if numberK(result, save_num)==-1:
                return
            array2_index += 1

    if breakValue == False:
        return 

    if array1_index == len(array1):
        while array2_index < len(array2):
            result.append(array2[array2_index])
            save_num += 1
            if numberK(result, save_num)==-1:
                return
            array2_index += 1

    if array2_index == len(array2):
        while array1_index < len(array1):
            result.append(array1[array1_index])
            save_num += 1
            if numberK(result, save_num)==-1:
                return
            array1_index += 1

    return result

global kthNum 
kthNum= 0
def numberK(array, save_num):
    global kthNum
    if save_num == savingNum:
        kthNum=array[-1]
        return -1
    else:
        pass


array = [4, 5, 1, 3, 2]
print(merge_sort(array), save_num)
print(kthNum)
  • 결과값
    >>> [1, 2, 3, 4, 5] 12
    >>> 0

리스트를 입력받는 것을 고려하지 않고 어떻게 작동하는지만 보려고 만들었으나
우선 대략 보기만 해도 코드가 매우 더럽고,
총 저장 횟수는 잘 도출되나 정해진 저장 횟수에 저장되는 수를 뽑아내지 못했다.
아무래도 함수끼리 얽혀서 리턴값을 제대로 읽어내지 못하거나 전역 변수 설정 문제인 듯 하다.
답이 없는 듯 하여 엎고 새로 다시 짰다.

두 번째 시도

그러던 와중 스터디원분 중 한 명께서 나는 눈치채지 못한 문제를 집어주셨는데,

def merge_sort(array):
	# 탈출 조건: mergeSort하는 요소가 1개가 될 때까지
    if len(array) <= 1:
        return array
    mid = len(array) // 2
    left_array = merge_sort(array[:mid])
    right_array = merge_sort(array[mid:])
    print(left_array, right_array)

    return merge(left_array, right_array)

정리했던 병합정렬 같은 경우, 중간 지점을
mid = len(array)//2 를 하여 홀수인 경우 나누는 배열의 앞부분이 적은 부분이 된다.

그런데 문제에서 행하는 병합정렬 방식은 달랐다.
위의 경우에 따르면 45/132 - 4/5//1/32 이렇게 나뉘어야 하는데,
문제에서는 451/32 - 45/1//32 이렇게 분할하고 있었다.

그래서 오밤 중에 같이 고민을 하던 팀원분과는
"mid+1을 해야하나...?"
"그러면 짝수인 경우 정확한 반으로 나눠지지 않는데요!"
"롸...?"

결국 답을 찾지 못한 채 열심히 파보던 와중
또다른 팀원분께서 힌트를 제시해 주셨다!

import sys
import math

# 리스트 크기와 저장 횟수/정렬할 리스트 입력
listNum, saveNum = map(int, sys.stdin.readline().split())
listElement = list(map(int, sys.stdin.readline().split()))

# 병합정렬mergeSort
def mergeSort(array):
    if len(array)<= 1:
        return array
    
    # 분할 정복 위한 나누기
    mid = math.ceil(len(array)/2)
    
    left_array = mergeSort(array[:mid])
    right_array = mergeSort(array[mid:])
    
    # 서로 다른 두 배열을 순서대로 merge하기 
    return merge(left_array, right_array)

# 저장횟수 카운트
global hmSavingNum 
hmSavingNum= 0

# 병합정렬merge
def merge(array1, array2):
    result = []
    
    # 병합하면서 새롭게 넣은 배열의 경우 다음 인덱스로 넘기기 위해 설정
    index1 = 0
    index2 = 0

    global hmSavingNum

    # 둘 다 새로운 배열에 옮겨담지 못했을 때
    while index1<len(array1) and index2<len(array2):
        if array1[index1] < array2[index2]:
            result.append(array1[index1])
            hmSavingNum += 1
            # 저장 횟수에 달하면 바로 해당값 출력하기
            if hmSavingNum == saveNum:
                print(result[-1])
                quit() 
            index1 += 1
            #print(hmSavingNum, result)
        else:
            result.append(array2[index2])
            hmSavingNum += 1
            # 저장 횟수에 달하면 바로 해당값 출력하기
            if hmSavingNum == saveNum:
                print(result[-1])
                quit() 
            index2 += 1
            #print(hmSavingNum, result)
    
    # 둘 중에 하나라도 다 옮겨 담았을 때
    if index1==len(array1):
        while index2<len(array2):
            result.append(array2[index2])
            hmSavingNum += 1
            # 저장 횟수에 달하면 바로 해당값 출력하기
            if hmSavingNum == saveNum:
                print(result[-1])
                quit() 
            index2 += 1
            #print(hmSavingNum, result)
    
    if index2==len(array2):
        while index1<len(array1):
            result.append(array1[index1])
            hmSavingNum += 1
            # 저장 횟수에 달하면 바로 해당값 출력하기
            if hmSavingNum == saveNum:
                print(result[-1])
                quit() 
            index1 += 1
            #print(hmSavingNum, result)
    # 전부 정렬했으나 입력받은 저장횟수가 전체 저장횟수보다 크면 -1 출력
    if len(result)==listNum and saveNum > hmSavingNum:
        print(-1)
        return

    return result

mergeSort(listElement)

<핵심 코드>

  • math.ceil(len(array)/2)
    import math를 하여 유리수 올림을 할 수 있게 하는 math.ceil(N)
    len(array)가 짝수여도 아까 고민했던 문제는 발생하지 않는다!

  • global hmSavingNum
    hmSavingNum = 0
    처음에는 전역 변수 선언 시 함수 내에서만,
    그리고 global hmSavingNum = 0 이렇게만 하면 되는 줄 알았으나,
    함수 밖에서도 선언해야하고,
    두 줄에 걸쳐서 선언해야한다는 것을 알았다.
    이것때문에 거의 한 시간은 잡아먹은 듯

  • quit()
    병합정렬의 강력한 기능이자 고질적인 문제가 바로,
    return 값이 서로 연결돼 있어 다른 값을 함부로 못 넣는다는 것이다.
    처음에는 return 값을 배열로 넘겨 인덱스로 분리할까했지만
    공간은 물론이고 복잡할 것 같아서 포기.
    그래서 지금의 if문에 빈 return 을 적었으나

    다음과 같이 None으로 다음 병합에 들어가버린다.
    저장횟수에 해당하는 숫자는 잘 나와서 더욱 화났다.

    결국 구글링을 해서 프로그램 자체를 끝내버리는 코드를 찾아 넣게됐다.
    조금 극단적인 것 같지만 지금 생각해보면
    내가 찾고 있던 유일한 방법인 듯 하다.

    결과는 성공!

세 번째 시도

오랜 시간에 걸쳐 성공했기에 더 수정하지 않고 넘어가려 했으나
스터디원분들이 나누는 얘기로부터 얻은 영감이 있어
살짝만 더 손 대보았다.

그런데 결과가 생각보다 너무 잘 나와서 너무 뿌듯하다.

키워드는 메모이제이션과 -1!

import sys
import math

# 리스트 크기와 저장 횟수/정렬할 리스트 입력
listNum, saveNum = map(int, sys.stdin.readline().split())
listElement = list(map(int, sys.stdin.readline().split()))

# 메모이제이션, 병합정렬 최댓값을 기록
memo = {
    1 : 0, # 정렬할 필요 없음
    2 : 2, # [2, 1] -> [1, _ ] -> [1, 2]
    3 : 5  # memo[2] + memo[1] + 3
}

# 병합정렬 중 저장횟수 구하기
def mergeSortCount(n, msMemo):
    if n in msMemo:
        return msMemo[n]
    nthMemo = mergeSortCount(math.ceil(n/2), msMemo) + mergeSortCount(n-math.ceil(n/2), msMemo) + n
    memo[n] = nthMemo
    return nthMemo

# 만약 배열 크기에 따른 저장횟수가 입력받은 저장횟수보다 작으면 바로 -1 출력 
if mergeSortCount(listNum, memo)<saveNum:
    print(-1)
    quit()

# 병합정렬mergeSort
def mergeSort(array):
    if len(array)<= 1:
        return array

    # 분할 정복 위한 나누기
    mid = math.ceil(len(array)/2)
    
    left_array = mergeSort(array[:mid])
    right_array = mergeSort(array[mid:])

    # 서로 다른 두 배열을 순서대로 merge하기 
    return merge(left_array, right_array)

#전역 변수로 숫자 정렬할 때마다 카운트하는 변수 선언
global hmSavingNum 
hmSavingNum= 0

# 병합정렬merge
def merge(array1, array2):
    result = []

    # 병합하면서 새롭게 넣은 배열의 경우 다음 인덱스로 넘기기 위해 설정
    index1 = 0
    index2 = 0

    global hmSavingNum

    # 둘 다 새로운 배열에 옮겨담지 못했을 때
    while index1<len(array1) and index2<len(array2):
        if array1[index1] < array2[index2]:
            result.append(array1[index1]) 
            index1 += 1
        else:
            result.append(array2[index2])
            index2 += 1
        # 위의 두 조건문에 반복되어 밖으로 뺐다.
        # 숫자 하나 집어넣을 때 횟수 증가+검사하기
        hmSavingNum += 1
        if hmSavingNum == saveNum:
            print(result[-1])
            quit()

    # 둘 중에 하나라도 다 옮겨 담았을 때
    if index1==len(array1):
        while index2<len(array2):
            result.append(array2[index2])
            hmSavingNum += 1
            if hmSavingNum == saveNum:
                print(result[-1])
                quit() 
            index2 += 1
    
    if index2==len(array2):
        while index1<len(array1):
            result.append(array1[index1])
            hmSavingNum += 1
            if hmSavingNum == saveNum:
                print(result[-1])
                quit() 
            index1 += 1

    # 전부 정렬했으나 입력받은 저장횟수가 전체 저장횟수보다 크면 -1 출력
    if len(result)==listNum and saveNum > hmSavingNum:
        print(-1)
        return

    return result

mergeSort(listElement)

<핵심 코드>

  • def mergeSortCount(n, msMemo)
    피보나치수열을 메모이제이션을 이용해 함수로 나타난 걸 본 적이 있는데,
    병합정렬도 재귀함수와 같은 유사한 규칙이 있었다.
    그래서 동적계획법이 가능하길래, 바로 하나씩 구해보았다.
    참고로 어떻게 정렬돼있던 간데 배열의 크기가 동일하면 저장 횟수도 동일하다.

    n = 1
    memo[1] = 0
    정렬할 필요가 없음

    n = 2
    memo[2] = 2
    [2, 1] -> [1, _ ] -> [1, 2]

    n = 3
    memo[3] = memo[2] + memo[1] + 3
    [2, 3, 1] -> [2, _, _] -> [2, 3, _] (여기까지 memo[2])-> [1, 3, _] -> [1, 2, _] -> [1, 2, 3]

    n = 4
    memo[4] = memo[2] + memo[2] + 4
    [2, 1, 4, 3] -> [1, _, _, _] -> [1, 2, _, _] -> [1, 2, 3, _] -> [1, 2, 3, 4] (여기까지 memo[2] 두 번)-> [**1**, 2, 3, 4] -> [**1**, **2**, 3, 4] -> [**1**, **2**, **3**, 4] -> [**1**, **2**, **3**, **4**]
    여기서 [1, 2, 3, 4] 이미 완성이 됐는데 왜 다시 저장을 하냐면,
    merge([1,2],[3,4]) 가 실행되기 전 또는 중에 완전한 [1, 2, 3, 4] 배열이 아직 다 만들어지지 않았기 때문이다. 이해를 위해 합쳐서 표현한 것이기, 중간의 [1, 2, 3, 4]는 사실상 [1, 2]과 [3, 4]가 각각 있는 상태.

    n = 5
    memo[5] = memo[3] + memo[2] + 5
    생략

    .
    .
    .
    일반화를 하자면
    S(N) = 배열의 크기가 N인 배열의 병합정렬 저장횟수

    S(N) = S(math.ceil(N/2)) + S(N-math.ceil(N/2)) + N

    코드로 변환하게 되면,
    nthMemo = mergeSortCount(math.ceil(n/2), msMemo) + mergeSortCount(n-math.ceil(n/2), msMemo) + n

    • if mergeSortCount(listNum, memo)<saveNum: print(-1) quit()

      결국엔 왜 메모이제이션을 했나면,
      -1을 제외한 아래 함수는 중간에 저장횟수에 달하면 바로 출력하고 프로그램을 종료하게 된다.
      그런데 -1을 도출하기 위해서는 병합정렬을 끝까지 해야하는데,
      여기서 시간 초과나 메모리 초과가 날 확률이 높다.
      그래서 배열의 크기로 병합정렬의 저장횟수가 정해진다면, 굳이 병합정렬을 하지 않아도, -1을 도출할 수 있는 것이다.

      메모이제이션은 재귀함수 성격을 가진 규칙에서 메모리를 아주 적게 써서 유용하다.
      이 부분이 시간 단축에 크게 기여한 듯 하다.

스택 수열

첫 번째 시도

import sys

def stackPut(n, sequence):
    stack = []
    num = 1
    index = 0
    result = []

    while True:
        if len(stack) == 0:
            stack.append(num)
            result.append("+")
            num += 1
        elif sequence[index] == stack[-1]:
            stack.pop()
            result.append("-")
            index += 1
            if index == n:
                break
        else:
            stack.append(num)
            result.append("+")
            num += 1
    if len(stack) == 0:
        print(*result, sep="\n")

totalNum = int(sys.stdin.readline())
numList = []
for i in range(totalNum):
    numList.append(int(sys.stdin.readline()))

# 최댓값 이후로 내림차순 아니면 NO
checkList = numList[numList.index(max(numList)):]
if checkList != sorted(checkList, reverse=True):
    print("NO")

stackPut(totalNum, numList)

예제를 몇 개 보면서 스택이 가능한 것과 불가능한 것의 차이를 비교하니,
불가능한 것은 최댓값 이후로 내림차순이 아닌 경우에 해당했다.
왜냐하면 최댓값이 결과 리스트에 들어간 후 스택에 남아있는 수들은 당연하게도 내림차순이다. 이것을 pop하여 결과 리스트에 붙일 때 내림차순으로 들어가기 때문에,
최댓값 이후가 내림차순이 아닌 경우 스택으로 만들지 못하는 것이다.

이 원리를 이용해서 NO인 경우를 먼저 빼려고 했으나

첫 번째와 두 번째 시도는 실패

성공한 시도

import sys

def stackPut(n, sequence):
    stack = []
    num = 1
    index = 0
    result = [] # push pop의 여부가 들어감

    while True:
        
        if len(stack) == 0:
            stack.append(num)
            result.append("+")
            num += 1
        elif sequence[index] == stack[-1]:
            stack.pop()
            result.append("-")
            index += 1
            # 전체 수를 입력된 배열처럼 담았을 때 탈출 가능
            if index == n:
                break
        else:
            # 스택에서 입력된 배열에 못 담을 때
            # 담아오는 숫자가 전체 숫자보다 큰 경우
            if n < num:
                print("NO")
                break
            stack.append(num)
            result.append("+")
            num += 1
    # 스택이 비었을 때 전체 push/pop 결과 출력
    if len(stack) == 0:
        for i in result:
            print(i)

totalNum = int(sys.stdin.readline())
numList = []
for i in range(totalNum):
    numList.append(int(sys.stdin.readline()))

stackPut(totalNum, numList)

아쉽게도 처음 생각했던 아이디어는 버려야 했다.
직관적으로 스택을 입력받은 수열과 동일하게 만드는 구조이다.

참고한 사이트

반올림, 내림, 올림
전역 변수 지역 변수
quit()메서드
리스트 정렬

앞에서 너무 힘을 빼서 그런가...
여기서 급하게 마무리!

profile
프론트엔드 기술 학습 및 공유를 활발하게 하기 위해 노력합니다.

0개의 댓글

관련 채용 정보