https://www.acmicpc.net/problem/14854
BJ 14854 이항 계수 6 - (1)에서 이어집니다.
어떤 이항 계수를 소수 로 나눈 나머지를 구하는 방법과, 어떤 소수의 거듭제곱 로 나눈 나머지를 구하는 방법을 모두 알아냈습니다. 각각의 법에 대한 해를 가지고 법법 142857에 대한 나머지를 구해야 하는데, 이를 해결할 수 있는 방법이 중국인의 나머지 정리입니다.
다음과 같은 연립합동식이 존재하고, 모든 는 쌍마다 서로소라고 가정합니다. 이 때, 0 이상 미만의 정수 중에 이 연립합동식의 해가 유일하게 존재합니다. 존재성과 유일성에 대한 내용은 여기에서 다루지 않고, 간단히 해를 구하는 알고리즘만 파악하도록 하겠습니다.
각 합동식의 를 모두 곱한 을 계산합니다.
번째 합동식에서 라 하고, 이고 인 를 정의합니다.
연립 합동식의 해는 다음과 같습니다:
from sys import stdin, stdout
from functools import reduce
from operator import mul
def get_modulo_inverse(a, b):
'''
Calculate inverse of a (mod b)
'''
b0 = b
x0, x1 = 0, 1
if b == 1:
return 1
while a > 1:
q = a // b
a, b = b, a % b
x0, x1 = x1 - q * x0, x0
if x1 < 0:
x1 += b0
return x1
factorial = {} # factorial[m][n] : n! (mod m)
inverse = {} # inverse[m][n]: inverse of n! (mod m)
def choose_modulo_prime(n, m, p):
'''Calculate C(n, m) % p
1 <= n <= 10^18, 0 <= m <= n, 2 <= p <= 2000
p must be prime
Return: int
'''
def get_factorial(num):
if p not in factorial:
factorial[p] = {0: 1, 1: 1}
for idx in range(2, num+1):
factorial[p][idx] = factorial[p][idx-1] * idx % p
return factorial[p][num]
def get_inverse(num):
if p not in inverse:
inverse[p] = {0: 1, 1: 1}
for idx in range(2, num+1):
inverse[p][idx] = get_modulo_inverse(get_factorial(idx), p)
return inverse[p][num]
choose_prod = 1
while n != 0 or m != 0:
n_digit = n % p
n //= p
m_digit = m % p
m //= p
if n_digit < m_digit:
choose_prod = 0
break
else:
choose_prod *= get_factorial(n_digit) * get_inverse(m_digit) * get_inverse(n_digit-m_digit)
choose_prod %= p
return choose_prod
p_adic_factorial = {} # p_adic_factorial[m][n]: (n!)_m (mod m)
p_adic_inverse = {} # p_adic_inverse[m][n]: inverse of (n!)_m (mod m)
def choose_modulo_prime_power(n, m, p, q):
'''Calculate C(n, m) % p^q.
1 <= n <= 10^9, 0 <= m <= n, q >= 1
p must be prime
Return: int
'''
modulo = p ** q
def expand_by_p(n, m, r):
'''Expand number with given base (p).
Each n, m, r will expanded until their length is more than d_min.
Return: tuple(list, list, list, int)
'''
n_expand, m_expand, r_expand = [], [], []
d = 0
d_min = q-1
while n > 0 or m > 0 or r > 0 or d <= d_min:
n_expand.append(n % p)
n //= p
m_expand.append(m % p)
m //= p
r_expand.append(r % p)
r //= p
d += 1
return n_expand, m_expand, r_expand, d-1
def least_positive_residue(n, m, r, d):
'''Calculate x // p^j % p^q where x = n, m, r and j = 0, 1, ..., d
Return: tuple(list, list, list)
'''
n_lpr, m_lpr, r_lpr = [], [], []
for _ in range(d+1):
n_lpr.append(n % modulo)
n //= p
m_lpr.append(m % modulo)
m //= p
r_lpr.append(r % modulo)
r //= p
return n_lpr, m_lpr, r_lpr
def carry_count(m_expand, r_expand, d):
'''Count number of carries occur when add m and r with base p.
Return: tuple(int, int)
'''
has_carry = [0] * (d+1)
prev_carry = 0
for idx in range(d+1):
value = m_expand[idx] + r_expand[idx] + prev_carry
if value >= p:
has_carry[idx] = 1
prev_carry = 1
else:
prev_carry = 0
eq1 = sum(has_carry[q-1:])
e0 = sum(has_carry[:q-1]) + eq1
return e0, eq1
def get_p_adic_factorial(num):
if p not in p_adic_factorial:
p_adic_factorial[p] = {0: 1, 1: 1}
begin = len(p_adic_factorial[p])
for idx in range(begin, num+1):
p_adic_factorial[p][idx] = p_adic_factorial[p][idx-1] * (1 if idx % p == 0 else idx) % modulo
return p_adic_factorial[p][num]
def get_p_adic_inverse(num):
if p not in p_adic_inverse:
p_adic_inverse[p] = {0: 1, 1: 1}
begin = len(p_adic_inverse[p])
for idx in range(begin, num+1):
p_adic_inverse[p][idx] = get_modulo_inverse(get_p_adic_factorial(idx), modulo)
return p_adic_inverse[p][num]
r = n - m
_, m_expand, r_expand, d = expand_by_p(n, m, r)
n_lpr, m_lpr, r_lpr = least_positive_residue(n, m, r, d)
e0, eq1 = carry_count(m_expand, r_expand, d)
n_factorial = map(get_p_adic_factorial, n_lpr)
m_inverse = map(get_p_adic_inverse, m_lpr)
r_inverse = map(get_p_adic_inverse, r_lpr)
p_adic_choose = (n * m * r % modulo for n, m, r in zip(n_factorial, m_inverse, r_inverse))
choose_prod = reduce(mul, p_adic_choose) % modulo
sign = 1 if p == 2 and q >= 3 else -1
pm = 1 if sign == 1 or eq1 % 2 == 0 else -1
return p ** e0 * pm * choose_prod % modulo
def get_crt_root(a_list, n_list):
'''Calculate root of x ≡ a_i (mod n_i) for each a_i and n_i.
Length of a_list and n_list must be same.
Result: int (in bound of 0 and N where N is product of each n_i)
'''
n_mul = reduce(mul, n_list)
root_list = (a * n_mul // n * (n_mul // n % a) for a, n in zip(a_list, n_list))
return sum(root_list) % n_mul
n_list = [27, 11, 13, 37]
input = stdin.readline
print = stdout.write
t = int(input())
for _ in range(t):
n, k = map(int, input().split())
a_list = [choose_modulo_prime_power(n, k, 3, 3), choose_modulo_prime(n, k, 11), choose_modulo_prime(n, k, 13), choose_modulo_prime(n, k, 37)]
print(str(get_crt_root(a_list, n_list)))
print('\n')
stdout.flush()
문제에서는 choose_modulo_prime_power
함수의 인자 p
와 q
가 각각 3만이 호출되므로 위와 같은 딕셔너리를 쓰는 것은 실제 수행 속도는 더 느려집니다. 하지만 이후의 이식 가능성을 염두에 두면 더 좋은 구현일 것입니다. 마찬가지로 choose_modulo_prime
도 실 수행 속도는 법 11, 13, 37에 대한 함수를 따로 만들어서 계산하는 것이 더 빠르지만, 범용적으로 활용 가능하도록 구현했습니다. 더 빠른 속도를 원할 경우 이처럼 인자의 단순화 뿐만 아니라, factorial
, inverse
, p_adic_factorial
, p_adic_inverse
를 미리 전처리시 37까지 계산하고 시작하는 것이 훨씬 빠르다는 점 상기 바랍니다.
reduce
함수는 functools
라는 파이썬의 built-in 모듈의 함수입니다. 초기값을 기준으로 데이터를 순회하면서 집계 함수를 계속 적용하는, 즉 누적으로 적용한 결과를 반환합니다.
본 프로그램에서는 reduce(mul, gen)
의 형식으로 사용하였고, 이는 gen
제너레이터에서 반환하는 값을 모두 곱한 값을 반환합니다. mul
역시 파이썬의 built-in 모듈인 operator
의 함수이며, 이외에도 add
, attrgetter
등 유용한 연산자들이 많습니다. 이는 다음에서 확인할 수 있습니다.
굳이 변수를 선언해서 여기에 값을 곱해가며 결과를 계산하는 대신, reduce
와 mul
을 사용하는 이유는 속도와 가독성을 위함임은 설명하지 않아도 알 것입니다.
쿼리형 문제는 단일 쿼리당 빠른 속도로 문제를 해결하는 것도 중요하지만, 입출력 속도도 성패에 영향을 미칠 수 있습니다. Python의 input
, print
함수는 상당히 느린 편에 속합니다. Python에서의 빠른 입출력을 위해서는 sys
모듈의 stdin.readline
과 stdout.write
를 사용하는 것이 좋습니다.
큰 성능의 향상을 돕는 것은 아니지만, 속도를 올리는 편법으로
input = stdin.readline
print = stdout.write
처럼 클래스 내 함수를 하나의 변수로 치환하면 처리 속도가 더 증가한다는 이점이 있습니다.