아래와 같이 주어진 모든 수를 리스트에 담아 정렬한 뒤, n번째로 큰 수를 출력하는 것이 쉽게 떠올릴 수 있는 접근이다.
# 오답(메모리 초과)
import sys
n = int(sys.stdin.readline())
nums = []
for _ in range(n):
nums += list(map(int, sys.stdin.readline().split()))
nums.sort(reverse=True)
print(nums[n - 1])
하지만.. 메모리 제한이 12MB라서 메모리 초과가 뜨게 되며, 그 이유는 다음과 같다.
n의 최댓값이 1,500이기 때문에, 입력 데이터는 1,500 * 1,500 = 2,250,000개가 될 수 있다.
Python에서 정수는 일반적으로 4bytes를 사용하므로, 2,250,000 * 4 = 9,000,000bytes = 9MB가 필요하다.
입력 데이터만 보면 메모리 초과가 안 뜰 것 같지만, 정렬을 위한 추가 메모리도 필요하다.
Python의 sort()는 Timesort 알고리즘을 사용해 추가로 약 N 크기의 메모리를 사용하게 된다고 한다.
따라서 위의 코드는 9 + 9 = 18MB를 필요로 하게 되는 것이다.
이러한 문제는 최소 힙(min_heap)을 통해 해결할 수 있는데, min_heap에 가장 큰 n개의 값만 유지되도록 하면 된다.
이 경우 메모리 사용량이 O(n)이기 때문에, 주어진 메모리 제한을 넘지 않는다.
참고로 Python의 heapq는 최소 힙이 디폴트 값이므로, min_heap[0]이 구하는 값이 된다.
# 정답
import sys
import heapq
n = int(sys.stdin.readline())
min_heap = []
for _ in range(n):
row = list(map(int, sys.stdin.readline().split()))
for num in row:
heapq.heappush(min_heap, num)
if len(min_heap) > n:
heapq.heappop(min_heap)
print(min_heap[0])