
문제 출처 : https://www.acmicpc.net/problem/1450
난이도 : 골드 1
냅색(Knapsack) 이란 가방의 뜻을 가지고 있다.
알고리즘에서 보통 냅색문제란,
가방에 물건을 넣는데,
가방의 최대 무게 제한이 있고 그것을 넘기지 않으며
가치를 최대화하라는 문제이다.
0/1 냅색
물건을 넣거나/ 안넣거나 해서 가치를 최대화하여 담는 문제이다.
Unbounded 냅색
물건을 무한히 사용 가능한 냅색문제이다.
보통 종류를 주고 담으라고 한다.
완전탐색으로 풀고자 하면 N이 30개 이상만 되어도 2^30이 되어 10억이 되어 시간복잡도가 터진다.
그래서 보통
1. DP 로 풀거나
2. Meet in the Middle (반씩 쪼개기)
로 푼다.
이 문제는 2번 Meet in the Middle 을 이용하여 풀고자 한다.
DP방식은 보통 무게 제한이 작을때, 적용하고
반씩 쪼개는 방식은 보통 N이 40이하 일때, 그리고 무게 제한이 매우 클 때 적용한다고 한다.
import sys
input = sys.stdin.readline
# 1450번 냅색(부분집합 합 개수 세기)
# 목표: N개의 물건(각 무게)이 있을 때, "무게 합이 C 이하"가 되는 부분집합의 개수를 구한다.
#
# 포인트:
# - N <= 30 이라서 전체 부분집합(2^N) 완전탐색은 빡셈.
# - 대신 Meet in the Middle(반으로 쪼개기)로 2^(N/2) 수준으로 줄인다.
# - bisect(이분탐색 라이브러리) 없이 "투 포인터"로 개수 세기.
# N 29이하 자연수, C 10^9 보다 작은 0포함 양의 정수
N, limit_weight = map(int,input().split())
weights = list(map(int,input().split()))
# 물건을 반으로 나눈다.
mid = N // 2
A = weights[:mid]
B = weights[mid:]
# 한 그룹에서 만들 수 있는 모든 "부분집합의 합" 을 구하는 함수
def get_sums(arr):
sums = []
def dfs(idx, total):
# 무게가 전부 양수이므로, 무게가 제한을 넘어간다면, early return
if total > limit_weight:
return
# 끝까지 왔으면 현재 total을 저장한다.
if idx == len(arr):
sums.append(total)
return
# idx번째 물건을 안 고르는 경우
dfs(idx+1, total)
# idx 물건을 고르는 경우
dfs(idx+1, total + arr[idx])
dfs(0,0)
return sums
# A, B 각각의 부분집합 합 리스트 생성
A_sums = get_sums(A)
B_sums = get_sums(B)
# 투 포인터를 쓰기 위해 정렬
A_sums.sort()
B_sums.sort()
# 5) 투 포인터로 개수 세기
# 아이디어:
# - A_sums를 작은 a부터 보면서,
# - B_sums는 큰 값부터 줄여가며 a + b <= C를 만족하는 최대 인덱스 j를 찾는다.
# - 그러면 B_sums[0..j]는 전부 a와 합쳐도 C 이하이므로 (j+1)개를 더하면 된다.
#
# 왜 빠르냐:
# - A_sums는 증가(오름차순)하니까 a가 커질수록 조건(a+b<=C)은 더 빡세진다.
# - 그래서 j는 절대 다시 커질 필요가 없고(오른쪽에서 왼쪽으로만 이동),
# 전체가 거의 O(len(A_sums) + len(B_sums))로 돈다.
j = len(B_sums) - 1 # B의 가장 큰 값부터 시작
answer = 0
for a in A_sums:
# a가 고정일 때, a + B_sums[j]가 C를 넘으면 j를 줄여서 b를 더 작게 만든다.
while j >= 0 and a + B_sums[j] > limit_weight:
j -= 1
# j가 -1이면, 현재 a는 어떤 b와도 합쳐서 C 이하가 될 수 없다.
# (A_sums는 오름차순이므로 이후 a는 더 커져서 더더욱 불가능)
if j < 0:
break
# B_sums[0..j]는 모두 a + b <= C를 만족
# 가능한 b의 개수는 (j+1)
answer += (j + 1)
print(answer)