
Problem: 제곱수의 합 (More Huge)
문제에도 나와 있듯이 라그랑주의 네 제곱수 정리에 따라 모든 수는 4개 이하의 제곱수의 합으로 나타낼 수 있다.
제곱수의 개수를 4개부터 하나씩 줄여나가며 풀어보기로 했다.
Legendre's three-square theorem
르장드르의 세 제곱수 정리에 따르면 이 의 형태가 아닐 경우 3개 이하의 제곱수의 합으로 나타낼 수 있다.
즉, 이 의 형태라면 4를 반환하면 된다.
def four_square(n):
while n % 4 == 0:
n //= 4
return n % 8 == 7
...
def solve(n):
if four_square(n):
return 4
Fermat's theorem on sums of two squares
페르마의 두 제곱수 정리에 따르면 2를 제외한 모든 소수는 또는 의 형태로 나타낼 수 있으며,
의 형태로 나타낼 수 있는 소수는 서로 다른 두 제곱수의 합으로 나타낼 수 있다.
주어진 수를 소인수분해하여 형태로 나타내지는 소수가 짝수 개라면 주어진 수는 2개 이하의 제곱수의 합으로 나타낼 수 있다.
즉, 형태로 나타낼 수 있는 소수가 홀수 개라면 3을 반환한다.
def three_square(n):
...
for i, n in c: # i: 소수, n: 해당 소수의 개수
if i % 4 == 3 and n % 2 == 1:
return True
return False
...
def solve(n):
...
elif three_square(n):
return 3
이제 남은 경우의 수는 제곱수냐, 제곱수가 아니냐 둘 밖에 없으니 sqrt함수를 사용해서 확인해 주면 된다.
def solve(n):
...
elif sqrt(n) ** 2 != n:
return 2
else:
return 1
def three_square(n):
l = []
while n > 1:
d = pollard_rho(n)
l.append(d)
n //= d
세 제곱수 판별 부분에서는 시간 제한이랑 주어지는 수의 범위가 저러니 폴라드 로와 밀러 라빈 알고리즘을 사용해서 소인수분해했다.
c = list(Counter(l).items())
for i, n in c:
if i % 4 == 3 and n % 2 == 1:
return True
return False
collections의 Counter를 사용하여 계산해주었다.
그 외에는 단순 구현인데.. 제곱수 판별 부분에서 한참을 삽질했다.
>>> m = sqrt(1234)
>>> m ** 2
1234.0000000000002
콘솔에도 저렇게 뜨길래 처음엔 아무 생각 없이 코드를 위처럼 적었었다.
>>> sqrt(567) ** 2
567.0
>>> sqrt(1333) ** 2
1333.0
그런데 반례가 생각보다 많았다.
...
elif int(sqrt(n)) ** 2 != n:
return 2
else:
return 1
위와 같이 해결했다. 근삿값 관련 문제였던 듯.
from random import randrange
from sys import stdin, setrecursionlimit
from math import gcd, sqrt
from collections import Counter
setrecursionlimit(10 ** 4)
input = stdin.readline
def powmod(a, e, m):
ret = 1
t = a % m
while e > 0:
if e & 1:
ret = ret * t % m
t = t * t % m
e >>= 1
return ret
def miller_rabin(n, a):
d = n - 1
while d % 2 == 0:
if powmod(a, d, n) == n - 1:
return True
d //= 2
t = powmod(a, d, n)
return t == n - 1 or t == 1
def is_prime(n):
if n == 1 or n % 2 == 0:
return False
for a in [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37]:
if n == a:
return True
if not miller_rabin(n, a):
return False
return True
def pollard_rho(n):
if is_prime(n):
return n
if n == 1:
return 1
if n % 2 == 0:
return 2
x = randrange(2, n)
y = x
c = randrange(1, n)
d = 1
while d == 1:
x = ((x ** 2 % n) + c + n) % n
y = ((y ** 2 % n) + c + n) % n
y = ((y ** 2 % n) + c + n) % n
d = gcd(abs(x - y), n)
if d == n:
return pollard_rho(n)
if is_prime(d):
return d
return pollard_rho(d)
def four_square(n):
while n % 4 == 0:
n //= 4
return n % 8 == 7
def three_square(n):
l = []
while n > 1:
d = pollard_rho(n)
l.append(d)
n //= d
c = list(Counter(l).items())
for i, n in c:
if i % 4 == 3 and n % 2 == 1:
return True
return False
def solve(n):
if four_square(n):
return 4
elif three_square(n):
return 3
elif int(sqrt(n)) ** 2 != n:
return 2
else:
return 1
print(solve(int(input())))
구현은 쉬운 문제다.