썸네일 출처: https://ye-yo.github.io/thumbnail-maker/
구간합 구하는 방법을 응용해서 푸는 문제이다.
A1, A2, ~, An이 있을 때 가장 빠르게 구간의 합 S[i, j]를 구하는 방법은 Ai부터 Aj까지 다 더하는 것이다.
하지만 구간 합을 여러개 구해야 한다면 이 방식은 중복되는 계산을 너무 많이 하게 된다.
따라서 0번째 수부터 i번째 수까지 모두 더한 값을 저장하는 배열 P를 만들어놓고 다음과 같이 구간합을 빠르게 계산하는 방법을 쓴다.
S[i, j] = P[j] - P[i]
이렇게 하면 초기에 P배열을 만드는 데 드는 비용을 제외하면 모든 구간 계산이 한 번의 덧셈으로 끝나므로 훨씬 효율적이 된다.
어떤 수가 다른 수로 나누어 떨어지는지 확인하기 위해서는 나머지 연산이 0이 나오는지 확인하면 되므로, 우리는 구간합의 나머지가 0인지 확인하면 된다.
하지만 구간의 모든 수를 다 더한 뒤 나머지를 구하면 더한 값이 너무 커져 오버플로우가 날 수 있다. 따라서 다음의 특성을 이용해 오버플로우가 나지 않도록 나머지를 구한다.
두 수의 합을 어떤 수로 나눈 나머지는 각각의 수를 어떤 수로 나눈 나머지를 더한 것의 나머지와 같다. 수식으로 쓰면 다음과 같다.
A, B 를 각각 C로 나눈 나머지가 Ra, Rb일 때,
A + B의 나머지 R = (Ra + Rb) mod C
즉, 원래 수를 모두 더한 후 총합을 나눈 나머지와, 각각의 수의 나머지를 더한 후 나머지의 총합을 나눈 나머지가 같다. 굳이 배열에 구간의 총합을 저장할 필요 없이 바로 나머지를 저장해도 된다는 것이다.
앞서 말한 0부터 i까지 수의 합을 담는 배열 P에 수의 합 대신 수의 합을 M으로 나눈 나머지를 저장할 것이다. 이렇게 하면 S[i, j]는 i부터 j까지 구간합을 M으로 나눈 나머지가 된다.
이제 S[i, j] = 0이 되는 경우의 수를 구해야 한다. 가장 쉬운 방법은 모든 i, j에 대해 S[i, j]가 0인지를 확인하는 것이다. 하지만 이 경우 시간복잡도는 O(n^2)가 되는데, 문제에서 n의 최댓값이 10^6이므로 총 10^12회에 달하는 비교를 해야 한다.
더 빠른 방식을 찾아내기 위해, 문제에서 주어진 예제에서 j = 4일때 어떻게 경우의 수를 구하는지 생각해보자.
우선 위의 방식으로 P배열을 구하면 다음과 같다.
이제 j = 4일때 S[i, j] = 0이 되는 구간을 구해보자. j >= i이므로 i는 4 이하이고, P[4] = 1이기 때문에 P[i] = 1이어야 한다. 따라서 j = 4일때 M으로 나누어 떨어지는 부분합의 개수는 자신보다 앞에 있는 P[i]중 값이 1인 것의 개수이다.
이제 j 를 1부터 n까지 증가시키며 각각의 경우의 수의 합을 구하면 된다.
하지만, 여기서 또 문제가 생긴다. 자신과 값이 같은 P[i]를 구하기 위해 총 j번의 비교를 하게 되는데, 이렇게 되면 총 시간 복잡도는 또다시 O(n^2)가 된다.
그래서, 그동안 나온 P[i]의 수를 저장할 수 있는 배열 Q를 만들었다. 매번 이전 인덱스들을 하나씩 확인하며 카운터를 증가시키는 것 대신에 P[j]값에 해당하는 Q의 칸에 카운터를 하나씩 증가시킨다. 이렇게 하면 이전까지의 특정 수가 몇 번 나왔는지가 저장되므로 이전 인덱스들을 모두 확인할 필요 없이 바로 경우의 수를 계산할 수 있다. 이 경우 시간복잡도는 O(n)이 된다.
Q배열의 크기는 등장 가능한 모든 나머지의 수이므로 M이다. M의 최댓값은 10^3이고, 각 카운터의 크기가 4바이트이므로 4KB만 추가로 차지하게 된다. 따라서 메모리 문제도 걱정 없다.
use std::cmp::{max, min, Ordering};
use std::io::{stdin, Read, Write, stdout, BufWriter};
use std::fmt::Write as fWrite;
const N_MAX: usize = 1_000_000;
const M_MAX: usize = 1_000;
fn main() {
let mut input = String::new();
stdin().read_to_string(&mut input).unwrap();
let mut input = input.split_ascii_whitespace().flat_map(str::parse::<usize>);
let (n, m) = (input.next().unwrap(), input.next().unwrap());
let mut counts = [0usize; M_MAX];
let mut sum = 0usize;
let mut range_sum = 0usize;
counts[0] = 1;
for i in 1..=n {
range_sum = (range_sum + input.next().unwrap()) % m;
sum += counts[range_sum];
counts[range_sum] += 1;
}
println!("{sum}");
}