누적 합

Noah·2024년 12월 11일

알고리즘

목록 보기
8/20

구간 합 알고리즘

구간 합 알고리즘은 n차원 배열에서 구간 합을 구하는 알고리즘으로 누적 합 배열을 빌드할때는 O(n)O(n), 구간 합을 구할때는 O(1)O(1)의 시간복잡도를 가집니다.

예를 들어 다음과 같은 배열이 있다고 합시다.

0123456
3456789

이때 2번~4번의 합을 구하면, 5 + 6 + 7 = 18 입니다. 이때 구간 합 알고리즘을 사용하지 않고 구한다면, 시간복잡도는 O(n),n:rangeO(n), n :range라고 할 수 있습니다. 그러나 구간 합 알고리즘을 사용한다면, O(1)O(1)에 구할 수 있습니다.

주의 할 점이, 구간 합 알고리즘은 업데이트가 불가하다는 점입니다. 업데이트를 하려면 세그먼트 트리나 펜윅 트리를 사용해야합니다.

구간 합 알고리즘 구현

구간 합 알고리즘은 누적 합을 사용합니다.

arrn=arrn+arrn1arr_n=arr_n+arr_{n-1}

위의 과정을 원래의 배열에 적용합니다. 따라서 누적 합 배열은 다음과 같습니다.

0123456
33 + 4 = 77 + 5 = 1212 + 6 = 1818 + 7 = 2525 + 8 = 3333 + 9 = 42

구간 합은 다음과 같이 구할 수 있습니다.

k=ijarrk=arrjarri1\sum_{k=i}^jarr_k=arr_j-arr_{i-1}

왜냐하면, 예를 들어 2~4번의 구간 합을 구한다고 하면, 0 ~ 4번의 누적 합에서 0~1번의 누적 합을 뺀 것과 동일하기 때문입니다.

Python 코드

arr = [3, 4, 5, 6, 7, 8, 9]
for i in range(1, len(arr)):
    arr[i] += arr[i-1]

print(arr[4]-arr[1])

2차원에서의 구간 합

2차원에서의 구간 합은 약간 복잡합니다. 예를 들어 다음과 같이 배열이 있다고 해봅시다. 일단 구간 합을 구하려면, 누적 합 먼저 처리해야합니다.

index0123
04567
13456
22345
31234

여기서 첫 행과 첫 열은 1차원 누적합 처럼 처리합니다.

index0123
0491522
17456
29345
310234

그러면 첫 행과 첫 열이 아닌 곳은 어떻게 처리할까요? (1, 1)을 봐보겠습니다.

(1, 1)에서 위의 원소는 9이고, 왼쪽 원소는 7입니다. 그리고 이는 배열의 (0, 0) + (0, 1) 과 (0, 0) + (1, 0)을 합한 값입니다. 그리고 (1, 1)은 (0, 0) + (0, 1) + (1, 0) + (1, 1) 입니다. 위의 원소와 왼쪽 원소에서 (0, 0) + (0, 1), (0, 0) + (1, 0)을 얻을 수 있습니다. 이때, 겹치는 원소가 있는데, 이것을 빼준다면 2차원에서 누적 합을 구할 수 있습니다. 즉, (0, 0) + (0, 1) + (0, 0) + (1, 0) - (0, 0) 입니다. 이를 일반화하면,

arrij=arr(i1)j+arri(j1)arr(i1)(j1)arr_{ij}=arr_{(i-1)j}+arr_{i(j-1)}-arr_{(i-1)(j-1)}

입니다.

index0123
0491522
17162740
29213654
310244264

이제 구간 합을 구해보겠습니다. 만약 (1, 1) ~ (2, 2)의 값을 구한다고 해보면, 일단 (2, 2)에서 시작합니다. (2, 2)는 (0, 0)부터 (2, 2) 까지의 합인데, 이때 필요없는 값들을 제거합니다. 따라서 0번째 열과 0번째 행을 제거해야합니다. 그런데, 이때도 겹치는 부분이 있습니다. 바로 (0, 0) 번째 원소인데, 이 값을 더해주면 2차원에서 구간 합을 구할 수 있습니다. 이를 일반화하면, (i, k)~(j, l)까지의 구간 합은

arrjlarr(i1)larrj(k1)+arr(j1)(k1)arr_{jl}-arr_{(i-1)l}-arr_{j(k-1)}+arr_{(j-1)(k-1)}

입니다.

Python 코드

arr = [[4, 5, 6, 7], [3, 4, 5, 6], [2, 3, 4, 5], [1, 2, 3, 4]]
for i in range(len(arr)):
    for j in range(len(arr)):
        if i == j == 0:
            continue
        if i == 0:
            arr[i][j] += arr[i][j-1]
        elif j == 0:
            arr[i][j] += arr[i-1][j]
        else:
            arr[i][j] += arr[i-1][j] + arr[i][j-1] - arr[i-1][j-1]

print(arr[2][2]-arr[0][2]-arr[2][0]+arr[0][0])
profile
부산소프트웨어마이스터고 4기 | 자세한 내용은 홈페이지(노션)의 테크 블로그에서 확인할 수 있습니다.

0개의 댓글