N
개의 정수로 이루어진 수열이 주어지고, 크기가 양수인 부분수열 중에서 그 수열의 원소를 다 더한 값이 S
가 되는 경우의 수를 구하는 문제다.
입력 값의 범위는 1 <= N <= 20
, s <= |1,000,000|
이다.
- 백트래킹 방식을 이용해 입력 배열에 포함된 값을 더하고 빼고를 반복하며
S
와 동일한 결과가 나올 때만count
를 1씩 증가시킨다.
- 가능한 조합만을 고려하기 위해 재귀를 수행할 때 인자로 현재 값 인덱스 + 1을 수행한다.
count
는 전역으로 설정하고 함수 안에서global
키워드로 선언하여 사용한다.
itertools
의combinations
를 사용해 모든 조합을 고려하는 방법도 존재한다.
import sys
input = sys.stdin.readline
n, s = map(int, input().split())
num_list = list(map(int, input().split()))
result = []
count = 0
visit = []
def back(num, result):
if sum(result) == s and len(result) != 0:
global count
count += 1
for i in range(num, n):
if i in visit:
continue
visit.append(i)
result.append(num_list[i])
back(i+1, result)
result.pop()
visit.pop()
back(0, result)
print(count)
통과는 했으나, 644ms 만큼의 시간이 소요되었다. 아마 sum(n)
, len(result) != 0
의 과정에서 발생하는 것 같다.
이를 개선하기 위해 2번, 3번을 작성해보았다.
import sys
from itertools import combinations
input = sys.stdin.readline
n, s = map(int, input().split())
num_list = list(map(int, input().split()))
count = 0
for i in range(1, n+1):
for j in combinations(num_list, i):
if sum(j) == s:
count += 1
print(count)
itertools
의 combinations
를 이용했다. 가능한 조합 중에서 sum()
을 수행해 값이 s
와 동일한 경우에만 count
를 증가시켰다.
for문
의 range
가 (1, n+1)
인 이유는 부분수열 크기를 정의하는 부분이기 때문이다. 360ms의 시간이 소요되었다.
import sys
input = sys.stdin.readline
n, s = map(int, input().split())
num_list = list(map(int, input().split()))
count = 0
def back(idx, temp_sum):
global count
if idx >= n:
return
temp_sum += num_list[idx]
if temp_sum == s:
count += 1
back(idx+1, temp_sum)
back(idx+1, temp_sum - num_list[idx])
back(0,0)
print(count)
불필요한 연산을 제거하고, 완전 병렬은 아니지만 재귀를 나누어 수행한다.
268ms가 소요되었다.