구간 합(Prefix Sum)과 세그먼트 트리(Segment Tree)

Yujeong·2025년 1월 6일
post-thumbnail

구간 합(Prefix Sum)

1. 구간 합 구하기

구간 합을 구할 때, 바로 떠올릴 수 있는 방법은 다음과 같을 것이다.
누적된 값을 구해놓고 차이를 구하는 방법이다.

N = 5
nums = [0, 1, 2, 3, 4, 5]
prefix_sum = [0] * (N+1)
for i in range(1, N+1):
	prefix_sum[i] = prefix_sum[i-1] + nums[i]

print(prefix_sum[5] - prefix_sum[2]) # 3에서 5까지 구간합
print(prefix_sum[4] - prefix_sum[1]) # 2에서 4까지 구간합

이렇게 구간 합만 구하는 경우에는 이 방법이 최선의 방법일 것이다.
구간 합을 계산하는 데 시간 복잡도 O(N)O(N), 구간 합을 구하는 데 시간 복잡도 O(1)O(1)이기 때문이다.

2. 값 업데이트 + 구간 합 구하기

하지만, nums의 값을 업데이트하는 경우에는 구간 합을 어떻게 구해야할까?
쉽게 떠올릴 수 있는 방법은 계산해놓은 prefix_sum의 값들을 업데이트 하는 방식일 것이다.

# M: update 횟수, K: compute 횟수
# a/b/c: update/변경할 index/변경할 값, compute/시작 index/끝 index

## ex1
for _ in range(M+K):
    a, b, c = map(int, input().split())
    if a == 'update':
        nums[b] = c
        for _ in range(b, N+1):
        	prefix_sum[b] += (c - nums[b])
    elif a == 'compute':
    	print(prefix_sum[c] - prefix_sum[b-1])

## ex2
diff = [0] * (N+1)
for _ in range(M+K):
    a, b, c = map(int, input().split())
    if a == 'update':
        diff[b:] = [diff[b] + c - nums[b]] * (N+1-b)
        nums[b] = c
    elif a == 'compute':
        print(prefix_sum[c] - prefix_sum[b-1] + diff[c])

반복문을 통해 값을 하나하나 업데이트 한다면, 시간복잡도가 O(N2)O(N^2)이 된다.

어떻게 시간복잡도를 줄일 수 있을까?
세그먼트 트리를 이용하면, 값 업데이트로 인해 늘어난 시간을 O(logN)O(logN)으로 줄일 수 있다.

세그먼트 트리(Segment Tree)

세그먼트 트리란?

세그먼트 트리는 배열의 구간 정보를 효율적으로 관리하기 위해 사용하는 자료 구조이다.
배열의 특정 구간에 대한 합, 최솟값, 최댓값 등의 연산을 빠르게 수행할 수 있도록 설계되었다.

  1. 구조
  • 완전 이진 트리(Full Binary Tree)
  • 리프 노드(leaf node): 값
  • 리프 노드가 아닌 노드: 왼쪽 자식 + 오른쪽 자식
  • 루트 노드(root node): 배열 전체 구간
  1. 인덱싱
  • 왼쪽 자식: 배열 구간의 왼쪽 절반, 2×부모 인덱스2 \times 부모\ 인덱스
  • 오른쪽 자식: 오른쪽 절반, 2×부모 인덱스+12 \times 부모\ 인덱스 + 1

N = 10일 때, 세그먼트 트리는 다음과 같이 구성할 수 있다.

세그먼트 트리 구현 (python)

1. 세그먼트 트리 만들기

1) 배열 크기 정하기

완전 이진 트리에서 리프 노드가 NN개 라면, 리프 노드가 아닌 노드는 N1N-1개 있다. 따라서, 필요한 노드 수는 2N12N-1개가 된다.

트리의 높이 hhNN을 2로 계속 나눴을 때, 1이 될 때까지와 같다.
h=log2Nh = \lceil \log_2 N \rceil
2log2N+1=2h+12^{\lceil \log_2 N \rceil + 1} = 2^{h+1} 이므로, 전체 배열의 크기는 2h+12^{h+1}

  1. (N = 6):
    h=log26=3h = \lceil \log_2 6 \rceil = 3
    배열크기=23+1=16{배열 크기} = 2^{3+1} = 16

  2. (N = 10):
    h=log210=4h = \lceil \log_2 10 \rceil = 4
    배열크기=24+1=32{배열 크기} = 2^{4+1} = 32

2) 트리 만들기

N = 10
arr = list(range(1,11))
tree = [0] * (N*4)

def build_tree(idx, start, end):
    if start == end:
        tree[idx] = arr[start]
    else:
        mid = (start + end) // 2
        build_tree(idx*2, start, mid)
        build_tree(idx*2+1, mid+1, end)
        tree[idx] = tree[idx*2] + tree[idx*2+1]

build_tree(1, 0, N-1)
print(tree)
  • 배열 크기를 N×4N \times 4로 선언한 이유?
    2log2N+12^{\lceil \log_2 N \rceil + 1}은 트리 크기의 최댓값을 이론적으로 계산한 값이다.
    이 값은 대략 2×N2 \times N보다 조금 더 크다.
    4×N4 \times N은 이론적으로 필요한 크기보다 여유를 둘 수 있기 때문에, 크기 계산을 신경 쓰지 않고 안전하게 사용할 수 있다.

2. 구간 합 구하기

노드의 구간이 [start, end], 합을 구하려는 구간이 [left, right]라고 하자.
이때, 생길 수 있는 4가지의 경우가 있다.

  1. [left,right][left,right][start,end][start,end]가 겹치지 않는 경우
    → 구간 합에 영향을 주지 않음
    [start,end]=[5,10][start,end] = [5,10], [left,right]=[1,4][left,right] = [1,4]
  2. [left,right][left,right][start,end][start,end]를 완전히 포함하는 경우
    → 해당 노드에 저장된 값 바로 사용
    [start,end]=[3,7][start, end] = [3,7], [left,right]=[1,10][left, right] = [1,10]
  3. [start,end][start,end][left,right][left,right]를 완전히 포함하는 경우
    → 2개의 하위 구간으로 나누어 계산
    [start,end]=[1,10][start, end] = [1,10], [left,right]=[3,7][left, right] = [3,7]
  4. [left,right][left,right][start,end][start,end]가 겹치는 경우
    → 2개의 하위 구간으로 나누어 계산
    [start,end]=[3,8][start, end] = [3,8], [left,right]=[6,10][left, right] = [6 ,10]
# 노드 구간: [start,end]
# 합을 구하려는 구간: [left, right]

def compute(start, end, idx, left, right):
    if left > end or right < start:
        return 0
    if left <= start and end <= right:
        return tree[idx]
    
    mid = (start + end) // 2
    return compute(start, mid, idx*2, left, right) + \
        compute(mid+1, end, idx*2+1, left, right)

# 1부터 10까지 합
print(compute(0, 10, 1, 0, 10)) # 55
# 4부터 7까지 합
print(compute(0, 10, 1, 3, 7)) # 22

3. 값 업데이트하기

node번째 값을 value로 변경하려고 할 때, 2가지 경우를 고려해야 한다.

  1. [start,end][start,end]node가 포함되는 경우
  2. [start,end][start,end]node가 포함되지 않는 경우
# node: 수정할 노드
# value: 수정할 값

def update_tree(start, end, idx, node, diff):
    if node < start or end < node:
        return
    
    tree[idx] += diff
    if start != end:
        mid = (start + end) // 2
        update_tree(start, mid, idx*2, node, diff)
        update_tree(mid+1, end, idx*2+1, node, diff)

def update(node, value):
    diff = value - arr[node]
    arr[node] = value
    update_tree(0, N-1, 1, node, diff)

print("before:", arr, compute(0, 10, 1, 0, 3)) # before: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] 6
update(2, 6) # 3번째 값을 6으로 바꿈(3->6)
print("after:", arr, compute(0, 10, 1, 0, 3)) # after: [1, 2, 6, 4, 5, 6, 7, 8, 9, 10] 9

Prefix Sum과 Segment Tree 시간복잡도 비교

구분Prefix SumSegment Tree
초기화O(N)O(N)O(N)O(N)
구간 합 계산O(1)O(1)O(logN)O(logN)
값 업데이트O(N)O(N)O(logN)O(logN)

구간 합만 자주 구하는 경우에는 Prefix Sum이 효율적이고, 값 업데이트와 구간 합 모두를 수행해야하는 경우에는 Segment Tree가 효율적이다.


참고

Segment tree

세그먼트 트리 (Segment Tree)

profile
공부 기록

0개의 댓글