Binomial-Coefficient

LixFlora·2022년 12월 29일
0

공부

목록 보기
6/7

알고리즘 문제에서 조합론의 단골 주제인 이항계수를 계산하는 방법에 대해 알아보겠습니다.
이항계수를 계산하는 공식은 크게 두가지로 나뉘어집니다.

팩토리얼 공식

(nk)=n!(nk)!×k!\left(\begin{array}{c}n\\ k\end{array}\right)=\frac{n!}{(n-k)!\times k!}

점화식 공식

(nk)=(n1k1)+(n1k)\left(\begin{array}{c}n\\ k\end{array}\right)=\left(\begin{array}{c}n-1\\ k-1\end{array}\right)+\left(\begin{array}{c}n-1\\ k\end{array}\right)

팩토리얼 공식

먼저 이항계수의 정의에 따라 팩토리얼을 계산하여 구하는 방식입니다.

(nk)=n!(nk)!×k!=n1×n12×n23×...×nk+1k\left(\begin{array}{c}n\\ k\end{array}\right)=\frac{n!}{(n-k)!\times k!}=\frac{n}{1}\times \frac{n-1}{2}\times \frac{n-2}{3}\times ...\times \frac{n-k+1}{k}

위 식을 자세히 보면, (n-1+1)부터 (n-k+1)까지 곱해주고, 1부터 k까지로 나누어 줌을 알 수 있습니다.

long long nck(int n, int k) {
    long long ret = 1;
    for(int i = 1; i <= k; i++)
        ret = ret * (n-i+1) / i;
    return ret;
}

식을 그대로 계산하지 않고 위와 같이 해석하여 적용하는 이유는 속도를 개선하고 정수범위 overflow를 최대한 늦추기 위함입니다.
한편 i로 나눌 때 나누어 떨어지는지가 신경쓰일 수도 있으나, 연속된 t개의 자연수 중에는 반드시 t의 배수가 하나씩 존재하기 때문에 항상 나누어 떨어짐을 직관적으로 알 수 있습니다.
시간복잡도는 O(k)O(k)입니다.
즉, k가 커질수록 느려지게 됩니다.

(nk)=(nnk)\left(\begin{array}{c}n\\ k\end{array}\right)=\left(\begin{array}{c}n\\ n-k\end{array}\right)

만약 n-k가 더 작다면 위 식에 따라 k와 바꿔줄 수 있습니다.

long long nck(int n, int k) {
    k = min(k, n-k);
    long long ret = 1;
    for(int i = 1; i <= k; i++)
        ret = ret * (n-i+1) / i;
    return ret;
}

점화식 공식

만약, 이항계수를 m번 구해야하는 문제가 주어진다면 위에서 작성한 팩토리얼 코드로는 O(mk)O(mk)의 시간복잡도를 갖게 될 것입니다.
이렇게 m번 반복하여 계산하는 문제가 주어질 때는 점화식으로 접근하는 것이 유리합니다.

(nk)=(n1k1)+(n1k)\left(\begin{array}{c}n\\ k\end{array}\right)=\left(\begin{array}{c}n-1\\ k-1\end{array}\right)+\left(\begin{array}{c}n-1\\ k\end{array}\right)

이항계수의 점화식은 파스칼의 삼각형을 통해서도 확인할 수 있으며, 조합을 통해 직관적으로 이해할 수도 있습니다.
조합으로 생각해보면, n개 중 k개를 선택할 때의 경우의 수는 한개를 무조건 포함하는 경우의 수와 한개를 무조건 비포함하는 경우의 수를 계산하는 것과 같습니다.
한개를 무조건 포함한다면 n-1개 중 k-1개를 선택하는 상황이 되고,
한개를 무조건 비포함하면 n-1개 중 k개를 선택하는 상황이 됩니다.

long long nck(int n, int k) {
    k = min(k, n-k);
    if(k == 0) return 1;
    if(k == 1) return n;
    long long ret = nck(n-1, k-1) + nck(n-1, k);
    return ret;
}

이렇게 코드를 구성한다면, 시간복잡도가 O(n2)O(n^2)가 됩니다.
시간복잡도를 보면 팩토리얼 접근법보다 더 느리다는 것을 알 수 있습니다.
그럼에도 굳이 점화식으로 하는 이유는 메모이제이션을 적용하기 적합하기 때문입니다.
메모이제이션을 사용하게 된다면, 연산을 몇번 반복하더라도 O(n2)O(n^2)의 시간복잡도로 유지됩니다.

vector<vector<long long>> nck_(SIZE, vector<long long>(SIZE/2, -1));
long long nck(int n, int k) {
    k = min(k, n-k);
    if(k == 0) return 1;
    if(k == 1) return n;
    long long &ret = nck_[n][k];
    if(ret == -1) ret = nck(n-1, k-1) + nck(n-1, k);
    return ret;
}

이항계수의 값이 커지면서 정수형의 overflow를 겪게 되기 때문에 일반적으로 문제에서는 mod값을 주고, 나머지를 구하도록 시킵니다.
또한 편의를 위해 n값에 따라 메모이제이션 배열의 크기를 자동으로 조절하게 해준다면 아래의 코드가 됩니다.

int mod = 1e9+7;
vector<vector<long long>> nck_;
long long nck(int n, int k) {
    int s = nck_.size();
    if(n-3 > s) nck_.resize(n-3, vector<long long>((n>>1)-1, -1));
    k = min(k, n-k);
    if(k <= 1) return k?n:1;
    long long &r = nck_[n-4][k-2];
    if(r == -1) r = (nck(n-1, k-1) + nck(n-1, k)) % mod;
    return r;
}

여기서 메모이제이션에 사용한 공간복잡도를 생각해본다면, 시간복잡도와 마찬가지로 O(n2)O(n^2)이라는 것을 알 수 있습니다.
n값이 작을 때에는 문제가 없지만, 값이 커지면 메모리 제한에 걸리게 됩니다.

다시, 팩토리얼

팩토리얼 공식에 메모이제이션을 적용한다면 공간복잡도가 O(n)O(n)이 되어 메모리를 절약할 수 있게 됩니다.
방법은 팩토리얼 값들을 미리 배열에 담아두는 것입니다.

vector<long long> fac_(SIZE, -1);
long long fac(int x) {
    if(x <= 1) return 1;
    long long &ret = fac_[x];
    if(ret == -1) ret = fac(x-1) * x;
    return ret;
}
long long nck(int n, int k) {
    return fac[n] / fac[n-k] / fac[k];
}

이 코드는 n이 작을 때는 잘 작동하지만, n이 커지면 overflow가 발생합니다.
합동식에서 나눗셈이 성립하지 않기 때문에 mod값을 그냥은 적용할 수 없습니다.
그렇기 때문에 합동식의 역원을 찾아야 하고, 이는 페르마 소정리를 통해 해결할 수 있습니다.

페르마 소정리

Np2×N1 (mod p)N^{p-2}\times N \equiv 1\space (mod\space p)

페르마 소정리를 통하여 N으로 나누는 것과 Np2N^{p-2}를 곱해주는 것이 합동식에서는 동치라는 것을 확인할 수 있습니다. (p가 소수임을 가정)
이를 곱셈의 역원이라고 하고, 이 역원들 역시 메모이제이션으로 배열에 담아두면 mod를 이용할 수 있게 됩니다.
한편, mod값은 알고리즘 문제에서 보편적으로 109+710^9+7 과 같이 큰 수로 주어집니다.
단순하게 Np2N^{p-2}에서 N을 하나씩 곱해주기만 한다면, 그것만으로 1초가 넘을 수 있습니다.
그래서 거듭제곱 알고리즘을 이용합니다.

long long x = 1;
for(int i = p-2; i; i /= 2) {
    if(i % 2 == 1) x *= N;
    N *= N;
}

속도개선을 위해 비트연산을 적용할 수도 있습니다.

long long x = 1;
for(int i = p-2; i; i>>=1) {
    if(i&1) x *= N;
    N *= N;
}

이 거듭제곱 알고리즘으로 1부터 n까지의 팩토리얼과 그 역원의 배열을 만드는 것을 코드로 작성한다면,

int mod = 1e9+7;
vector<long long> fac_(n+1), fac_i(n+1, 1);
for(int i = 0; i <= n; i++) fac_[i] = i ? i * fac_[i-1] % mod : 1;
long long x = fac_[n];
long long &t = fac_i[n];
for(int m = mod-2; m; m >>= 1, x = x * x % mod) if(m&1) t = t * x % mod;
for(int i = n; i >= 0 && i; i--) fac_i[i-1] = i * fac_i[i] % mod;

위와 같은 코드가 될 것이고, 이를 통해 이항계수를 계산하는 것은 간단합니다.

long long nck(int n, int k) {
    return fac_[n] * fac_i[k] % mod * fac_i[n-k] % mod;
}

여기에 n값에 따라 자동으로 메모이제이션 크기를 조정하도록 한다면, 다음의 코드가 완성됩니다.

int mod = 1e9+7;
vector<long long> nck_, nck_i;
long long nck(int n, int k) {
    int s=nck_.size();
    if(n>=s) {
        nck_.resize(n+1), nck_i.resize(n+1,1);
        for(int i=s; i<=n; i++)
            nck_[i]=i?i*nck_[i-1]%mod:1;
        long long x=nck_[n];
        long long &t=nck_i[n];
        for(int m = mod-2; m; m>>=1, x=x*x%mod)
            if(m&1) t=t*x%mod;
        for(int i=n; i>=s && i; i--)
            nck_i[i-1]=i*nck_i[i]%mod;
    }
    return nck_[n]*nck_i[k]%mod*nck_i[n-k]%mod;
}

여기서 주의할 점은 자동으로 메모이제이션 크기를 조정하도록 했지만, n이 점차 증가하는 식을 입력받으면 오버헤딩이 크게 발생하게 됩니다.
따라서 편의를 위해 이 코드를 이용한다면, 입력받을 n의 최대값으로 미리 한번 실행시켜둘 필요가 있습니다.
한편 모듈러를 이용하는 식이므로, n값이 mod값보다 큰 경우에는 사용할 수 없습니다.
마지막으로 시간복잡도와 공간복잡도는 O(n)O(n)이 되는데, k값이 충분히 작다면 O(k)O(k)로 구하는 게 더 빠릅니다.

문제 추천

백준 13977 - 이항 계수와 쿼리
백준 6591 - 이항 쇼다운
백준 20296 - 폰친구

0개의 댓글