[백준 10986] 나머지 합 - Rust로 알고리즘 풀기

이승규·2022년 12월 18일
1
post-thumbnail

썸네일 출처: https://ye-yo.github.io/thumbnail-maker/


📖 문제

10986번: 나머지 합

구간합 구하는 방법을 응용해서 푸는 문제이다.


💡 아이디어

1. 구간합 빠르게 구하기

A1, A2, ~, An이 있을 때 가장 빠르게 구간의 합 S[i, j]를 구하는 방법은 Ai부터 Aj까지 다 더하는 것이다.
하지만 구간 합을 여러개 구해야 한다면 이 방식은 중복되는 계산을 너무 많이 하게 된다.
따라서 0번째 수부터 i번째 수까지 모두 더한 값을 저장하는 배열 P를 만들어놓고 다음과 같이 구간합을 빠르게 계산하는 방법을 쓴다.

S[i, j] = P[j] - P[i]

이렇게 하면 초기에 P배열을 만드는 데 드는 비용을 제외하면 모든 구간 계산이 한 번의 덧셈으로 끝나므로 훨씬 효율적이 된다.

2. 구간합이 나누어 떨어지는지 판별하기

어떤 수가 다른 수로 나누어 떨어지는지 확인하기 위해서는 나머지 연산이 0이 나오는지 확인하면 되므로, 우리는 구간합의 나머지가 0인지 확인하면 된다.
하지만 구간의 모든 수를 다 더한 뒤 나머지를 구하면 더한 값이 너무 커져 오버플로우가 날 수 있다. 따라서 다음의 특성을 이용해 오버플로우가 나지 않도록 나머지를 구한다.

두 수의 합을 어떤 수로 나눈 나머지는 각각의 수를 어떤 수로 나눈 나머지를 더한 것의 나머지와 같다. 수식으로 쓰면 다음과 같다.

A, B 를 각각 C로 나눈 나머지가 Ra, Rb일 때,
A + B의 나머지 R = (Ra + Rb) mod C

즉, 원래 수를 모두 더한 후 총합을 나눈 나머지와, 각각의 수의 나머지를 더한 후 나머지의 총합을 나눈 나머지가 같다. 굳이 배열에 구간의 총합을 저장할 필요 없이 바로 나머지를 저장해도 된다는 것이다.

앞서 말한 0부터 i까지 수의 합을 담는 배열 P에 수의 합 대신 수의 합을 M으로 나눈 나머지를 저장할 것이다. 이렇게 하면 S[i, j]는 i부터 j까지 구간합을 M으로 나눈 나머지가 된다.

3. 빠르게 경우의 수 계산하기

이제 S[i, j] = 0이 되는 경우의 수를 구해야 한다. 가장 쉬운 방법은 모든 i, j에 대해 S[i, j]가 0인지를 확인하는 것이다. 하지만 이 경우 시간복잡도는 O(n^2)가 되는데, 문제에서 n의 최댓값이 10^6이므로 총 10^12회에 달하는 비교를 해야 한다.

더 빠른 방식을 찾아내기 위해, 문제에서 주어진 예제에서 j = 4일때 어떻게 경우의 수를 구하는지 생각해보자.
우선 위의 방식으로 P배열을 구하면 다음과 같다.

이제 j = 4일때 S[i, j] = 0이 되는 구간을 구해보자. j >= i이므로 i는 4 이하이고, P[4] = 1이기 때문에 P[i] = 1이어야 한다. 따라서 j = 4일때 M으로 나누어 떨어지는 부분합의 개수는 자신보다 앞에 있는 P[i]중 값이 1인 것의 개수이다.

이제 j 를 1부터 n까지 증가시키며 각각의 경우의 수의 합을 구하면 된다.

하지만, 여기서 또 문제가 생긴다. 자신과 값이 같은 P[i]를 구하기 위해 총 j번의 비교를 하게 되는데, 이렇게 되면 총 시간 복잡도는 또다시 O(n^2)가 된다.

그래서, 그동안 나온 P[i]의 수를 저장할 수 있는 배열 Q를 만들었다. 매번 이전 인덱스들을 하나씩 확인하며 카운터를 증가시키는 것 대신에 P[j]값에 해당하는 Q의 칸에 카운터를 하나씩 증가시킨다. 이렇게 하면 이전까지의 특정 수가 몇 번 나왔는지가 저장되므로 이전 인덱스들을 모두 확인할 필요 없이 바로 경우의 수를 계산할 수 있다. 이 경우 시간복잡도는 O(n)이 된다.

Q배열의 크기는 등장 가능한 모든 나머지의 수이므로 M이다. M의 최댓값은 10^3이고, 각 카운터의 크기가 4바이트이므로 4KB만 추가로 차지하게 된다. 따라서 메모리 문제도 걱정 없다.


✏️ 코드

use std::cmp::{max, min, Ordering};
use std::io::{stdin, Read, Write, stdout, BufWriter};
use std::fmt::Write as fWrite;

const N_MAX: usize = 1_000_000;
const M_MAX: usize = 1_000;

fn main() {
    let mut input = String::new();
    stdin().read_to_string(&mut input).unwrap();
    let mut input = input.split_ascii_whitespace().flat_map(str::parse::<usize>);
    let (n, m) = (input.next().unwrap(), input.next().unwrap());
    let mut counts = [0usize; M_MAX];
    let mut sum = 0usize;
    let mut range_sum = 0usize;
    counts[0] = 1;
    for i in 1..=n {
        range_sum = (range_sum + input.next().unwrap()) % m;
        sum += counts[range_sum];
        counts[range_sum] += 1;
    }

    println!("{sum}");
}
profile
웹 풀스택 개발 공부중입니다.

0개의 댓글