[자료구조] 세그먼트 트리(Segment Tree)

sinryuji·2024년 12월 12일
post-thumbnail

세그먼트 트리(Segment Tree)란?

출처: https://www.geeksforgeeks.org/segment-tree-data-structure/

세그먼트 트리는 구간합, 구간곱, 구간 최소값, 구간 최대값 등 구간에 대한 계산 결과를 빠르게(O(logN)O(logN)) 구할 수 있는 자료구조이다. 그리고 특정 노드에 대한 수정까지 O(logN)O(logN)의 시간 복잡도로 가능해서 잦은 변경이 일어나는 경우에도 사용하기에 적합한 자료구조이다.

위 이미지를 보면 세그먼트 트리는 각 구간에 대한 계산 결과를 미리 구한 후 그 값을 노드에 저장해놓은 이진트리임을 확인할 수 있다. 예를 들어 arr[3:5]에 대한 구간 합이 필요하다면 27의 값을 가지고 있는 노드를 바로 참조하여 구할 수 있다. 만약에 arr[2:5]에 대한 구간합이 필요하다면 arr[2:2]에 해당하는 노드와 arr[3:5]에 해당하는 노드의 값을 더해 구할 수 있는 것이다. 이미지로 보면 다음 동그라미를 친 노드들의 값만을 더하면 된다. 세그먼트 트리는 대략 이렇게 동작한다.


수정이 일어나지 않는 배열에 대한 구간 합을 구한다면 누적 합(Prefix Sum)을 사용하는 것이 효과적일 것이다.

list = [1, 3, 5, 7, 9, 11]

위와 같은 배열이 있다고 가정을 해보자. [1, 4, 9, 16, 25, 36]와 같은 누적 합 배열을 만드는데는 O(N)O(N)의 시간 복잡도를, 특정 구간 합을 구하는데는 O(1)O(1)의 시간 복잡도를 가진다.

하지만 2번 인덱스의 값이 8로 변한다면 어떻게 될까? 2번 인덱스 값의 변화는 누적합 배열에서 인덱스 2 이상의 모든 값에 영향을 끼치므로 그 모든 값들을 업데이트 해주는 과정이 필요하다. 다음과 같이 말이다. [1, 4, 12, 19, 28, 39]

즉, 수정이 일어날때마다 O(N)O(N)의 시간 복잡도가 필요해지게 된다. 배열이 크면 클수록, 수정이 많이 일어나면 일어날수록 연산량은 크게 늘어날 것이다.

그렇다면 세그먼트 트리에서는 어떨까? 똑같이 2번 인덱스의 값을 바꿨을 때 위 이미지에서 2번 인덱스가 포함하는 구간의 노드들만 수정을 해주면 된다. arr[0:5], arr[0:2], arr[2:2]에 해당하는 노드들의 수정만이 필요하다. 그림으로 보면 다음 동그라미를 친 노드들이다.

물론 이 예시는 배열의 길이가 매우 짧기에 수정해야 할 개수의 차이가 거의 나지 않지만 배열의 길이가 길면 길수록 그 차이는 매우 커지게 된다!!

세그먼트 트리 구현하기

세그먼트 트리 초기화

세그먼트 트리를 초기화 할 때 트리의 크기는 일반적으로 배열의 크기 * 4의 크기를 할당하면 된다.

arr = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
tree = [0] * (len(arr) * 4)

사실 정확히는 배열의 크기보다 큰 제곱수 중 가장 작은 제곱수 * 2의 크기면 된다. 예를 들어 배열의 크기가 10이라면 이보다 크면서 가장 작은 제곱수는 424^2인 16이므로 이의 2배인 32의 크기면 충분하다. 하지만 그런 제곱수를 구하는 과정이 귀찮기도 하고 그냥 4배를 해버려도 얼추 그 근사치에 낭비되는 공간도 그렇게 크지 않기 때문에 보통 그냥 4배를 해버리는 경우가 일반적인 것이다!

💡 왜 배열의 크기보다 큰 제곱수 중 가장 작은 제곱수 * 2의 크기가 필요할까?

세그먼트 트리를 생성할 때 왼쪽 자식 노드는 index * 2, 오른쪽 자식 노드는 index * 2 + 1의 규칙으로 생성을 한다. 만약에 어떤 노드가 인덱스 8에 위치하였다면 그 왼쪽 노드는 16에 위치하고 오른쪽 노드는 17에 위치하게 되는 것이다. 이렇게 자식 노드의 위치를 지정을 해놔야 트리 생성 후에 수정이나 탐색을 할 때에도 편하기 때문이다.

만약에 세그먼트 트리를 만들었는데 포화 이진 트리(Perfect Binary Tree)의 형태가 된다면 2N12N - 1의 크기로도 충분하다.

출처: https://iq.opengenus.org/perfect-binary-tree/

위와 같이 트리가 구성이 된다면 리프 노드에 NN만큼의 공간이 필요하고, 부모 노드로 올라가면서 무조건 2개의 노드가 하나의 노드로 합쳐지므로 N1N - 1의 공간이 필요해지게 되어 총 2N12N - 1의 공간이면 충분하다.

하지만 모든 세그먼트 트리가 포화 이진 트리로 구성되는건 아니다. 다음 이미지를 보자.

출처: https://velog.io/@kimdukbae/%EC%9E%90%EB%A3%8C%EA%B5%AC%EC%A1%B0-%EC%84%B8%EA%B7%B8%EB%A8%BC%ED%8A%B8-%ED%8A%B8%EB%A6%AC-Segment-Tree

경우에 따라서 세그먼트 트리는 위와 같이 전 이진 트리(Full Binary Tree)로 구성 될 수도 있다. 동그리미를 친 노드를 보면 인덱스가 12인 부모 노드의 왼쪽 자식 노드이므로 앞서 설명한 규칙에 따라 인덱스가 24가 된다. 그런데 크기를 2N12N - 1로 할당을 했다면 당연히 해당 인덱스에는 할당이 불가능해진다. 값이 10인 노드를 보면 인덱스가 15인데 이의 자식 노드는 인덱스 31에 위치해야 할 것이고 이게 가장 큰 인덱스가 필요한 노드이다. 앞서 크기가 10인 배열의 경우엔 그보다 큰 가장 작은 제곱수인 16에 2를 곱한 32의 공간이 필요하다고 했다. 이게 배열의 크기보다 큰 제곱수 중 가장 작은 제곱수 * 2의 크기가 필요한 이유이다.

세그먼트 트리를 초기화 하는 함수는 다음과 같다.

def init(start, end, idx):
    if start == end:
        tree[idx] = arr[start]
        return tree[idx]
        
    mid = (start + end) // 2
    tree[idx] = init(start, mid, idx * 2) + init(mid + 1, end, idx * 2 + 1)
    
    return tree[idx]
  • start: 배열의 시작 인덱스
  • end: 배열의 마지막 인덱스
  • idx: 트리의 인덱스

startend가 같다는 건 리프 노드에 도달했다는 의미이고 리프 노드에는 배열의 각 요소가 위치하므로 해당 트리의 인덱스에 배열의 값을 할당(tree[idx] = arr[start]) 해준다.

현재 처리 구간을 반으로 나눠 왼쪽 부분은 왼쪽 자식 노드에게 넘기고, 오른쪽 부분은 오른쪽 자식 노드에게 넘긴다. 그러기 위해 중간값인 mid를 구하고 앞서 설명한 규칙대로 왼쪽 자식 노드에게는 idx * 2 를 인덱스로 넘기고 오른쪽 자식 노드에게는 idx * 2 + 1를 인덱스로 넘긴다.

구간 합을 구해야 하므로 양 구간의 값을 더하여 tree[idx]에 저장해다.

구간 합 구하기

def find(start, end, idx, left, right):
    if left > end or right < start:
        return 0
        
    if left <= start and right >= end:
        return tree[idx]
        
    mid = (start + end) // 2
    return find(start, mid, idx * 2, left, right) + find(mid + 1, end, idx * 2 + 1, left, right)
  • start: 배열의 시작 인덱스
  • end: 배열의 마지막 인덱스
  • idx: 트리의 인덱스
  • left: 구할 구간의 시작 인덱스
  • right: 구할 구간의 마지막 인덱스

구할 구간의 시작 인덱스(left)가 배열의 마지막 인덱스보다 크거나 구할 구간의 마지막 인덱스(right)가 배열의 시작 인덱스보다 작다는 것은 배열이 구할 구간에서 벗어났다는 의미이다. 그럴 경우엔 합에 더해지면 안되므로 0을 리턴한다.

❗️ 주의해야 할 점

구간을 벗어났을때는 무조건 0을 리턴하는게 아니라, 의미가 없는 값을 리턴하는 것이다. 해당 케이스는 구간 합을 구하기 때문에 합에 의미가 없는 0을 리턴하는 것이고, 만약에 구간 곱을 구한다면 1을 리턴해야 할 것이고, 구간 최솟값을 구한다면 int('float')와 같은 값을 리턴해야 할 것이다!

만약에 배열의 구간이 구할 구간에 포함되어 있다면(if left <= start and right >= end) 해당 노드의 값을 리턴한다.

역시나 중간 값인 mid를 구하고 배열의 구간을 나눠가며 왼쪽 자식의 경우엔 index * 2 오른쪽 자식의 경우엔 index * 2 + 1을 넘기며 재귀를 진행한다. 그렇게 되면 구간을 벗어나는 경우엔 0을, 구간에 포함되는 경우엔 그 구간의 구간 합을 리턴할 것이니 그 구간 합을 모두 더하면 최종적으로 내가 구하고자 하는 구간의 합을 구할 수 있다!

배열의 값 수정하기

구간 합을 구하는 것보다 배열의 값을 수정할 때가 세그먼트 트리의 진면목이자 가장 큰 강점이다.

배열의 값을 수정 하는 데는 Top-DownBottom-Up 두 가지 방식이 존재한다.

구간 합을 구하는 경우에는 어떠한 방식으로 구현을 해도 상관이 없다. 하위 노드 값의 변화가 상위 노드의 값에 끼치는 영향이 수정 할 값 - 기존의 값으로 항상 일정하기 때문이다. 그 차이만큼 포함되는 구간의 노드에 더해주기만 하면 된다.

그런데 구간 곱이나 구간 최솟값을 처리하는 경우에는 무조건 Bottom-Up으로 처리를 해야 한다. 구간 합과 달리 이 경우들은 하위 노드 값의 변화가 상위 노드에 끼치는 영향이 모두 다르다. 구간 계산의 결과가 반대편 노드에 따라 달라지기 때문이다.

세그먼트 트리는 응용으로 갈수록 단순히 구간 합 보다는 구간에 대한 다양한 계산 결과를 요구하는 케이스가 많기 때문에 반드시 두 가지 방식 모두 숙지를 하고 있어야 한다!

Top-Down

def update_top_down(N, target, value):
    diff = value - arr[target]
    arr[target] = value
    update_tree_top_down(0, N - 1, 1, target, diff)

def update_tree_top_down(start, end, idx, target, diff):
    if target < start or target > end:
        return

    tree[idx] += diff
    
    if start == end:
        return
        
    mid = (start + end) // 2
    update_tree_top_down(start, mid, idx * 2, target, diff)
    update_tree_top_down(mid + 1, end, idx * 2 + 1, target, diff)
  • N: 배열의 크기
  • target: 수정할 배열의 인덱스
  • value: 수정할 값
  • diff: 기존 값 - 수정할 값

배열의 값을 우선적으로 바꿔주고(arr[target] = value) 재귀 함수로 진입한다.

역시나 범위에서 벗어나는 경우에는 바로 리턴을 해주고 아닐 경우 기존 값 - 수정할 값만큼을 트리의 노드에 더해준다.

start == end의 경우엔 리프 노드에 닿은 경우니 리턴을 해주는데, 그렇지 않을 경우 아직 내려갈 자식 노드가 남은 것이니 역시나 중간 값을 구해준 뒤 왼쪽 구간과 오른쪽 구간을 나눠가며 재귀를 진행하면 된다!

Bottom-Up

def update_bottom_up(start, end, idx, target, value):
    if target < start or target > end:
        return

    if start == end:
        arr[target] = value
        tree[idx] = value
        return

    mid = (start + end) // 2
    update_bottom_up(start, mid, idx * 2, target, value)
    update_bottom_up(mid + 1, end, idx * 2 + 1, target, value)

    tree[idx] = tree[idx * 2] + tree[idx * 2 + 1]

역시나 구간을 벗어나는 경우에는 바로 리턴을 해주고 리프 노드에 닿았을 경우 배열과 트리 모두 값을 수정을 해준다.

그리고 올라오며 값을 처리해야 하므로 먼저 왼쪽 구간 오른쪽 구간 나눠 재귀를 타고, 리프 노드에 닿아 리턴이 되면 올라오며 수정된 값을 업데이트(tree[idx] = tree[idx * 2] + tree[idx * 2 + 1]) 해준다.

왼쪽 자식 노드는 idx * 2이고 오른쪽 자식 노드는 idx * 2 + 1이다. 그리고 자식 노드로 타고 들어간 재귀가 끝났다는 것은 이미 자식 노드는 값이 업데이트 됐음을 의미하므로 그 후 왼쪽 자식 노드 값(tree[idx * 2])과 오른쪽 자식 노드 값(tree[idx * 2 + 1])을 더하여 현재 노드의 값을 업데이트 하는 것이다.

만약 구간 합이 아니라 구간 곱일 경우에는 tree[idx] = tree[idx * 2] * tree[idx * 2 + 1]와 같이 두 자식 노드의 합이 아니라 두 자식 노드의 곱으로만 수정을 해주면 된다.

전체 코드

arr = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
tree = [0] * (len(arr) * 4)

def init(start, end, idx):
    if start == end:
        tree[idx] = arr[start]
        return tree[idx]
        
    mid = (start + end) // 2
    tree[idx] = init(start, mid, idx * 2) + init(mid + 1, end, idx * 2 + 1)
    
    return tree[idx]
    
def find(start, end, idx, left, right):
    if left > end or right < start:
        return 0
        
    if left <= start and right >= end:
        return tree[idx]
        
    mid = (start + end) // 2
    return find(start, mid, idx * 2, left, right) + find(mid + 1, end, idx * 2 + 1, left, right)
    
def update_top_down(N, target, value):
    diff = value - arr[target]
    arr[target] = value
    update_tree_top_down(0, N - 1, 1, target, diff)

def update_tree_top_down(start, end, idx, target, diff):
    if target < start or target > end:
        return

    tree[idx] += diff
    
    if start == end:
        return
        
    mid = (start + end) // 2
    update_tree_top_down(start, mid, idx * 2, target, diff)
    update_tree_top_down(mid + 1, end, idx * 2 + 1, target, diff)
    
def update_bottom_up(start, end, idx, target, value):
    if target < start or target > end:
        return

    if start == end:
        arr[target] = value
        tree[idx] = value
        return

    mid = (start + end) // 2
    update_bottom_up(start, mid, idx * 2, target, value)
    update_bottom_up(mid + 1, end, idx * 2 + 1, target, value)

    tree[idx] = tree[idx * 2] + tree[idx * 2 + 1]
    
init(0, len(arr) - 1, 1)
print(find(0, len(arr) - 1, 1, 0, 9))  # 0부터 9까지의 구간 합 (1 + 2 + ... + 9 + 10)
print(find(0, len(arr) - 1, 1, 0, 2))  # 0부터 2까지의 구간 합 (1 + 2 + 3)
print(find(0, len(arr) - 1, 1, 6, 7))  # 0부터 2까지의 구간 합 (7 + 8)

# arr[0]을 5로 수정(탑 다운)
update_top_down(len(arr), 0, 5)
print(find(0, len(arr) - 1, 1, 0, 2))   # 0부터 2까지의 구간 합 (5 + 2 + 3)

# arr[9]를 -1 수정(바텀 업)
update_bottom_up(0, len(arr) - 1, 1, 9, -1)
print(find(0, len(arr) - 1, 1, 8, 9))   # 8부터 9까지의 구간 합 (9 + -1)

💡 참고로 매개변수 idx, 그러니까 트리의 인덱스가 1부터 시작하는 이유는 역시나 또 설명을 하지만 왼쪽 자식 노드는 index * 2, 오른쪽 자식 노드는 index * 2 + 1의 규칙을 따르기 때문이다. 0부터 시작해버리면 왼쪽 자식 노드는 영원히 0이 되어버린다!

세그먼트 트리는 구간에 대한 계산 결과를 O(logN)O(logN)으로 구하면서 수정까지도 할 수 있는 매우 유용한 자료구조이다. 분명 익혀두면 많이 도움이 될 것이라고 생각하고 무엇보다 백준에서 골드1 ~ 플래4 정도에 그렇게 높지 않은 난이도로 많은 세그먼트 트리 문제들이 분포를 하고 있는데 백준 티어 올리기에 정말 개꿀통(?)인 알고리즘이다...ㅋㅋ 세그먼트 트리 덕분에 플래티넘5 정체 구간을 뚫고 플래티넘4를 달성했다 ㅎㅎ

참고

https://velog.io/@kimdukbae/%EC%9E%90%EB%A3%8C%EA%B5%AC%EC%A1%B0-%EC%84%B8%EA%B7%B8%EB%A8%BC%ED%8A%B8-%ED%8A%B8%EB%A6%AC-Segment-Tree

profile
응애 개발자입니다.

0개의 댓글