시간 초과 코드는 아래와 같다.
#include <iostream>
#include <vector>
using namespace std;
int main() {
ios::sync_with_stdio(0);
cin.tie(0);
int n, m;
cin >> n >> m;
vector<int> arr(n + 1, 0);
for (int i = 1; i <= n; i++)
cin >> arr[i];
vector<long long> v = arr;
for (int i = 2; i <= n; i++)
v[i] += v[i - 1];
long long cnt = 0;
for (int i = 1; i <= n; i++) {
for (int j = i; j <= n; j++) {
if ((v[j] - v[i - 1]) % m == 0)
cnt++;
}
}
cout << cnt;
return 0;
}
수의 범위로 봐서 당연히 시간 초과가 날 수 밖에 없다.
그래서 뭔가 신박한 생각을 해내야만 풀 수 있는 문제임을 알았다.
문제 풀이는 아래 수식으로부터 시작한다.
if ((누적합[j] - 누적합[i - i]) mod M == 0) cnt++;
여기서 누적합[index]
는 배열[0]부터 배열[index]까지의 합
이다.
따라서 구간 (i, j)의 누적합은 누적합[j] - 누적합[i - i]
으로 나타낼 수 있다.
즉, 저 i부터 j까지 누적합이 M으로 나누어떨어지는 횟수를 카운팅하는 식이다.
수식을 길게 풀어헤쳐보자.
if (누적합[j] mod M == 누적합[i - 1] mod M) cnt++
즉, 어떤 두 인덱스 i, j까지의 누적합을 각각 M으로 나눴을 때 나머지가 서로 같으면
문제의 의도와 정확히 같은 조건을 의미하게 되는 것이다.
그런데 쓰면 쓸수록 아무도 못알아먹겠다는 확신이 든다.
만약 궁금하시다면 아래 코드를 한 줄 한 줄 디버깅해보고
또 그림판을 이용해서 얘가 뭔 말을 하나 알아먹어주면 참 좋을 것 같다.
#include <iostream>
#include <vector>
using namespace std;
int main() {
ios::sync_with_stdio(0);
cin.tie(0);
int n, m;
cin >> n >> m;
vector<int> v(m, 0);
long long sum = 0;
long long cnt = 0;
v[0] = 1;
for (int i = 0; i < n; i++) {
int num;
cin >> num;
sum = (sum + num) % m;
cnt += v[sum];
v[sum]++;
}
cout << cnt;
return 0;
}