2진법 인덱스 구조를 활용해 구간 합 문제를 효과적으로 해결해 줄 수 있는 자료구조를 의미합니다.
펜윅 트리(Fenwick Tree)라고도 합니다.
세그먼트 트리의 한 종류로 더 빠르고 효율적으로 동작합니다.
0이 아닌 마지막 비트를 찾는 방법은 특정한 숫자 K의 0이 아닌 마지막 비트를 찾기 위해서 (K & -K)를 계산하면 됩니다.
배열에 부분 합을 구할 때 사용하는 개념입니다.
제일 아래 리프 노드로 달린 것들이 실제 우리가 처음 받아온 데이터들을 의미합니다.
부모 노드 값은 아래 자식 노드 값들의 합입니다.
기존 데이터 배열의 크기를 이라 하면, 리프 노드의 개수가 이 되고, 트리의 높이 는 이 되고, 배열의 크기는 이 됩니다.
인덱스 a ~ b까지의 구간 합을 구하려면 { (인덱스 1부터 b까지의 구간 합) - (인덱스 1부터 a-1까지의 구간 합) }을 계산하면 됩니다.
바이너리 인덱스 트리를 구성하면 해당 인덱스의 2진수에서 제일 마지막(제일 오른쪽)에 있는 1을 1씩 빼주면 필요한 구간들이 나옵니다.
예를 들어 2에서 7까지의 구간합 [2,7]을 구해보겠습니다.
먼저 7의 2진수 0111에서 처음 구간 [7]과 1을 뺀 0110인 6의 구간 [5,6] 그리고 또 1을 뺀 0100인 4의 구간 [1,4]를 모두 더하면 [1,7]을 구할 수 있습니다.
그리고 2의 2진수 0010의 처음 구간 [1,2]를 빼주면 구간합 [2,7]을 구할 수 있습니다.
위의 계산 결과 표를 보면 K가 7일 때, 2진수 표기시 가장 마지막 1의 위치가 1의 자리인 것을 알 수 있습니다. 즉 K & -K를 통해서 해당 인덱스가 가지고 있는 구간 합 범위를 알 수 있습니다.
K의 K & -K값이 1이면 해당 인덱스는 K값 하나만 가지고 있습니다. 2이면 (K-1 + K)의 값을 가지고 있습니다. 4이면 (K-3 + K-2 + K-1 + K)의 값을 가집니다.
어떤 N개의 수가 주어져 있다. 그런데 중간에 수의 변경이 빈번히 일어나고 그 중간에 어떤 부분의 합을 구하려 한다.
만약에 1,2,3,4,5 라는 수가 있고, 3번째 수를 6으로 바꾸고 2번째부터 5번째까지 합을 구하라고 한다면 17을 출력하면 되는 것이다. 그리고 그 상태에서 다섯 번째 수를 2로 바꾸고 3번째부터 5번째까지 합을 구하라고 한다면 12가 될 것이다.
첫째 줄에 수의 개수 N(1 ≤ N ≤ 1,000,000)과 M(1 ≤ M ≤ 10,000), K(1 ≤ K ≤ 10,000) 가 주어진다. M은 수의 변경이 일어나는 횟수이고, K는 구간의 합을 구하는 횟수이다.
그리고 둘째 줄부터 N+1번째 줄까지 N개의 수가 주어진다. 그리고 N+2번째 줄부터 N+M+K+1번째 줄까지 세 개의 정수 a, b, c가 주어지는데, a가 1인 경우 b(1 ≤ b ≤ N)번째 수를 c로 바꾸고 a가 2인 경우에는 b(1 ≤ b ≤ N)번째 수부터 c(b ≤ c ≤ N)번째 수까지의 합을 구하여 출력하면 된다.
입력으로 주어지는 모든 수는 -263보다 크거나 같고, 263-1보다 작거나 같은 정수이다.
값이 중간에 계속 변경되면서 구간 합을 구하므로 바이너리 인덱스 트리 알고리즘을 사용하는 문제입니다.
최악의 경우에도 의 시간 복잡도를 보장합니다.
import sys
input = sys.stdin.readline
# 데이터의 개수(n), 변경 횟수(m), 구간 합 계산 횟수(k)
n, m, k = map(int, input().split())
# 전체 데이터의 개수는 최대 1,000,000개
arr = [0] * (n + 1)
tree = [0] * (n + 1)
# i번째 수까지의 누적 합을 계산하는 함수
def prefix_sum(i):
result = 0
while i > 0:
result += tree[i]
# 0이 아닌 마지막 비트만큼 빼가면서 이동
i -= (i & -i)
return result
# i번째 수를 dif만큼 더하는 함수
def update(i, dif):
while i <= n:
tree[i] += dif
i += (i & -i)
# start부터 end까지의 구간 합을 계산하는 함수
def interval_sum(start, end):
return prefix_sum(end) - prefix_sum(start - 1)
for i in range(1, n + 1):
x = int(input())
arr[i] = x
update(i, x)
for i in range(m + k):
a, b, c = map(int, input().split())
# 업데이트(update) 연산인 경우
if a == 1:
update(b, c - arr[b]) # 바뀐 크기(dif)만큼 적용
arr[b] = c
# 구간 합(interval sum) 연산인 경우
else:
print(interval_sum(b, c))