[자료구조] 세그먼트 트리

DAUN JO·2021년 6월 30일
0

TIL

목록 보기
6/17

💡 구간을 저장하기 위한 트리, 세그먼트 트리의 개념과 활용을 공부해보자.

세그먼트 트리(Segment Tree,구간 트리)

특정 구간 내 연산(쿼리)에 대해 빠르게 응답하기 위해 만들어진 자료구조

예를 들어 크기가 N인 int배열 arr이 있다면 1~N의 인덱스 내 숫자들이 위치해 있을 것이다.
이 때, 이 배열의 구간 arr[a] ~ arr[b]의 합을 구하고자 한다고 하자. 아래의 그림과 같다.
(1 <= a && b <= N && a < b)

간단히 반복문을 사용하여 l의 위치와 r의 위치를 찾아 덧셈을 수행하면
시간 복잡도는 최악의 경우 O(N)이다. 이 행위를 M번 수행하면 O(NM)이 된다.

세그먼트 트리는 이 연산의 시간 복잡도를 O(MlogN) 만에 수행할 수 있도록 만들어준다.


세그먼트 트리(Segment Tree)의 구조

기본적으로 이진트리의 구조를 갖는다.

기존 이진 트리에서는 각각의 모든 노드가 고유의 값을 가졌다면, 이번에는 부모 노드가 자식 노드들의 합을 저장하는 방식이라고 생각하면 된다.

기존 데이터의 배열의 크기를 통해서 트리 배열의 최대 크기를 알 수 있다. 기존 데이터 배열의 크기를 N 이라 하면, 리프 노드의 개수가 N 이 되고, 트리의 높이 H 는 [ logN ] 이 되고, 배열의 크기는 2^(H+1) 이 된다.


배열을 세그먼트 트리로

배열의 초기값을 세그먼트 트리에 넣는다.
이 때 루트 노드 0부터 아래로 내려간다.
한 노드의 값을 채울 때에는 자식 노드의 값을 먼저 채우고, 그 뒤에 합을 쓰면 된다.

현재 노드의 인덱스 index 에 대해 왼쪽 자식 노드의 인덱스는 index*2+1, 오른쪽 자식 노드의 인덱스는 index*2+2 로서 재귀적으로 구현하였다.

int init(int start, int end, int index)
{
    if (start == end)
        tree[index] = A[start];
    else{
        int mid = (start+end)/2;
        tree[index] = init(index*2+1, start, mid) + init(index*2+2, mid+1, end);
    }
    return tree[index];
}

어떤 배열의 0번째부터 n번째까지의 값을 세그트리에 넣고 싶다면 init(0, n, 0); 을 쓰면 된다.


구간합 구하기

세그먼트 트리를 이용해 구간합을 구해보자.
두 구간에 대해 왼쪽을 left, 오른쪽을 right라고 한다.
탐색 범위 [start, end]와 합의 구간 [left,end]의 관계는 다음과 같다.

  1. [left, right][start, end]가 전혀 겹치지 않는 경우

    탐색 범위 내에 구하는 범위가 존재하지 않는다.
    그렇다면 탐색 범위에 값들은 아무 의미없는 값이므로 0을 return 한다.

  2. [start, end][left, right]에 속해 있는 경우

    탐색 범위 내에 값들이 전부 구하는 범위의 값들이다.
    하위 노드들을 탐색할 필요없이 이미 하위 노드들의 합을 저장하고 있는 tree[index] 를 반환한다.

  3. [left, right][start, end]에 속해 있는 경우

  4. [left, right][start, end] 가 일부 겹치는 경우

    재귀적으로 더 들어가서 어디까지의 값들이 필요한지에 대해 구한다.


int sum(int start, int end, int left, int right, int index)
{
     // 구간이 전혀 겹치지 않는 경우
    if (start > right || end < left)
        return 0;
    else if (left <= start && end <=right)
        return tree[index];
    else {
        int mid = (start+end) / 2;
        return sum(index*2+1, start, mid, left, right) + sum(index*2+2, mid+1, end, left, right);
    }
}

응용하면 구간 합 뿐만 아니라 최소/최댓값 등 많은 부분에서 유용하게 쓸 수 있다.


값 변경하기

void update(int changed_index, int diff, int index, int start, int end)
{
	if (changed_index < start || changed_index > end) return;
    tree[index] += diff;
    
    if (start != end){
        int mid = (start+end) / 2;
        update(changed_index, diff, index*2+1, start, mid);
        update(changed_index, diff, index*2+2, mid+1, end);
    }
}

diff 는 바꿀 새로운 값이 아니다

diff = 새로 바꿀 값 - A[changed_index] (기존의 값)
A[changed_index] = 새로 바꿀 값

재귀적으로 양 쪽 자식노드로 나눠가며 start == end 가 될 때 까지, 즉 리프 노드가 될 때까지 탐색을 한다.
탐색을 할 시에는 탐색 범위 안에 없다면 return 하고 탐색 범위 안에 있다면 변경된 노드의 증가값 diff 만큼 노드에 더해준다.


출처

profile
🍕

0개의 댓글