[알고리즘] 세그먼트 트리

김제현·2023년 6월 27일
0

알고리즘

목록 보기
9/10
post-thumbnail
2023. 06. 27.

세그먼트 트리 📌

세그먼트 트리는 주어진 데이터의 구간 합과 데이터 업데이트를 빠르게 수행하기 위해 고안해낸 자료구조의 형태이다. 세그먼트 트리의 종류는 구간합, 최대-최소 구하기로 나눌 수 있고, 구현 단계는 트리 초기화하기, 질의값 구하기, 데이터 업데이트하기로 나눌 수 있다.


세그먼트 트리의 핵심 이론

📢 세그먼트 트리 구현 과정

1. 트리 초기화하기

  • 리프 노드의 개수가 데이터의 개수(N) 이상이 되도록 트리 배열을 만든다. 트리 배열의 크기를 구하는 방법은 2^k >= N을 만족하는 k의 최솟값을 구한 후 2^k 2를 트리 배열의 크기로 정의하면 된다. 예를 들어 N=8이라면 2^3 >= 8이므로 배열의 크기를 2^3 2 = 16으로 정의하면 된다.
    ex) {5, 8, 4, 3, 7, 2, 1, 6}

  • 원본 데이터는 자식이 없는 리프 노드에 있으므로 원본 데이터를 입력할 때는 2^k를 시작 인덱스로 취하면 된다. 아래 사진을 예시로 보면 2^3 = 8을 시작 인덱스로 입력하면 된다. 그리고 리프노드를 제외한 7번노드부터 1번노드까지 채워나가면 된다.

  • 위 예시를 이용해 3개의 케이스와 관련된 세크먼트 트리를 구성하면 위와 같이 나온다.

2. 질의값 구하기

질의 인덱스를 세그먼트 트리 인덱스로 변경하는 방법
주어진 질의 인덱스를 세그먼트 트리의 리프 노드에 해당하는 인덱스로 변경한다. 기존 예시를 기준으로 한 인덱스값과 세그먼트 트리 리스트에서의 인덱스값이 다르기 때문에 인덱스를 변경해야 한다.

세그먼트 트리 인덱스 = 주어진 질의 인덱스 + 2^k - 1
k=3이라고 가정할 때, 1~3까지의 구간합을 구하는 질의가 있다고 하자.
1~3까지의 구간합을 구하는 것은 8(1+2^3-1)~10(3+2^3-1)의 구간합을 구하는 것과 같다. 굳이 수식으로 정의하면 이렇게 되지만 상식적으로 위 트리를 참고하면 1번부터 3번까지의 구간합은 노드 8,9,10의 값을 합치는 것과 같다.

질의값 구하는 과정
1. START_INDEX % 2 == 1일 때 해당 노드를 선택한다.
2. END_INDEX % 2 == 0일 때 해당 노드를 선택한다.
3. START_INDEX DEPTH 변경: START_INDEX = (START_INDEX + 1) / 2
4. END_INDEX DEPTH 변경: END_INDEX = (END_INDEX - 1) /2
5. 1~4과정을 반복하다가 END_INDEX < START_INDEX (교차될 때)가 되면 종료한다.

질의에 해당하는 노드 선택 방법
1. 구간 합: 선택된 노드들을 모두 더한다.
2. 최댓값 구하기: 선택된 노드들 중 MAX값을 선택해 출력한다.
3. 최솟값 구하기: 선택된 노드들 중 MIN값을 선택해 출력한다.

3. 데이터 업데이트 하기

  • 업데이트 방식은 자신의 부모 노드로 이동하면서 업데이트 한다는 것은 동일하지만, 어떤 값으로 업데이트할 것인지에 관해서는 트리의 종류별로 다르다.

  • 부모 노드로 이동하는 방식: INDEX = INDEX / 2 로 변경하며 업데이트를 하면 된다. 구간 합에서는 원래 데이터와 변경 데이터의 차이만큼 부모 노드로 올라가면서 변경하지만, 최댓값-최솟값 찾기에서는 변경 데이터와 자신과 같은 부모를 지니고 있는 다른 자식 노드와 비교하여 업데이트가 일어나지 않으면 종료한다.


2023. 06. 27. 오늘의 문제풀이 ✍

BOJ 2042 - 구간곱구하기
import sys

# n = 수 개수, m = 변경이 일어나는 횟수, k = 구간의 곱을 구하는 횟수
n,m,k = map(int,input().split())  

MOD = 1000000007
tree_height = 0
length = n

while length != 0:
    length //= 2
    tree_height += 1

tree_size = 2 ** (tree_height+1)
start_index = tree_size // 2 - 1
tree = [1] * (tree_size + 1)

for i in range(start_index + 1, start_index + n + 1):
    tree[i] = int(input())
    
# 리프노드를 제외한 거 채우기
def getIndexTree(i):
    while i != 1:
        tree[i // 2] = tree[i // 2] * tree[i] % MOD
        i -= 1

getIndexTree(tree_size - 1)
        
# 값 변경 함수
def change_value(index, change_val):
    tree[index] = change_val
    while index > 1:
        index //= 2
        tree[index] = tree[index*2] % MOD * tree[index*2 + 1] % MOD

def getMul(s,e):
    Mul = 1
    
    while s <= e:
        if s % 2 == 1:
            Mul = Mul * tree[s] % MOD
            s += 1
        if e % 2 == 0:
            Mul = Mul * tree[e] % MOD
            e -= 1
            
        s //= 2
        e //= 2
        
    return Mul
        

for _ in range(m+k):
    data, s, e = map(int,input().split())
    
    if data == 1:
        change_value(start_index+s,e)
        
    elif data == 2:
        s = s + start_index
        e = e + start_index
        
        print(getMul(s,e))

출처

Do it algorithm

0개의 댓글