BOJ 1450 냅색문제(투 포인터 알고리즘)

박국현·2022년 8월 15일
0

코테 알고리즘

목록 보기
13/20

브루트 포스 방식으로 접근할 경우 30개의 원소를 가진 배열의 부분집합을 모두 구해야 하므로 2302^{30}의 경우의 수를 계산해야하므로 불가능하다.

따라서 수의 크기를 줄여야 한다. 두 개의 2152^{15}짜리 문제로 바꾼 후 투 포인터 알고리즘(O(N)O(N))으로 두 문제를 동시에 접근하여 풀이가 가능하다.

풀이법

  1. 주어진 배열을 반으로 나눠 각각 arr1, arr2로 저장한다.
N, C = map(int, input().split())
arr = list(map(int, input().split()))
arr1, arr2 = arr[:N // 2], arr[N // 2:]
  1. arr1, arr2의 부분집합을 다 구한다. 부분집합은 각각 subset1, subset2 에 저장할 것인데, 정확히는 부분집합의 합을 저장한다. 이때 문제에서 주어진 CC(냅색 문제에서 '가방의 크기'에 해당하는 값)보다 큰 경우는 저장하지 않는다. 이 둘을 오름차순 정렬한다.
def subset(arr: list, max_sum: int):
    result = []
    for i in range(1 << len(arr)):
        sub_sum = 0
        is_over = False
        for j in range(len(arr)):
            if i & (1 << j):
                sub_sum += arr[j]
                if sub_sum > max_sum:
                    is_over = True
                    break
        if not is_over:
            result.append(sub_sum)
    result.sort()
    return result
  1. subset1의 포인터를 pointer1이라 하고 0에서 시작한다. subset2의 포인터는 pointer2라 하고 len(subset2)에서 시작한다.
arr1, arr2 = arr[:N // 2], arr[N // 2:]
subset1 = subset(arr1, C)
subset2 = subset(arr2, C)
pointer1 = 0
pointer2 = len(subset2) - 1
  1. subset1[pointer1] + subset2[pointer2]의 값을 확인하며 CC보다 크면 pointer2를 감소시킨다. CC보다 작거나 같을 경우 answer 값에 (pointer2 + 1)를 더해주고 pointer1을 증가시킨다.

    pointer2 + 1answer에 더해주는 이유 - 현재 pointer가 가리키는 값의 의미를 생각해보자. arr1의 부분집합과 arr2의 부분집합을 합친 결과가 CC보다 작거나 같다는 뜻이다. 이 조건은 오름차순 정렬된 subset2의 모든 값 중 pointer2 앞에 있는 값 모두 만족하므로, (pointer2 + 1)개 만큼 만족하는 경우의 수가 존재한다.

answer = 0
while pointer1 < len(subset1) and pointer2 >= 0:
    if subset1[pointer1] + subset2[pointer2] > C:
        pointer2 -= 1
    else:  # (subset1[pointer1] + subset2[pointer2]) <= C
        answer += pointer2 + 1
        pointer1 += 1
print(answer)

전체 코드

import sys

input = sys.stdin.readline


# arr의 부분집합 구하는 함수
def subset(arr: list, max_sum: int):
    result = []
    for i in range(1 << len(arr)):
        sub_sum = 0
        is_over = False
        for j in range(len(arr)):
            if i & (1 << j):
                sub_sum += arr[j]
                if sub_sum > max_sum:
                    is_over = True
                    break
        if not is_over:
            result.append(sub_sum)
    result.sort()
    return result


def main():
    N, C = map(int, input().split())
    arr = list(map(int, input().split()))
    # 두 리스트로 나눠서 사용
    arr1, arr2 = arr[:N // 2], arr[N // 2:]
    subset1 = subset(arr1, C)
    subset2 = subset(arr2, C)
    # 각 리스트별로 다른 포인터 사용
    pointer1 = 0
    pointer2 = len(subset2) - 1
    answer = 0
    while pointer1 < len(subset1) and pointer2 >= 0:
        if subset1[pointer1] + subset2[pointer2] > C:
            pointer2 -= 1
        else:  # (subset1[pointer1] + subset2[pointer2]) <= C
            answer += pointer2 + 1
            pointer1 += 1
    print(answer)


if __name__ == '__main__':
    main()
profile
공부하자!!

0개의 댓글