
N개 정수가 주어지고, Q개의 명령이 주어질 때, x부터 y까지 합을 출력하고 a번째 수를 b로 바꾸는 문제이다.
N이 100,000으로 부분합을 구할 때 누적합을 이용하면, 수를 갱신할 때 너무 오래 걸린다.
따라서 세그먼트 트리를 이용해서 O(logN)으로 수를 갱신하는 방법을 생각해야 한다.
leaf노드에는 값들이, 그 외의 노드들에는 부분합들을 저장한 자료구조이다.
어떤 노드 x의 왼쪽 자식 노드는 2x, 오른쪽 자식 노드는 2x+1이 된다.
초기화
start와 end가 같아지면 leaf노드이므로 값을 넣는다.
그 외의 경우 왼쪽 자식 노드, 오른쪽 자식 노드를 탐색하고 return 됐을 때 값을 갱신한다.
부분합 구하기
left, right가 구하고자 하는 부분합의 범위라고 할 때,
범위가 벗어나면 0을 return한다.
노드의 구간을 포함하는 경우 더 탐색할 필요가 없으므로 tree[node]를 return한다.
그 외의 경우 왼쪽, 오른쪽의 부분합을 재귀로 구한 다음 두 값을 더하여 return한다.
업데이트
범위를 벗어나는 경우 탐색을 멈춘다.
start와 end가 같아졌을 경우 바꾸고자 하는 노드의 값을 갱신시켜준다.
왼쪽 자식노드와 오른쪽 자식노드도 모두 탐색하고 return한 뒤 tree의 값을 갱신한다.
해결언어 : Python
import sys
input = sys.stdin.readline
import math
n, q = map(int, input().split())
arr = [0]+list(map(int, input().split()))
h = math.ceil(math.log2(n))
tree = [0]*(1 << (h+1))
def segment(node, s, e):
if s == e:
tree[node] = arr[s]
else:
segment(node*2, s, (s+e)//2)
segment(node*2+1, (s+e)//2+1, e)
tree[node] = tree[node*2] + tree[node*2+1]
def query(node, s, e, l, r):
if s > r or e < l:
return 0
if l <= s and e <= r:
return tree[node]
lsum = query(node*2, s, (s+e)//2, l, r)
rsum = query(node*2+1, (s+e)//2+1, e, l, r)
return lsum + rsum
def update(node, s, e, idx, val):
if idx < s or idx > e:
return
if s == e:
arr[idx] = val
tree[node] = val
return
update(node*2, s, (s+e)//2, idx, val)
update(node*2+1, (s+e)//2+1, e, idx, val)
tree[node] = tree[node*2] + tree[node*2+1]
segment(1, 1, n)
for _ in range(q):
x, y, a, b = map(int, input().split())
if x > y: x, y = y, x
print(query(1, 1, n, x, y))
update(1, 1, n, a, b)

끝으로..
새로운 개념인 세그먼트 트리에 대해 배워봤다. 코드가 익숙해지도록 연습해야겠다.