BOJ 11401 이항 계수 3

박국현·2022년 5월 26일
0

코테 알고리즘

목록 보기
10/20

(NR){N}\choose{R}을 구하는 간단한 문제이지만 NN의 범위가 너무 커 단순 계산으로 풀 수 없는 문제이다. (N4,000,000N\le4,000,000)
수의 크기를 줄이기 위해 N!(NR)!R!\frac{N!}{(N-R)!R!}에서 팩토리얼을 미리 1,000,000,0071,000,000,007으로 나누며 나머지로 계산을 해야 하는데 이렇게 하면 분모에 00이 올 수 있기 때문 불가능하다.

이 문제에서 주목해야 하는 것은 1,000,000,0071,000,000,007이라는 숫자이다. 이 숫자는 컴퓨터에서 정수로 표현할 수 있는 수 중 가장 큰 소수라고 한다.
소수라는 점이 이 숫자의 중요한 점인데, 덕분에 페르마의 소정리를 이용할 수 있다.

페르마의 소정리

pp가 소수이고 aa가 정수일 때 다음이 성립한다.

ap11 (mod p)a^{p-1} \equiv 1 \text{ (mod p)}

위키피디아 출처

이 식을 비틀어서(?) 생각하면 본 문제를 풀 수 있다. p=1,000,000,007p=1,000,000,007을 대입한다고 가정하고 N!(NR)!R!\frac{N!}{(N-R)!R!}을 다시 써보면서 문제를 구조화해보자.

N!(NR)!R!=N!1(NR)!R!N!1(NR)!R!((NR)!R!)p1 (mod p)N!((NR)!R!)p2 (mod p)\begin{matrix} \frac{N!}{(N-R)!R!} &=& N! * \frac{1}{(N-R)!R!} \\ &\equiv& N! * \frac{1}{(N-R)!R!} * ((N-R)!R!)^{p-1} &\text{ (mod p)} \\ &\equiv& N! * ((N-R)!R!)^{p-2} &\text{ (mod p)} \end{matrix}

따라서 N!((NR)!R!)p2N! * ((N-R)!R!)^{p-2}pp로 나눈 나머지를 구하면 이 문제를 풀 수 있는 것이다. 이 목표를 가지고 문제를 코드화해보자.

코드화

이 문제에는 큰 값의 지수 계산이 요구되므로 계산 함수부터 설정한다. 이 계산에서 수가 커지는 것을 막기 위해 나머지 계산을 미리 해놓는다.

# a^b 를 mod로 나눈 나머지를 구하는 함수
def power(a, b, mod):
    if b == 1:
        return a
    elif b == 0:
        return 1
    return (power(a, b // 2, mod) ** 2) * power(a, b % 2, mod) % mod

이어서 팩토리얼을 구한다. 입력값이 4백만 이상일 수 있으므로 팩토리얼을 구할 때부터 나머지로 계산한다.

factorial = list(range(N + 1))
factorial[0] = 1
for i in range(3, N + 1):
    factorial[i] = factorial[i] * factorial[i - 1] % mod

조합 공식의 분모 분자를 미리 변수에 할당해둔다.

A = factorial[N]
B = (factorial[N - K] * factorial[K]) % mod

위 분모 분자를 A!((AB)!B!)p2A! * ((A-B)!B!)^{p-2}의 형태로 정리해준다.

answer = (A % mod) * (power(B, mod - 2, mod) % mod) % mod

전체 코드

import sys

input = sys.stdin.readline


# power
def power(a, b, mod):
    if b == 1:
        return a
    elif b == 0:
        return 1
    return (power(a, b // 2, mod) ** 2) * power(a, b % 2, mod) % mod


def main():
    N, K = map(int, input().split())
    mod = 1000000007

    # factorial
    factorial = list(range(N + 1))
    factorial[0] = 1
    for i in range(3, N + 1):
        factorial[i] = factorial[i] * factorial[i - 1] % mod

    A = factorial[N]
    B = (factorial[N - K] * factorial[K]) % mod
    answer = (A % mod) * (power(B, mod - 2, mod) % mod) % mod
    sys.stdout.write(str(answer))


if __name__ == '__main__':
    main()
profile
공부하자!!

0개의 댓글