BJ 14854 이항 계수 6 - (2)

이경헌·2021년 1월 16일
1

백준 - 이항 계수

목록 보기
8/8

https://www.acmicpc.net/problem/14854

BJ 14854 이항 계수 6 - (1)에서 이어집니다.

중국인의 나머지 정리

어떤 이항 계수를 소수 pp로 나눈 나머지를 구하는 방법과, 어떤 소수의 거듭제곱 pqp^q로 나눈 나머지를 구하는 방법을 모두 알아냈습니다. 각각의 법에 대한 해를 가지고 법법 142857에 대한 나머지를 구해야 하는데, 이를 해결할 수 있는 방법이 중국인의 나머지 정리입니다.

{xa1(modn1)xa2(modn2)xak(modnk)\begin{cases} x \equiv& a_1 \pmod{n_1} \\ x \equiv& a_2 \pmod{n_2} \\ &\vdots \\ x \equiv& a_k \pmod{n_k} \end{cases}

다음과 같은 연립합동식이 존재하고, 모든 nin_i는 쌍마다 서로소라고 가정합니다. 이 때, 0 이상 N=n1×n2××nkN=n_1\times n_2\times \cdots\times n_k 미만의 정수 중에 이 연립합동식의 해가 유일하게 존재합니다. 존재성과 유일성에 대한 내용은 여기에서 다루지 않고, 간단히 해를 구하는 알고리즘만 파악하도록 하겠습니다.

  1. 각 합동식의 nin_i를 모두 곱한 NN을 계산합니다.

  2. ii번째 합동식에서 Ni=N/niN_i=N/n_i라 하고, MiNi(modni)M_i\equiv N_i \pmod{n_i}이고 0Mini0\le M_i \le n_iMiM_i를 정의합니다.

  3. 연립 합동식의 해는 다음과 같습니다:

    xi=1kaiNiMi(modN)x\equiv \prod_{i=1}^k a_iN_iM_i \pmod N

코드

from sys import stdin, stdout
from functools import reduce
from operator import mul

def get_modulo_inverse(a, b):
    '''
    Calculate inverse of a (mod b)
    '''
    b0 = b
    x0, x1 = 0, 1
    if b == 1:
        return 1
    while a > 1:
        q = a // b
        a, b = b, a % b
        x0, x1 = x1 - q * x0, x0
    if x1 < 0:
        x1 += b0
    return x1

factorial = {} # factorial[m][n] : n! (mod m)
inverse = {} # inverse[m][n]: inverse of n! (mod m)

def choose_modulo_prime(n, m, p):
    '''Calculate C(n, m) % p
    1 <= n <= 10^18, 0 <= m <= n, 2 <= p <= 2000
    p must be prime

    Return: int
    '''
    def get_factorial(num):
        if p not in factorial:
            factorial[p] = {0: 1, 1: 1}
        for idx in range(2, num+1):
            factorial[p][idx] = factorial[p][idx-1] * idx % p
        return factorial[p][num]

    def get_inverse(num):
        if p not in inverse:
            inverse[p] = {0: 1, 1: 1}
        for idx in range(2, num+1):
            inverse[p][idx] = get_modulo_inverse(get_factorial(idx), p)
        return inverse[p][num]

    choose_prod = 1
    while n != 0 or m != 0:
        n_digit = n % p
        n //= p
        m_digit = m % p
        m //= p

        if n_digit < m_digit:
            choose_prod = 0
            break
        else:
            choose_prod *= get_factorial(n_digit) * get_inverse(m_digit) * get_inverse(n_digit-m_digit)
            choose_prod %= p
        
    return choose_prod    

p_adic_factorial = {} # p_adic_factorial[m][n]: (n!)_m (mod m)
p_adic_inverse = {} # p_adic_inverse[m][n]: inverse of (n!)_m (mod m)

def choose_modulo_prime_power(n, m, p, q):
    '''Calculate C(n, m) % p^q.
    1 <= n <= 10^9, 0 <= m <= n, q >= 1
    p must be prime
    
    Return: int
    '''
    modulo = p ** q

    def expand_by_p(n, m, r):
        '''Expand number with given base (p).
        Each n, m, r will expanded until their length is more than d_min.
        
        Return: tuple(list, list, list, int)
        '''
        n_expand, m_expand, r_expand = [], [], []
        d = 0
        d_min = q-1
        while n > 0 or m > 0 or r > 0 or d <= d_min:
            n_expand.append(n % p)
            n //= p
            m_expand.append(m % p)
            m //= p
            r_expand.append(r % p)
            r //= p

            d += 1
        return n_expand, m_expand, r_expand, d-1
    
    def least_positive_residue(n, m, r, d):
        '''Calculate x // p^j % p^q where x = n, m, r and j = 0, 1, ..., d

        Return: tuple(list, list, list)
        '''
        n_lpr, m_lpr, r_lpr = [], [], []
        for _ in range(d+1):
            n_lpr.append(n % modulo)
            n //= p
            m_lpr.append(m % modulo)
            m //= p
            r_lpr.append(r % modulo)
            r //= p
        return n_lpr, m_lpr, r_lpr
    
    def carry_count(m_expand, r_expand, d):
        '''Count number of carries occur when add m and r with base p.

        Return: tuple(int, int)
        '''
        has_carry = [0] * (d+1)
        prev_carry = 0
        for idx in range(d+1):
            value = m_expand[idx] + r_expand[idx] + prev_carry
            if value >= p:
                has_carry[idx] = 1
                prev_carry = 1
            else:
                prev_carry = 0
        
        eq1 = sum(has_carry[q-1:])
        e0 = sum(has_carry[:q-1]) + eq1
        return e0, eq1
    
    def get_p_adic_factorial(num):
        if p not in p_adic_factorial:
            p_adic_factorial[p] = {0: 1, 1: 1}
        begin = len(p_adic_factorial[p])
        for idx in range(begin, num+1):
            p_adic_factorial[p][idx] = p_adic_factorial[p][idx-1] * (1 if idx % p == 0 else idx) % modulo
        return p_adic_factorial[p][num]
    
    def get_p_adic_inverse(num):
        if p not in p_adic_inverse:
            p_adic_inverse[p] = {0: 1, 1: 1}
        begin = len(p_adic_inverse[p])
        for idx in range(begin, num+1):
            p_adic_inverse[p][idx] = get_modulo_inverse(get_p_adic_factorial(idx), modulo)
        return p_adic_inverse[p][num]

    r = n - m
    _, m_expand, r_expand, d = expand_by_p(n, m, r)
    n_lpr, m_lpr, r_lpr = least_positive_residue(n, m, r, d)
    e0, eq1 = carry_count(m_expand, r_expand, d)

    n_factorial = map(get_p_adic_factorial, n_lpr)
    m_inverse = map(get_p_adic_inverse, m_lpr)
    r_inverse = map(get_p_adic_inverse, r_lpr)

    p_adic_choose = (n * m * r % modulo for n, m, r in zip(n_factorial, m_inverse, r_inverse))
    choose_prod = reduce(mul, p_adic_choose) % modulo

    sign = 1 if p == 2 and q >= 3 else -1
    pm = 1 if sign == 1 or eq1 % 2 == 0 else -1

    return p ** e0 * pm * choose_prod % modulo

def get_crt_root(a_list, n_list):
    '''Calculate root of x ≡ a_i (mod n_i) for each a_i and n_i.
    Length of a_list and n_list must be same.

    Result: int (in bound of 0 and N where N is product of each n_i)
    '''
    n_mul = reduce(mul, n_list)
    root_list = (a * n_mul // n * (n_mul // n % a) for a, n in zip(a_list, n_list))
    return sum(root_list) % n_mul

n_list = [27, 11, 13, 37]

input = stdin.readline
print = stdout.write
t = int(input())
for _ in range(t):
    n, k = map(int, input().split())
    a_list = [choose_modulo_prime_power(n, k, 3, 3), choose_modulo_prime(n, k, 11), choose_modulo_prime(n, k, 13), choose_modulo_prime(n, k, 37)]
    print(str(get_crt_root(a_list, n_list)))
    print('\n')
stdout.flush()

문제에서는 choose_modulo_prime_power함수의 인자 pq가 각각 3만이 호출되므로 위와 같은 딕셔너리를 쓰는 것은 실제 수행 속도는 더 느려집니다. 하지만 이후의 이식 가능성을 염두에 두면 더 좋은 구현일 것입니다. 마찬가지로 choose_modulo_prime도 실 수행 속도는 법 11, 13, 37에 대한 함수를 따로 만들어서 계산하는 것이 더 빠르지만, 범용적으로 활용 가능하도록 구현했습니다. 더 빠른 속도를 원할 경우 이처럼 인자의 단순화 뿐만 아니라, factorial, inverse, p_adic_factorial, p_adic_inverse를 미리 전처리시 37까지 계산하고 시작하는 것이 훨씬 빠르다는 점 상기 바랍니다.

알아두면 좋은 연산

reduce와 mul

reduce 함수는 functools라는 파이썬의 built-in 모듈의 함수입니다. 초기값을 기준으로 데이터를 순회하면서 집계 함수를 계속 적용하는, 즉 누적으로 적용한 결과를 반환합니다.

본 프로그램에서는 reduce(mul, gen)의 형식으로 사용하였고, 이는 gen 제너레이터에서 반환하는 값을 모두 곱한 값을 반환합니다. mul 역시 파이썬의 built-in 모듈인 operator의 함수이며, 이외에도 add, attrgetter등 유용한 연산자들이 많습니다. 이는 다음에서 확인할 수 있습니다.

굳이 변수를 선언해서 여기에 값을 곱해가며 결과를 계산하는 대신, reducemul을 사용하는 이유는 속도와 가독성을 위함임은 설명하지 않아도 알 것입니다.

구현시 주의할 점

쿼리형 문제는 단일 쿼리당 빠른 속도로 문제를 해결하는 것도 중요하지만, 입출력 속도도 성패에 영향을 미칠 수 있습니다. Python의 input, print 함수는 상당히 느린 편에 속합니다. Python에서의 빠른 입출력을 위해서는 sys 모듈의 stdin.readlinestdout.write를 사용하는 것이 좋습니다.

큰 성능의 향상을 돕는 것은 아니지만, 속도를 올리는 편법으로

input = stdin.readline
print = stdout.write

처럼 클래스 내 함수를 하나의 변수로 치환하면 처리 속도가 더 증가한다는 이점이 있습니다.

profile
Undergraduate student in Korea University. Major in electrical engineering and computer science.

0개의 댓글