페르마의 소정리는 아래와 같다.p가 소수이고 a가 p의 배수가 아니면 아래의 수식을 만족한다. 아래의 두 수식은 같은 뜻이다.
기억해야 할 문장은 딱 하나다.
소수 p는 위 수식을 만족한다. 고로 위 수식을 만족하지 않으면 합성수이다.
즉, 위 수식을 어떻게 변형해도 위 수식을 만족해야 한다.
이제, 위 수식을 변형시켜보자
을 로 변형시켜 보자.
그럼 수식이 와 같이 된다.
(당연히 위 수식도 만족해야 하고)
다시 위 수식은 와 동치이고 이 수식도 만족해야 한다.
이제부터 핵심이다. 수식을 어떻게 변형해도 여전히 만족해야 p가 소수라고 할 수 있다.(a가 충분히 많다는 가정하에)
는 양변에 루트를 씌워 도 만족해야 한다.
(참고: )
즉, 이 d(d는 홀수)가 될 때 까지 나누어 가면서 모두 만족하는지 검사한다. 이때 우항이 -1(코딩으로는 p-1)이 나오면 더 이상 루트를 씌울 수 없으므로 모든 항을 만족한다는 결과가 나온다.
를 계산하는데 있어서 2를 51번 곱할 필요는 없다.
=
=
=
=
=
위 알고리즘을 간단히 설명하자면 지수가 홀수일 경우 밑을 따로 곱하게 해서 위 수식처럼 전개할 수 있다.
using ull = unsigned long long;
ull fast_pow(ull x, ull y) {
ull tmp = 1;
while (y) {
if (y&1) //if y is odd number
tmp *= x;
y >>=1;
x = x * x;
}
return tmp;
}
코드와 수식을 대응하자면 코드에서의 x가 수식에서 괄호 안에 있는 수가 되겠고, tmp는 괄호 밖의 숫자들의 곱이 되겠다.
를 계산하는데 가 커지면 컴퓨터에서 정수오버플로우가 발생하므로 이런 방식을 사용한다.
모듈러 연산의 특징
(a + b + c) % m = (a%m + b%m + c%m) % m
(a - b - c) % m = (a%m - b%m - c%m) % m
(a * b * c) % m = (a%m * b%m * c%m) % m
곱해지는 숫자 전부에 나머지 연산을 적용해도 결과는 똑같다. 그렇다면 위 코드에서 x와 tmp가 갱신될 때 마다 모듈러 연산을 적용해 주면 를 계산할 수 있다.
using ull = unsigned long long;
ull fast_ipow(ull x, ull y, ull z) {
ull tmp = 1LL;
x %= z;
while (y) {
if (y & 1) //if y is odd number
tmp = ((tmp%z)*(x%z)) % z;
y >>= 1;
x = ((x%z) * (x%z)) % z;
}
return tmp;
}
위 알고리즘을 토대로 밀러라빈 소수판별 알고리즘을 구현하면 아래와 같다.
#include <iostream>
#include <vector>
using uint32 = unsigned int;
using uint64 = unsigned long long;
using uint128 = __uint128_t;
template<typename T>
T safe_mul_mod(T a, T b, T m) {
if constexpr (sizeof(T) == 4) {
return (static_cast<uint64>(a) * b) % m;
} else if constexpr (sizeof(T) == 8) {
return (static_cast<uint128>(a) * b) % m;
} else if constexpr (sizeof(T) == 16) {
if (a != 0 && b != 0 && a > std::numeric_limits<T>::max() / b) {
T r = 0;
a %= m;
while (b > 0) {
if (b & 1) r = (r + a) % m;
a = (a * 2) % m;
b >>= 1;
}
return r;
}
return (a * b) % m;
} else {
return (a * b) % m;
}
}
template<typename T>
T fast_pow(T x, T y, T z) {
T tmp = static_cast<T>(1);
x %= z;
while (y) {
if (y & 1)
tmp = safe_mul_mod(tmp, x, z);
y >>= 1;
x = safe_mul_mod(x, x, z);
}
return tmp;
}
template<typename T>
bool isPrime_MillerRabin(T n) {
if (n < 2) return false;
if (n == 2) return true;
if (n % 2 == 0) return false;
static const std::vector<T> witnesses = {2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37};
T d = n - 1;
T s = 0;
while (d % 2 == 0) {
d /= 2;
++s;
}
for (T a: witnesses) {
if (a >= n) continue;
T x = fast_pow(a, d, n);
if (x == 1 || x == n - 1) continue;
bool composite = true;
for (T r = 0; r < s - 1; r++) {
x = safe_mul_mod(x, x, n);
if (x == n - 1) {
composite = false;
break;
}
}
if (composite) return false;
}
return true;
}