import sys
input = sys.stdin.readline
# input
N, M = map(int, input().split())
arr = [int(input()) for _ in range(N)]
pair = [map(int, input().split()) for _ in range(M)]
세그먼트 트리의 가장 일반적인 형태는 구간합을 구하는 형태일 것이다.
다음은 구간합을 구하는 세그먼트 트리의 코드 예시이다.
class SegmentTree:
def __init__(self, arr):
self.n = len(arr)
self.tree = [0] * (4 * self.n) # 트리의 크기는 원래 배열 크기의 4배
self.build(arr, 0, self.n - 1, 1) # 세그먼트 트리 구축
def build(self, arr, left, right, node):
if left == right: # 리프 노드에 도달한 경우
self.tree[node] = arr[left]
return
mid = (left + right) // 2
self.build(arr, left, mid, node * 2) # 왼쪽 자식 노드 구축
self.build(arr, mid + 1, right, node * 2 + 1) # 오른쪽 자식 노드 구축
# 왼쪽 자식 노드와 오른쪽 자식 노드의 값을 요약하여 현재 노드에 저장
self.tree[node] = self.tree[node * 2] + self.tree[node * 2 + 1]
def query_sum(self, left, right, node, node_left, node_right):
if right < node_left or left > node_right: # 구간이 완전히 벗어난 경우
return 0
if left <= node_left and node_right <= right: # 구간이 완전히 포함되는 경우
return self.tree[node]
mid = (node_left + node_right) // 2
# 왼쪽 자식 노드와 오른쪽 자식 노드로 분할하여 구간 합 계산
return self.query_sum(left, right, node * 2, node_left, mid) + \
self.query_sum(left, right, node * 2 + 1, mid + 1, node_right)
def get_sum(self, left, right):
return self.query_sum(left, right, 1, 0, self.n - 1)
또는 아래와 같이 나타낼 수도 있다.
class SegmentTree:
def __init__(self, arr):
self.arr = arr
self.tree = [0] * (4 * len(arr)) # 세그먼트 트리를 저장할 배열
def build(self, node, start, end):
if start == end:
self.tree[node] = self.arr[start]
else:
mid = (start + end) // 2
self.build(2 * node, start, mid) # 왼쪽 자식 노드를 빌드
self.build(2 * node + 1, mid + 1, end) # 오른쪽 자식 노드를 빌드
self.tree[node] = self.tree[2 * node] + self.tree[2 * node + 1] # 요약 정보 업데이트
def query(self, node, start, end, left, right):
if left > end or right < start: # 구간이 완전히 벗어난 경우
return 0
if left <= start and right >= end: # 구간이 완전히 포함되는 경우
return self.tree[node]
mid = (start + end) // 2
left_sum = self.query(2 * node, start, mid, left, right) # 왼쪽 자식 노드로 재귀 호출
right_sum = self.query(2 * node + 1, mid + 1, end, left, right) # 오른쪽 자식 노드로 재귀 호출
return left_sum + right_sum
사용 예시
arr = [1, 3, 5, 7, 9, 11]
tree = SegmentTree(arr)
tree.build(1, 0, len(arr) - 1) # 세그먼트 트리 빌드
print(tree.query(1, 0, len(arr) - 1, 1, 4)) # 구간 [1, 4]의 합 출력
+) 위 문제와 상관없으나 update 함수는 다음과 같다.
def update(self, node, start, end, index, diff):
if index < start or index > end: # 인덱스가 구간에 속하지 않는 경우
return
self.tree[node] += diff
if start != end: # 리프 노드가 아닌 경우
mid = (start + end) // 2
self.update(2 * node, start, mid, index, diff) # 왼쪽 자식 노드로 재귀 호출
self.update(2 * node + 1, mid + 1, end, index, diff) # 오른쪽 자식 노드로 재귀 호출
tree.update(1, 0, len(arr) - 1, 2, 2) # 인덱스 2의 값을 2만큼 증가
위 문제에서는 각 구간에 대해 최솟값과 최댓값을 구해야 하므로 총 두개의 트리가 필요하다. build
함수와 query
함수 에서 자식 트리의 값을 합하는 부분을 적절히 변경하면 된다. 다음은 구현 코드이다.
class SegmentTree():
def __init__(self, arr):
self.n = len(arr)
self.mintree = [0]*(4*self.n)
self.maxtree = [0]*(4*self.n)
self.build_min(arr, 0, self.n-1, 1)
self.build_max(arr, 0, self.n-1, 1)
def build_min(self, arr, left, right, node):
if left == right:
self.mintree[node] = arr[left]
return
mid = (left + right) // 2
self.build_min(arr, left, mid, node*2)
self.build_min(arr, mid+1, right, node*2+1)
self.mintree[node] = min(self.mintree[node * 2], self.mintree[node*2+1])
def build_max(self, arr, left, right, node):
if left == right:
self.maxtree[node] = arr[left]
return
mid = (left + right) // 2
self.build_max(arr, left, mid, node*2)
self.build_max(arr, mid+1, right, node*2+1)
self.maxtree[node] = max(self.maxtree[node*2], self.maxtree[node*2+1])
def query_min(self, left, right, node, node_left, node_right):
if node_right < left or right < node_left: # 구간을 벗어남
return float('inf')
if left <= node_left and node_right <= right: # 구간 내 포함
return self.mintree[node]
node_mid = (node_left + node_right) // 2
return min(self.query_min(left, right, node*2, node_left, node_mid),
self.query_min(left, right, node*2+1, node_mid+1, node_right))
def query_max(self, left, right, node, node_left, node_right):
if node_right < left or right < node_left: # 구간을 벗어남
return 0
if left <= node_left and node_right <= right: # 구간 내 포함
return self.maxtree[node]
node_mid = (node_left + node_right) // 2
return max(self.query_max(left, right, node*2, node_left, node_mid),
self.query_max(left, right, node*2+1, node_mid+1, node_right))
def get_min(self, left, right):
return self.query_min(left, right, 1, 0, self.n-1)
def get_max(self, left, right):
return self.query_max(left, right, 1, 0, self.n-1)
문제에서 입력받은 a,b 값은 실제 인덱스가 아닌 입력 순서이므로, -1
씩 연산한다.
segment_tree = SegmentTree(arr)
for a, b in pair:
print(segment_tree.get_min(a-1, b-1), segment_tree.get_max(a-1, b-1))