문제

  • 0, 1, 2 ... n-1로 이름이 부여된 n마리의 거위가 있습니다.
  • k마리의 거위들이 탈출했습니다.
  • 탈출한 거위들의 이름의 합은 n으로 나누어 떨어집니다.
  • 탈출한 거위들의 집합이 총 몇 가지인지를 구하시오.
  • n(1 <= n <= 500) 전체 거위 수, k(1 <= k <= min(n, 100)) 탈출한 거위 수
  • 시간 제한 3초
  • 문제 링크

이거 풀다가 토할 뻔, 3일 동안... 그래도 냅색을 이렇게 꼬아버릴 수 있구나 하고 새로웠다!

접근 과정

1. DP, 0/1 knapsack(배낭 문제)

  • 집합을 구할 때 DFS(재귀적인 성질을 이용하는 것인데, 탐색하는 과정이 DFS와 같음)를 사용하면 O(2^n) 이 걸립니다. n(1 <= n <= 500) 이기 때문에 2^500 이라는 어마어마한 시간이 소비되어 사용할 수 없습니다. (DFS는 보통 n<=20인 경우에 사용합니다.)
  • DFS로 집합을 구할 수 없을 때는 DP를 사용할 수 있습니다. 한 원소에 대해 선택하고 안하는 경우를 0/1 knapsack (배낭 문제) 이라고 합니다.
  • 하지만, 이 문제에서 0/1 knapsack에서 조금 더 생각해야할 부분이 있습니다. 그대로 적용하게 되면 1) 메모리 초과 또는 2) 시간 초과가 발생합니다.

2. 점화식

  • 정보를 저장할 배열 설계
    1) 거위 이름의 합 i를 정보로 생각하지 말고, 이름의 합을 n으로 나눈 나머지가 i일때를 정보로 생각합니다.
    2) n=7, k=2 라고 하면 정답이 되는 집합은 {1, 6}, {2, 5}, {3, 4} 가 있고 배열로 나타내면 d[2][0] = 3(탈출한 거위의 수는 2, 탈출한 거위의 이름의 합을 n으로 나눈 나머지가 0)
    d[i][j] = 탈출한 거위수가 i마리이고, 거위 이름의 합을 n으로 나눈 나머지가 j일 때의 경우의 수
  • 점화식
    // j마리가 탈출했고 숫자의 합을 n으로 나눈 나머지가 k일때는
    // j-1마리가 탈출했고 이름의 합을 나눈 나머지가 (k-i+n)%n 인 경우에서 올 수 있습니다.
    // 원래는 i번째가 탈출했기 때문에 d[j][k] += d[j-1][k-i]라고 할 수 있지만,
    // 나머지를 사용하기 때문에 아래와 같이 점화식이 세워집니다.
    d[j][k] = d[j][k] + d[j-1][(k-i+n)%n];

3. 시간 복잡도 계산

  • 1) O(kn^2) 반복문 3개 중첩

  • n(1 <= n <= 500) 전체 거위 수, k(1 <= k <= min(n, 100) 탈출한 거위 수 이기 때문에 O(kn^2)은 O(100 * 500 * 500) = O(2천5백만) 문제의 시간 제한이 3초 이기 때문에 시간안에 풀 수 있습니다.

코드

1. C++

#include <iostream>
#include <cstring>

#define max_int 501
using namespace std;

//시간 복잡도: O(kn^2)
//공간 복잡도: O(kn)
//사용한 알고리즘: DP Bottom-up(0/1 knapsack)
//사용한 자료구조: 2차원 배열

int t, n, m;
// d[i][j] = 탈출한 거위수가 i마리이고, 거위 숫자의 합을 n으로 나눈 나머지가 j일 때의 경우의 수
int d[101][max_int];
int mod = 1000000007;

int main(){
    scanf("%d", &t);
    for(int test_case=1; test_case<=t; test_case++){
        scanf("%d %d", &n, &m);

        // 1. d 배열 초기화
        memset(d, 0, sizeof(d));

        // 2. 초기값 설정
        d[0][0] = 1; //0마리가 탈출했고 합을 n으로 나눈 나머지가 0인 경우의 수는 1입니다.
        d[1][0] = 1; //1마리가 탈출했고 합을 n으로 나눈 나머지가 0인 경우의 수는 1입니다.

        // 3. 점화식 실행
        // i번째 거위부터 n-1번 거위까지 검사합니다.
        for(int i=1; i<n; i++){
            // min(n, 100)마리가 탈출 할 수 있습니다.
            for(int j=min(n, 100); j>0; j--){
                // 거위 숫자의 합을 n으로 나눌 경우 0~n-1 까지의 합을 가질 수 있습니다.
                for(int k=n-1; k>=0; k--){

                    // j마리가 탈출했고 숫자의 합을 n으로 나눈 나머지가 k일때는
                    // j-1마리가 탈출했고 숫자의 합을 나눈 나머지가 (k-i+n)%n 일때의 경우의 수를 받습니다.
                    d[j][k] = (d[j][k] + d[j-1][(k-i+n)%n])%mod;
                }
            }
        }

        // 4. 결과 출력
        printf("%d\n", d[m][0]);
    }
}