[Algorithm] 백준 2696 : 중앙값 구하기

채멈·2024년 1월 27일

Algorithm

목록 보기
16/24
post-thumbnail

문제
https://www.acmicpc.net/problem/2696
풀이
https://github.com/nowChae/algorithm/blob/master/%EB%B0%B1%EC%A4%80/Gold/2696.%E2%80%85%EC%A4%91%EC%95%99%EA%B0%92%E2%80%85%EA%B5%AC%ED%95%98%EA%B8%B0/%EC%A4%91%EC%95%99%EA%B0%92%E2%80%85%EA%B5%AC%ED%95%98%EA%B8%B0.py

처음 이 문제를 풀 때는 힙 자료구조 하나만을 사용해서 풀으려고 했었다. 배열 속에 숫자를 하나씩 넣다가, 중앙값을 찾아야하는 타이밍에 해당 배열 속 값들을 중앙값이 나올 때까지 heappop하도록 해야하나 생각했었다. 항상 최솟값을 pop하기 때문에 배열의 길이가 5일때는 3번 pop하게 되면 중앙값이고, 길이가 7일 때는 4번 pop하게 되면 중앙값이다. 여기서 규칙을 찾아서 적용해보려고 했다. 이러다 보니 처음 입력받은 배열을 계속 슬라이싱하여 복사하면서 사용했고, 이 때문에 처음 푼 방식으로 제출해보니 시간 초과가 발생하였다.

힙을 사용하여 중앙값을 찾기 위해서는 최소힙, 최대힙을 모두 사용하면 쉽게 찾을 수 있었다. 중앙값을 기준으로 왼쪽은 최대힙, 오른쪽은 최소힙을 사용해 중앙값보다 작은 값은 왼쪽의 최대힙에 넣어주고, 중앙값보다 큰 값은 오른쪽의 최소힙에 넣어준다. 1, 3, 5, 7 등 홀수 번째에서 중앙값을 찾아주기 때문에 왼쪽힙과 오른쪽힙의 길이가 같다면 기존의 중앙값이 여전히 중앙값이다. 만약 어느쪽 하나의 길이가 더 길다면 더 긴 쪽의 값을 하나 빼서 중앙값으로 갱신해주고, 중앙값이었던 것은 길이가 짧은 힙에 추가해준다.

중앙값을 구할 때 최소힙과 최대힙을 사용할 수 있다는 것을 알게되어서 이와 같은 중앙값을 찾는 문제에서 활용해봐야겠다고 생각했다.

< 풀이 코드 - 시간 초과 >

#시간 초과
import heapq
import sys
input = sys.stdin.readline

N = int(input())
for i in range(N):
    number = int(input())
    if number > 10 :
        c = (number//10)
        arr = list(map(int, input().split()))
        for _ in range(c):
            arr += list(map(int, input().split()))
    else:
        arr = list(map(int, input().split()))

    print_count = (number+1)//2
    print(print_count)
    result = []
    for i in range(print_count):
        heap_arr = arr[:2*i + 1]
        count = i
        while count >= 0:
            heapq.heapify(heap_arr)
            r = heapq.heappop(heap_arr)
            if count == 0:
                result.append(r)
            count -= 1

    for i in range(print_count):
        if i % 9 == 0:
            print(result[i])
        else:
            print(result[i], end=" ")
    print()

< 풀이 코드 >

#최대합 최소합을 사용하여 중앙값 구하기 
# 최대합 - 중앙값 - 최소합

import heapq
import sys
input = sys.stdin.readline

N = int(input())
for i in range(N):
    number = int(input())
    if number > 10 :
        c = (number//10)
        arr = list(map(int, input().split()))
        for _ in range(c):
            arr += list(map(int, input().split()))
    else:
        arr = list(map(int, input().split()))

    print_count = (number+1)//2
    print(print_count)
    
    result = [arr[0]] # mid
    min_heap = []
    max_heap = []

    mid = arr[0]
    for i, a in enumerate(arr[1:]):
        if a >= mid:
            heapq.heappush(min_heap, a)
        else:
            heapq.heappush(max_heap, -a)

        if i % 2 != 0:
            if len(min_heap) > len(max_heap):
                heapq.heappush(max_heap, -mid)
                mid = heapq.heappop(min_heap)
            elif len (min_heap) < len(max_heap):
                heapq.heappush(min_heap, mid)
                mid = -heapq.heappop(max_heap)
            result.append(mid)


    for i in range(print_count):
        if i % 9 == 0:
            if i != 0:
                print(result[i])
            else:
                print(result[i], end=" ")
        else:
            print(result[i], end=" ")
    print()
profile
공부 기록 차곡차곡 ( ੭ ・ᴗ・ )੭

0개의 댓글