우선 간단한 완전 탐색 알고리즘으로 생각해 보겠습니다. 모든 신호를 만들고 앞에서부터 차례차례 구간합을 구해서 K를 만들 수 있는지 검사하고, K를 넘어가면 다음 신호부터 구간합을 검사하면 됩니다. 아래는 이 아이디어를 구현한 코드입니다.
public static int simple(final ArrayList<Integer> signals, int k) {
int result = 0;
for (int head = 0; head < signals.size(); head++) {
int sum = 0;
for (int tail = head; tail < signals.size(); tail++) {
// sum은 [head, tail] 구간의 합이다.
sum += signals.get(tail);
if (sum == k) result++;
if (sum >= k) break;
}
}
return result;
}
시간 복잡도를 계산해보면 인데(숫자들이 모두 1이면 최대 K개까지 더해서 비교해야 하므로) N의 최댓값이 5천만이므로 딱봐도 안됩니다. 그런데 여기서 한 가지 통찰을 얻으면 더 최적화된 알고리즘을 만들 수 있습니다. 바로 이전 구간합의 끝 부분의 위치는 앞으로 돌아가지 않는다는 것입니다. 즉, 위 코드에서 tail에 저장되는 값은 더 낮아지지 않습니다. 왜냐하면 head가 증가했는데 tail이 감소했다면 이 후보 구간은 이전 후보 구간의 부분 구간이 됩니다. 이 구간의 합이 이미 K 이상이라면 이전 후보 구간은 더 일찍 끝났어야 하므로, 이런 경우는 있을 수가 없습니다.
여기서 tail이 증가하기 전엔 구간합이 K 미만이였다가 tail이 증가해서 K를 초과했을 때 초과한 값이 head의 위치에 저장된 값보다 크면 head가 증가해도 부분 구간 합이 K를 초과하지 않나? 라는 의문이 들 수 있습니다. 예를 들어, K=7이고 현재의 구간의 숫자들이 {2, 1, 3}일 때 다음 숫자가 4라면 {2, 1, 3, 4}가 되므로 K를 초과하므로 head를 증가시켜 {1, 3, 4}가 됩니다. 여전히 K를 초과합니다. 하지만 이럴 경우 tail을 감소시킬 필요가 없습니다. 왜냐하면 tail이 감소되면 이전 구간이였던 {2, 1, 3}의 부분 구간이되므로 {2, 1, 3} 구간의 합보다 클 수 없으므로 검사할 필요가 없습니다. 아래의 코드는 이 통찰을 이용해 코드를 구현했습니다.
public static int optimized(final ArrayList<Integer> signals, int k) {
int result = 0, tail = 0, rangeSum = signals.get(0);
for (int head = 0; head < signals.size(); head++) {
// rangeSum이 k 이상인 최초의 구간을 만날 때까지 tail을 옮긴다.
while (rangeSum < k && tail + 1 < signals.size()) {
rangeSum += signals.get(++tail);
}
if (rangeSum == k) result++;
// signals[head]는 이제 구간에서 빠진다.
rangeSum -= signals.get(head);
}
return result;
}
위 코드는 메모리 제한을 초과합니다. N이 5천만 일때 N을 short로 저장해도 메모리 제한인 바이트를 초과하기 때문입니다. 따라서 신호를 미리 생성해서 저장하는 오프라인 알고리즘을 사용할 수 없습니다. 온라인 알고리즘으로 대체해야 합니다. 계산 중에 신호를 새로 하나씩 생성하는 전략을 취해야 합니다. 이 전략을 사용할 수 있는 이유는 알고리즘 실행 중엔 구간의 숫자들만 필요로 하기 때문입니다. 아래는 이 아이디어를 구현한 코드입니다.
import java.util.*;
public class Main {
public static int K, N;
public static long[] A;
public static ArrayList<Integer> signals;
public static int result;
static class RNG {
long seed;
RNG() {
seed = 1983;
}
int next() {
long result = seed;
seed = (seed * 214013 + 2531011) % (1L << 32);
return (int) (result % 10000 + 1);
}
}
public static void input(Scanner scanner) {
K = scanner.nextInt();
N = scanner.nextInt();
}
public static void solve() {
result = countRanges(K, N);
}
public static int countRanges(int k, int n) {
RNG rng = new RNG();
Queue<Integer> range = new LinkedList<>();
int result = 0, rangeSum = 0;
for (int i = 0; i < n; i++) {
// 구간에 숫자를 추가한다.
int newSignal = rng.next();
rangeSum += newSignal;
range.offer(newSignal);
// 구간의 합이 k를 초과하는 동안 구간에서 숫자를 뺀다.
while (rangeSum > k) {
rangeSum -= range.poll();
}
if (rangeSum == k) result++;
}
return result;
}
public static void output() {
System.out.println(result);
}
public static void main(String[] args) {
Scanner scanner = new Scanner(System.in);
int testCase = scanner.nextInt();
for (int i = 0; i < testCase; i++) {
input(scanner);
solve();
output();
}
}
}
함수에서 구간 안의 숫자들을 저장하는 방법은 큐를 사용했습니다.
반복문이 두 개라서 언뜻 보기에 이지 않을까 생각이 들 수 있습니다. 하지만 분할 상환 분석을 사용해보면 K는 절대 감소하지 않기 때문에 바깥 반복문이 전체가 실행될 동안 안쪽 반복문이 실행되는 횟수는 최대 N입니다. 이므로 충분히 시간 제한 내에 문제를 해결할 수 있습니다.
솔직히 통찰을 이해하기 어려웠습니다.