교훈
heapq.heappush
는O(logn)
의 시간복잡도를 갖는다.
세계적인 도둑 상덕이는 보석점을 털기로 결심했다.
상덕이가 털 보석점에는 보석이 총 N개 있다. 각 보석은 무게 M_i와 가격 V_i를 가지고 있다. 상덕이는 가방을 K개 가지고 있고, 각 가방에 담을 수 있는 최대 무게는 C_i이다. 가방에는 최대 한 개의 보석만 넣을 수 있다.
상덕이가 훔칠 수 있는 보석의 최대 가격을 구하는 프로그램을 작성하시오.
입력
첫째 줄에 N과 K가 주어진다. (1 ≤ N, K ≤ 300,000)
다음 N개 줄에는 각 보석의 정보 M_i와 V_i가 주어진다. (0 ≤ M_i, V_i ≤ 1,000,000)
다음 K개 줄에는 가방에 담을 수 있는 최대 무게 Ci가 주어진다. (1 ≤ C_i ≤ 100,000,000)
모든 숫자는 양의 정수이다.
출력
첫째 줄에 상덕이가 훔칠 수 있는 보석 가격의 합의 최댓값을 출력한다.
예제 입력 1
2 1
5 10
100 100
11
예제 출력 1
10
예제 입력 2
3 2
1 65
5 23
2 99
10
2
예제 출력 2
164
가장 비싸고(값이 같다면 더 가벼운) 보석부터 시작하여 차례대로 가장 가벼운 가방부터 순서대로 매칭시켜나가는 알고리즘을 짰다.
모든 가방에 못 들어가는 보석이 있을 수 있겠지만, 그래도 모든 보석을 돌면 최적화된 값이 나올 것임에 분명했다.
N, K = map(int, input().split())
jewels = []
bags = []
for _ in range(N):
M, V = map(int, input().split())
jewels.append((M, V))
for _ in range(K):
bags.append(int(input()))
# 제일 값 비싸고, 값이 똑같다면 가벼운 것부터 매칭시켜야됨.
jewels.sort(key=lambda x: (-x[1], x[0]))
bags.sort()
is_full = [False] * len(bags)
rv = 0
for jewel in jewels:
for i, bag in enumerate(bags):
if (not is_full[i]) and jewel[0] <= bag:
is_full[i] = True
rv += jewel[1]
break
print(rv)
하지만 시간 초과!
생각해보니 보석별로 순회하는 게 아니라 가방별로 순회하는 게 더 직관적이겠다 싶어서 바꿔보았다.
그리고 당연히 아닐 것 같긴 했지만, 이미 가방에 담은 보석은 del
연산을 사용하여 없애주면 시간 복잡도가 줄어들지 않을까 생각했다. (하지만 시간복잡도는 O(n^2)
그대로다.)
import sys
N, K = map(int, sys.stdin.readline().rstrip().split())
jewels = []
bags = []
for _ in range(N):
M, V = map(int, sys.stdin.readline().rstrip().split())
jewels.append((M, V))
for _ in range(K):
bags.append(int(sys.stdin.readline().rstrip()))
# 제일 값 비싸고, 값이 똑같다면 가벼운 것부터 매칭시켜야됨.
jewels.sort(key=lambda x: (-x[1], x[0]))
bags.sort()
rv = 0
# 가장 가벼운 가방부터 순회, 거기에 담을 수 있는 가장 비싼 보석 catch, 보석을 list에서 제외
for bag in bags:
for i, jewel in enumerate(jewels):
if jewel[0] <= bag:
rv += jewel[1]
del jewels[i]
break
print(rv)
뇌피셜이긴 한데, 오히려 del
연산이 O(n)
이라서 시간복잡도가 오히려 늘어나는 결과를 초래했나 싶다.
이것 역시 시간 초과!
뭔가 이런 식으로 막힐 때는 heapq
를 사용하곤 했던 것 같은데...
이 문제에서는 최댓값이나 최솟값을 필요로 하는 곳이 없는 것 같아서 어디에 사용해야 할지 감이 안 잡혔다.
내가 전부 다 생각해낸 건 아니고, 이 풀이 링크를 보고 내가 읽기 편한 대로 코드를 수정했다.
import sys
import heapq
N, K = map(int, sys.stdin.readline().rstrip().split())
jewels = []
bags = []
for _ in range(N):
M, V = map(int, sys.stdin.readline().rstrip().split())
heapq.heappush(jewels, (M, V))
for _ in range(K):
bags.append(int(sys.stdin.readline().rstrip()))
bags.sort()
rv = 0
jewels_candidate = []
for bag in bags:
while jewels and jewels[0][0] <= bag: # 가벼운 가방부터 순회하기 때문에 한번 jewels_candidate로 간 보석은 계속 candidate에 해당한다.
lightest_jewel = heapq.heappop(jewels)
heapq.heappush(jewels_candidate, -lightest_jewel[1]) # 비싼 순으로 정렬. 지금 들어간 보석이면 다음 가방에서도 어차피 다 맞으니까 보석 가격만 max-heap 돌리면 됨.
if jewels_candidate:
rv += -heapq.heappop(jewels_candidate)
print(rv)
https://blog.naver.com/PostView.naver?isHttpsRedirect=true&blogId=sule47&logNo=220773937206
최소 힙, 최대 힙은 삽입 시 O(log_2(n))
, 삭제 시 O(log_2(n))
이므로!
위 코드의 시간복잡도는 O(n^2)
이 아니라 O(nlogn)
으로 줄어든 것이다.
(삭제가 아니라 그냥 최소 힙에서 최솟값 찾기, 최대 힙에서 최댓값 찾기라면 heap[0]
으로 바로 접근할 수 있기 때문에 get 연산은 O(1)
인 것이 특징이다.)
이제까진 input()
대신 sys.stdin.readline().rstrip()
쓰는 거 귀찮아서 안 했었는데,
생각해보니까 시간 초과로 쪼들리기 시작하면 혹시 input() 써서 시간 초과인가? 하는 생각에 2트를 무의미하게 날려버리는 일이 발생한다.
그럴거면 좀 귀찮더라도 1트에서 sys.stdin.readline().rstrip()
써버리는 게 어려운 문제에서는 안정적인 방법인 듯하다.
뭐 대부분의 경우... (그리고 이 문제에서조차도) input()
때문에 발생한 시간 초과는 아니었지만 말이다.