느낀점
문제 설명 최근 누적합 문제를 많이 풀고 있습니다. 그런 류의 문제가 요새 좀 많이 등장하는 것 같더군요. 문제는 되게 단순합니다. 하지만, 생각없이 코드짜면 바로 시간 초과나는 문제입니다.
문제 아이디어
당연히 누적합을 구해야합니다. 하지만, 나머지에 초점이 맞춰져 있는 문제이기 때문에 누적합을 하면서 m으로 나누어 떨어지는 인덱스를 기록하면서 쭉 순환하되, 각각의 나머지가 몇번 등장하는지 알아야 합니다. 그래야 문제에서 요구하는 모든 구간에 대한 합이 m과 나누어 떨어지는지 알 수 있습니다. 문제에서 주어진예시를 보고 좀 더 자세히 설명드려보겠습니다.
5 3
1 2 3 1 2
위의 예시에서 누적으로 합하면서 3으로 나누어 떨어지는 인덱스는
1, 2, 4 입니다.
이 3개 중 아무거나 2개 택하면, 택한 모든 구간은 3으로 나누어 떨어집니다.
자 이런 메커니즘으로 3으로 나눴을 때 나머지가 1인 구간은
0, 3밖에 없습니다.
이때는 딱 한 쌍만 나오네요. i = 1, j = 3을 택한 경우와 동일합니다.
감이 오신분은 아마 아 그러면 나머지가 동일한 부분만 찾아서 조합의 수를 체크하면 되겠다! 라고 생각드실겁니다.
네 바로 그거에요.
그렇기 때문에, 나머지가 동일한 부분을 빠르게 찾기 위해서 하나의 배열을 더 만듭니다.
어차피 m으로 나누기 때문에, 원소의 개수가 m인 것을 따로 만들면 됩니다.
그렇게 모든 순회를 돌면서 진행하게 됩니다.
즉 위 문제를 O(n+m)
으로 끝내게 됩니다. 아래는 풀이 코드에요!
import sys
n, m = map(int, input().split())
arr_n = list(map(int, sys.stdin.readline().split()))
arr_m = [0 for _ in range(m)]
s = 0
num = 0
for i in range(n):
s += arr_n[i]
res = s % m
if res == 0:
num += 1
arr_m[res] += 1
for i in range(m):
if arr_m == 0:
continue
num += arr_m[i] * (arr_m[i] - 1) / 2
print(int(num))