복습 요망!!
문제의 포인트는 세그먼트 트리(Segment Tree)
를 사용하는 것이다. 세그먼트 트리
에 잘 모른다면 해당 글을 참고하면 좋을 것이다. 왜 세그먼트 트리
를 사용할까?
문제에서 주어진 2가지 연산을 생각해보자.
구간 b, c가 주어질 때, arr[b] + ... + arr[c]
의 방식으로 하나씩 전부 더해주기
b번째 수를 c로 바꿀 때, arr[b] = c
로 바꾸기
M은 수의 변경이 일어나는 횟수이고, K는 구간의 합을 구하는 횟수이다. 1번 연산을 수행할 때 O(N)
이고, 2번 연산을 수행할 때 O(1)
의 시간이 걸린다. 즉, O(K*N) + O(M*1) = O(KN)
의 시간 복잡도로 1,000,000 * 10,000 = 10,000,000,000(100억)
으로 2초안에 풀 수 없어 시간 초과가 발생한다.
세그먼트 트리를 사용하면, 1번 연산과 2번 연산 모두 O(logN)
만에 수행할 수 있다. 따라서 총 O((K+M)logN)
의 시간 복잡도로 이전의 방법보다 훨씬 더 빠르게 구간 합들을 구할 수 있게된다.
# 세그먼트 트리 생성
def init(start, end, index):
if start == end:
segment_tree[index] = arr[start]
return segment_tree[index]
mid = (start + end) // 2
segment_tree[index] = init(start, mid, index * 2) + init(mid + 1, end, index * 2 + 1)
return segment_tree[index]
# 세그먼트 트리에서 조건에 맞는 구간 합 구하기
def interval_sum(start, end, index, left, right):
if left > end or right < start:
return 0
if left <= start and right >= end:
return segment_tree[index]
mid = (start + end) // 2
return interval_sum(start, mid, index * 2, left, right) + interval_sum(mid + 1, end, index * 2 + 1, left, right)
# 배열의 원소가 변경되었을 때 세그먼트 트리도 변경하기
def update(start, end, index, what, value):
if what < start or what > end:
return
segment_tree[index] += value
if start == end:
return
mid = (start + end) // 2
update(start, mid, index * 2, what, value)
update(mid + 1, end, index * 2 + 1, what, value)
import sys
input = sys.stdin.readline
N, M, K = map(int, input().split())
arr = []
for _ in range(N):
arr.append(int(input()))
segment_tree = [0] * (N * 4)
init(0, N - 1, 1)
for _ in range(M + K):
a, b, c = map(int, input().split())
if a == 1:
b = b - 1
diff = c - arr[b]
arr[b] = c
update(0, N - 1, 1, b, diff)
elif a == 2:
print(interval_sum(0, N - 1, 1, b - 1, c - 1))