17633번: 제곱수의 합 (More Huge)

YongChan Cho·2022년 2월 14일

백준

목록 보기
2/3
post-thumbnail

시작

Problem: 제곱수의 합 (More Huge)

문제에도 나와 있듯이 라그랑주의 네 제곱수 정리에 따라 모든 수는 4개 이하의 제곱수의 합으로 나타낼 수 있다.

제곱수의 개수를 4개부터 하나씩 줄여나가며 풀어보기로 했다.

세 제곱수 정리

Legendre's three-square theorem

르장드르의 세 제곱수 정리에 따르면 nnn=4a(8b+7)n = 4^a(8b+7)의 형태가 아닐 경우 3개 이하의 제곱수의 합으로 나타낼 수 있다.

즉, nnn=4a(8b+7)n = 4^a(8b+7)의 형태라면 4를 반환하면 된다.

def four_square(n):
    while n % 4 == 0:
        n //= 4
    return n % 8 == 7

...

def solve(n):
    if four_square(n):
        return 4

두 제곱수 정리

Fermat's theorem on sums of two squares

페르마의 두 제곱수 정리에 따르면 2를 제외한 모든 소수는 4n+14n+1 또는 4n14n-1의 형태로 나타낼 수 있으며,

4n+14n+1의 형태로 나타낼 수 있는 소수는 서로 다른 두 제곱수의 합으로 나타낼 수 있다.

주어진 수를 소인수분해하여 4n+34n+3형태로 나타내지는 소수가 짝수 개라면 주어진 수는 2개 이하의 제곱수의 합으로 나타낼 수 있다.

즉, 4n+34n+3형태로 나타낼 수 있는 소수가 홀수 개라면 3을 반환한다.

def three_square(n):
    ...
    for i, n in c: # i: 소수, n: 해당 소수의 개수
        if i % 4 == 3 and n % 2 == 1:
            return True
    return False

...

def solve(n):
    ...
    elif three_square(n):
        return 3

제곱수

이제 남은 경우의 수는 제곱수냐, 제곱수가 아니냐 둘 밖에 없으니 sqrt함수를 사용해서 확인해 주면 된다.

def solve(n):
    ...
    elif sqrt(n) ** 2 != n:
        return 2
    else:
        return 1

문제 풀이

def three_square(n):
    l = []
    while n > 1:
        d = pollard_rho(n)
        l.append(d)
        n //= d

세 제곱수 판별 부분에서는 시간 제한이랑 주어지는 수의 범위가 저러니 폴라드 로밀러 라빈 알고리즘을 사용해서 소인수분해했다.

    c = list(Counter(l).items())
    for i, n in c:
        if i % 4 == 3 and n % 2 == 1:
            return True
    return False

collectionsCounter를 사용하여 계산해주었다.

그 외에는 단순 구현인데.. 제곱수 판별 부분에서 한참을 삽질했다.

>>> m = sqrt(1234)
>>> m ** 2
1234.0000000000002

콘솔에도 저렇게 뜨길래 처음엔 아무 생각 없이 코드를 위처럼 적었었다.

>>> sqrt(567) ** 2
567.0
>>> sqrt(1333) ** 2
1333.0

그런데 반례가 생각보다 많았다.

    ...
    elif int(sqrt(n)) ** 2 != n:
        return 2
    else:
        return 1

위와 같이 해결했다. 근삿값 관련 문제였던 듯.

전체 소스

from random import randrange
from sys import stdin, setrecursionlimit
from math import gcd, sqrt
from collections import Counter

setrecursionlimit(10 ** 4)
input = stdin.readline


def powmod(a, e, m):
    ret = 1
    t = a % m
    while e > 0:
        if e & 1:
            ret = ret * t % m
        t = t * t % m
        e >>= 1
    return ret


def miller_rabin(n, a):
    d = n - 1
    while d % 2 == 0:
        if powmod(a, d, n) == n - 1:
            return True
        d //= 2
    t = powmod(a, d, n)
    return t == n - 1 or t == 1


def is_prime(n):
    if n == 1 or n % 2 == 0:
        return False

    for a in [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37]:
        if n == a:
            return True
        if not miller_rabin(n, a):
            return False
    return True


def pollard_rho(n):
    if is_prime(n):
        return n

    if n == 1:
        return 1
    if n % 2 == 0:
        return 2

    x = randrange(2, n)
    y = x
    c = randrange(1, n)
    d = 1

    while d == 1:
        x = ((x ** 2 % n) + c + n) % n
        y = ((y ** 2 % n) + c + n) % n
        y = ((y ** 2 % n) + c + n) % n
        d = gcd(abs(x - y), n)

        if d == n:
            return pollard_rho(n)
    if is_prime(d):
        return d
    return pollard_rho(d)


def four_square(n):
    while n % 4 == 0:
        n //= 4
    return n % 8 == 7


def three_square(n):
    l = []
    while n > 1:
        d = pollard_rho(n)
        l.append(d)
        n //= d

    c = list(Counter(l).items())
    for i, n in c:
        if i % 4 == 3 and n % 2 == 1:
            return True
    return False


def solve(n):
    if four_square(n):
        return 4
    elif three_square(n):
        return 3
    elif int(sqrt(n)) ** 2 != n:
        return 2
    else:
        return 1


print(solve(int(input())))

구현은 쉬운 문제다.

profile
나개발자아님

0개의 댓글