구간 트리(segment tree)란?

황준승·2021년 6월 6일
0

😂 구간트리(세그먼트 트리)란?

트리들은 모두 자료들을 특정 순서대로 저장하고, 추가/삭제하는 등 자료를 저장하는 용도로 사용되었습니다만, 이들 외에도 다양한 종류의 트리가 있습니다. 이 장에서 다루는 구간 트리는 저장된 자료들을 적절히 전처리해 그들에 대한 질의들을 빠르게 대답할 수 있도록 합니다. 구간 트리는 흔히 일차원 배열의 특정 구간에 대한 질문을 빠르게 대답하는데 사용합니다.

보통 1차원의 배열에서 특정 구간의 합을 구하기 위해서 선형적으로 구합니다. 이러한 방식을 고려했을 때 앞에서 하나씩 더해가므로 데이터의 개수가 n이면 시간 복잡도는 O(N)이 나옵니다. 따라서 이러한 방식을 이용했을 때 구간의 합을 구하는 속도가 너무 느립니다.

따라서 우리는 트리의 구조를 사용하여 좀 더 빠르게 풀 수 있습니다. 트리 구조의 특성상 합을 구할 때 시간 복잡도는 O(logN)이 됩니다. 그렇다면 어떻게 트리를 만들어 구간의 합을 빠르게 구하는 지 자세히 알아보자.

😘 1. 구간 합 트리 생성하기

인덱스는 아까와 동일하게 0부터 11까지 입니다. 이제 빠르게 합을 구하기 위해서 '구간 합 트리'를 새롭게 생성해주어야 합니다. 먼저 최상단 노드에는 전체 원소를 더한 값이 들어갑니다.

이후 두 번째 노드와 세번째 노드를 구합니다. 두번째 노드는 인덱스 0부터 5까지의 원소를 더한 값을 가지고, 세번째 노드는 인덱스 6부터 인덱스 11까지의 원소를 더한 값을 가집니다. 말 그대로 원래 데이터 범위를 반씩 분할하며 그 구간의 합들을 저장하도록 초기 설정하는 것이다. 이를 통해 구간 합 트리의 전체 노드를 구할 수 있습니다.

이와 같은 방식으로 구간 합 트리를 구하게 되면 결과적으로 다음과 같습니다. 재미있는 점은 구간 합 트리에 한해서는 인덱스 번호가 1부터 시작한다는 것입니다.

코드

#구간 합 트리 생성
def make_tree(start, end, node):
    if start == end: 
        tree[node] = lst[start]
        return tree[node]

    mid = (start + end) // 2

    tree[node] = make_tree(start, mid, node*2) + make_tree(mid+1,end,node*2+1)
    
    return tree[node]

또한 구간 합 트리의 원소 개수는 위 그림만 보면 알 수 있듯이 32개라는 것을 알 수 있습니다. 쉽게 말해 데이터의 개수가 n개일 때 n보다 큰 가장 가까운 n의 제곱수를 구한 뒤에 그것의 2배까지 미리 배열의 크기를 미리 만들어 놓아야 한다는 것이다.

예를 들어 데이터 개수가 12일 경우 16*2 = 32개의 크기가 필요했던 것입니다. 그래서 실제로는 데이터의 개수 N에 4를 곱한 크기만큼 미리 구간 합 트리의 공간을 할당합니다.

😁 2. 구간 합을 구하는 함수 만들기

이제 구간 합을 구하는 함수를 만들어 보자. 트리 구조를 가지고 있기 때문에 데이터를 탐색함에 있어 들이는 비용은 O(logN)입니다. 따라서 구간 합을 항상 O(logN)의 시간에 구할 수 있습니다. 예를 들어 4~8의 범위에 대한 합은 다음과 같이 세 노드의 합만 구해주면 됩니다.

위의 그림과 같이 구간의 합은 '범위 안에 있는 경우'에 한해서만 더해주면 됩니다.

코드

#구간 합 더하기
def tree_sum(start, end, node, left, right):
    if (left > end) or (right < start):
        return 0

    if (left <= start) and (end <= right):
        return tree[node]

    mid = (start + end) // 2

    return tree_sum(start, mid, node*2,left,right) + tree_sum(mid+1,end, node*2+1,left, right)

😘 3. 특정원소의 값을 수정하는 함수 만들기

특정 원소의 값을 수정할 때는 해당 원소를 포함하고 있는 모든 구간 합 노드들을 갱신해주면됩니다. 예를 들어 인덱스 7의 노드를 수정한다고 하면 다음과 같이 5개의 구간 합 노드를 모두 수정하면 됩니다.

이 함수 또한 재귀적으로 구현하면 어렵니 않게 작성할 수 있습니다. 마찬가지로 수정할 노드로는 '범위 안에 있는 경우' 에 한해서만 수정해주시면 됩니다.

코드

#특정 원소 바꾸기
#fixed = 바꾸려는 값 - 원래 값
def change(start, end, node, idx, fixed):
    if idx < start or idx > end:
        return 
    
    tree[node] += fixed 
    if start == end:
        return

    mid = (start + end) // 2
    change(start, mid, node*2, idx, fixed)
    change(mid + 1, end, node*2+1 , idx, fixed)       

make_tree(0,n-1,1)

대표적으로 백준의 구간 합 구하기 문제(2042번)를 풀어보시면 완벽하게 감이 잡히실 겁니다.

profile
다른 사람들이 이해하기 쉽게 기록하고 공유하자!!

0개의 댓글