class SegmentTree:
def __init__(self, arr):
self.n = len(arr) # 입력 배열의 크기
self.arr = arr # 원본 배열
self.tree = [0] * (4 * self.n) # 세그먼트 트리 배열 (충분히 크게 할당)
self.build(1, 0, self.n - 1) # 세그먼트 트리 빌드 시작 (루트 노드 번호는 1)
def build(self, node: int, start: int, end: int):
# 리프 노드인 경우 (start == end)
if start == end:
self.tree[node] = self.arr[start]
else:
# 자식 노드로 분할하여 재귀적으로 빌드
mid = (start + end) // 2
self.build(2 * node, start, mid) # 왼쪽 자식
self.build(2 * node + 1, mid + 1, end) # 오른쪽 자식
# 현재 노드는 두 자식 노드의 합
self.tree[node] = self.tree[2 * node] + self.tree[2 * node + 1]
def _query(self, node: int, start: int, end: int, l: int, r: int) -> int:
# [start, end]와 [l, r]이 겹치지 않는 경우
if r < start or end < l:
return 0
# [start, end]가 [l, r]에 완전히 포함되는 경우
if l <= start and end <= r:
return self.tree[node]
# 부분적으로 겹치는 경우 왼쪽/오른쪽 자식으로 내려감
mid = (start + end) // 2
left_sum = self._query(2 * node, start, mid, l, r)
right_sum = self._query(2 * node + 1, mid + 1, end, l, r)
# 왼쪽과 오른쪽 자식의 결과를 합산
return left_sum + right_sum
def query(self, l: int, r: int) -> int:
# 사용자에게 제공되는 쿼리 인터페이스
# 배열 인덱스 l ~ r 구간의 합을 반환
return self._query(1, 0, self.n - 1, l, r)
def update(self, idx: int, value: int, node: int, start: int, end: int):
# 리프 노드에 도달한 경우 (idx 위치를 찾은 경우)
if start == end:
self.arr[idx] = value # 원본 배열도 갱신
self.tree[node] = value # 세그먼트 트리 노드 갱신
else:
# 자식 노드로 내려감
mid = (start + end) // 2
if idx <= mid:
self.update(idx, value, 2 * node, start, mid) # 왼쪽 자식
else:
self.update(idx, value, 2 * node + 1, mid + 1, end) # 오른쪽 자식
# 자식 노드가 변경되었으니 현재 노드도 다시 계산
self.tree[node] = self.tree[2 * node] + self.tree[2 * node + 1]
def _update(self, idx: int, value: int):
# 사용자에게 제공되는 업데이트 인터페이스
# idx 위치의 값을 value로 변경
self.update(idx, value, 1, 0, self.n - 1)