#include<iostream>
#include<vector>
#include<queue>
#include<string>
#include<algorithm>
using namespace std;
typedef long long ll;
typedef pair<int, int> pii;
// a는 b보다 크고, b는 c보다 크다.
// 1은 탈락
// 소인수분해 했을 때 나오는 약수들은 탈락
// 나누는 수가 나머지와 같을 수는 없음
// 즉, b는 c와 같지 않음.
// b를 고정했을 때, a는 b의 배수만 아니면 됨. (c!=0이면 됨)
// N-b+1에서 b의 배수의 개수만큼 빼면 된다는 이야기
// 그런데 b를 1부터 N까지 연산하면 시간 초과.
// b를 N으로 나눴을 때 나머지를 버린 수를 q라고 할 때,
// q - 1이 빼야 하는 개수
// 그런데 q가 같은 수가 여러 개임. 이들을 구간 별로 묶어서 연산하면 됨.
// b = 1부터 N-1까지 (a,b)의 경우의 수를 더해놓은 뒤,
// q가 같은 구간을 곱해놓은 후 개수를 곱해 빼기 시작한다.
// l = 1부터 시작.
// q가 바뀌는 다음 구간은 N/q + 1임. r을 N/q + 1로 설정.
// [l,r)의 구간은 q - 1만큼을 빼야 함. 즉, (r-l)(q-1)을 빼고 다음 구간으로 넘어감.
// l = r로 할당하고 같은 연산을 반복함.
// q는 구간이 올라감에 따라 작아짐. q = 1일때, r = N+1임.
// 이 구간은 q = 1이므로 뺄 것도 없음. l = N+1이 되어 l > N이면 반복을 끝내도록 설정.
ll N;
const ll m = 998244353;
ll SumOfExceptCases()
{
ll sum = 0;
ll l = 1;
while (l < N)
{
ll q = N / l;
ll r = N / q + 1;
ll add = ((r - l) % m) * ((q - 1) % m) % m;
sum = (sum + add) % m;
l = r;
}
return sum;
}
int main()
{
cin >> N;
ll inv2 = (m + 1) / 2; // = 499122177
ll a = N % m;
ll b = (N - 1) % m;
ll res = a * b % m * inv2 % m;
res = (res - SumOfExceptCases() + m) % m;
cout << res;
}
Editorial 이해한 대로 코드 작성.
ll res = a * b % m * inv2 % m; 이 부분은 강력하게 gpt 도움 받음.
오버플로우 때문에 이러는 건 알고 있었는데,
inv2를 페르마의 소정리로 바꾸면 m승을 또 구해야 했음.
근데 gpt는 아무렇지도 않게 inv2라는 걸 가져와서 씀.
M이 2가 아닌 소수일 때 2의 역원은 (M + 1) / 2 이기 때문.
3의 역원은 페르마의 소정리 써야됨. 소수여도 (M + 1) / 3이 성립한다는 보장이 없기 때문.
이걸 굳이 공식과 증명을 다 외우고 다닐 이유는 없고, 옵시디언에 적어두고 필요할 때 쓰면 될듯.
평생 필요할 일 없을 것 같긴 함 근데.
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll mod = 998244353;
ll n, ans;
int main()
{
ios::sync_with_stdio(false);
cin.tie(nullptr);
cin >> n;
ans = (n % mod) * ((n + 1) % mod) % mod * ((mod + 1) / 2) % mod;
for (ll l = 1, r; l <= n; l = r + 1)
{
r = n / (n / l);
ans = (ans + mod - (r - l + 1) % mod * (n / l) % mod) % mod;
}
cout << ans << '\n';
return 0;
}
그와 별개로
다른 답 봤는데 진짜 짧고 간단하게 품.
로직은 동일한데,
while을 for문으로 바꿔서 압축,
함수 호출도 안 해서 압축,
빼는 것도 한꺼번에 모았다가 따로 하는게 아니라 바로바로 빼서 압축.
int main()
{
ll res = 0, N, r, m = 998244353;
cin >> N;
for (ll l = 1; l <= N; l = r)
{
ll q = N / l;
r = N / q + 1;
res = (res + ((N - l)+(N - r+1))*(r-l)/2 - (r - l) * (q - 1)) % m;
}
cout << res;
}
살짝 수정하여 inv2가 없는 코드를 짜보려고 했음.
합을 미리 구하는게 아니라 뺄 때 더하는 것도 같이 하는 걸로 바꿈
하지만 평균을 구하는 과정에서 결국 inv2가 들어가야하고,
무엇보다 연산력이 낭비됨.