세그먼트 트리(Segment Tree)란?
- 여러 개의 데이터가 존재할 때 특정 구간(중간)의 합(최솟값, 최댓값, 곱 등)을 구하는 데 사용하는 자료구조
- 이진 트리 형태
- 특정구간의 합을 가장 빠르게 구할 수 있음
- 시간복잡도 :
세그먼트 트리 관련 문제 구현 패턴 및 기능
- 세그먼트 트리 공간 할당
- Python 에서는 리스트로 생성
- 세그먼트 트리 생성 & 초기화
- 세그먼트 트리의 인덱스는 무조건 1부터 시작
- 이 과정에서 구간의 합을 구할것인지 아니면 최솟값, 최댓값, 곱을 구할 것인지 결정해서 트리를 생성
- 원하는(특정) 구간의 합(최솟값, 최댓값, 곱 등)을 구하는 함수 생성
- 범위 안에 있는 경우에 한해서만 더해주면 됨
- 중간에(구간의) 어떤 부분의 합(최솟값, 최댓값, 곱 등) 구하는 함수 생성
- 특정 원소의 값을 수정하는 함수 생성
- 중간에 수의 변경이 일어남
- 해당 원소를 포함하고 있는 모든 구간 합 노드들을 갱신
- 해당 원소를 포함하고 있는 부분적인 노드들만 바꿔주면 댐
- 항상 트리는 루트부터 시작
- 트리와 관련된 거의 모든 구현은 재귀적으로 구현된다.
세그먼트 트리(Segment Tree) 자세한 설명 보러가기 (Python 기준)
어떤 N개의 수가 주어져 있다. 그런데 중간에 수의 변경이 빈번히 일어나고 그 중간에 어떤 부분의 합을 구하려 한다. 만약에 1,2,3,4,5 라는 수가 있고, 3번째 수를 6으로 바꾸고 2번째부터 5번째까지 합을 구하라고 한다면 17을 출력하면 되는 것이다. 그리고 그 상태에서 다섯 번째 수를 2로 바꾸고 3번째부터 5번째까지 합을 구하라고 한다면 12가 될 것이다.
첫째 줄에 수의 개수 N(1 ≤ N ≤ 1,000,000)과 M(1 ≤ M ≤ 10,000), K(1 ≤ K ≤ 10,000) 가 주어진다. M은 수의 변경이 일어나는 횟수이고, K는 구간의 합을 구하는 횟수이다. 그리고 둘째 줄부터 N+1번째 줄까지 N개의 수가 주어진다. 그리고 N+2번째 줄부터 N+M+K+1번째 줄까지 세 개의 정수 a, b, c가 주어지는데, a가 1인 경우 b(1 ≤ b ≤ N)번째 수를 c로 바꾸고 a가 2인 경우에는 b(1 ≤ b ≤ N)번째 수부터 c(b ≤ c ≤ N)번째 수까지의 합을 구하여 출력하면 된다.
입력으로 주어지는 모든 수는 -263보다 크거나 같고, 263-1보다 작거나 같은 정수이다.
첫째 줄부터 K줄에 걸쳐 구한 구간의 합을 출력한다. 단, 정답은 -263보다 크거나 같고, 263-1보다 작거나 같은 정수이다.
전형적인 세그먼트 트리를 활용한 문제
import sys
sys.setrecursionlimit(10**9)
input = sys.stdin.readline
# 세그먼트 트리를 배열의 각 구간 합으로 채워주기 (세그먼트 트리 생성 & 초기화)
def init(start, end, index):
# 리프노드에 도달했으면
if start == end:
tree[index] = data[start]
return tree[index]
# 두개의 서브트리로 쪼갬
mid = (start + end)//2
# 후위 순회(LRV) 방식으로 값을 채워 나간다.
tree[index] = init(start, mid, index*2) + init(mid+1, end, index*2+1)
return tree[index]
# 원하는(특정) 구간의 합(최솟값, 최댓값, 곱 등)을 구하는 함수
def interval_sum(start, end, index, left, right):
# 범위를 완전히 벗어난 경우 (내가 원하는 구간이 아닌 경우)
if left > end or right < start:
return 0
# 범위 안에 있는 경우 (내가 원하는 구간안에 속해 있는 경우)
if left <= start and end <= right:
return tree[index]
# 그렇지 않다면 두 서브트리로 쪼개서 합을 구함
mid = (start + end)//2
return interval_sum(start, mid, index * 2, left, right) + interval_sum(mid+1, end, index*2+1, left, right)
# 특정 원소의 값을 수정하는 함수
def update(start, end, index, up_inx, diff):
# 구간안에 수정할 구간(인덱스) 없는 경우
if up_inx < start or up_inx > end:
return
# 구간안에 수정할 구간(인덱스) 있는 경우 수정 해야할 diff(차이) 만큼 갱신해줌
tree[index] += diff
# 리프노드까지 바꿔주었으면 다시 전 단계로 돌아감
if start == end:
return
# 한 단계 아래로 내려가서 탐색함
mid = (start + end)//2
update(start, mid, index*2, up_inx, diff) # 왼쪽 서브트리
update(mid+1, end, index*2+1, up_inx, diff) # 오른쪽 서브트리
# N개의 수, M: 변경이 일어나는 횟수, K: 구간의 합을 구하는 횟수
N, M, K = map(int, input().split())
data = [0] + [int(input()) for _ in range(N)]
tree = [0] * (N * 4)
# 트리 생성
init(1, N, 1)
for _ in range(M+K):
a, b, c = map(int, input().split())
# a가 1인경우 b번째 수를 c로 바꿈
if a == 1:
val = c - data[b] # 변경할값 - 원래값 = 차이를 구해줌
data[b] = c # 원래 배열에서 값 수정
update(1, N, 1, b, val)
# a가 2인경우 b번째 수부터 c번째 수까지의 합을 구함
elif a == 2:
print(interval_sum(1, N, 1, b, c))
import sys
sys.setrecursionlimit(10**6)
input = sys.stdin.readline
def init(start, end, index):
if start == end:
tree[index] = data[start]
return tree[index]
mid = (start + end)//2
tree[index] = init(start, mid, index*2) + init(mid+1, end, index*2+1)
return tree[index]
def interval_sum(start, end, index, left, right):
if left > end or right < start:
return 0
if left <= start and end <= right:
return tree[index]
mid = (start + end)//2
return interval_sum(start, mid, index * 2, left, right) + interval_sum(mid+1, end, index*2+1, left, right)
def update(start, end, index, up_inx, diff):
if up_inx < start or up_inx > end:
return
tree[index] += diff
if start == end:
return
mid = (start + end)//2
update(start, mid, index*2, up_inx, diff)
update(mid+1, end, index*2+1, up_inx, diff)
N, M, K = map(int, input().split())
data = [0] + [int(input()) for _ in range(N)]
tree = [0] * (N * 4)
init(1, N, 1)
for _ in range(M+K):
a, b, c = map(int, input().split())
if a == 1:
val = c - data[b]
data[b] = c
update(1, N, 1, b, val)
elif a == 2:
print(interval_sum(1, N, 1, b, c))
def update(start, end, index, up_inx, diff):
if up_inx < start or up_inx > end:
return
tree[index] += diff
if start == end:
return
mid = (start + end)//2
update(start, mid, index*2, up_inx, diff)
update(mid+1, end, index*2+1, up_inx, diff)
👉 어떤 방식으로 값이 변경되는 것인가?
위 함수는 원래의 배열에 있는 원소의 값을 수정했을때 트리도 수정된 값으로 구간마다 갱신해주는 함수이다.
처음 세그먼트 트리를 공부하면서 tree[index] += diff
다음과 같은 코드를 봤을 때 살짝 멈칫 했다. 한 번 손으로 따라가보면서 확인해보니 해당 코드는 트리내에서 동작할때 원래 원소와 수정할 원소간의 차이만큼 업데이트 해주는 것이였다. 처음에는 원소자체를 업데이트 한다고 생각을 했었다.
혹시나 설명이 잘못된 부분이 있으면 댓글 부탁드립니다.