세그먼트 트리(Segment Tree)

Jaca·2021년 8월 11일
0

세그먼트 트리란?

트리는 데이터를 효율적으로 저장하고, 빠른 시간에 탐색, 수정 및 추가 삭제 연산을 진행할 수 있도록 하는 자료구조 이다.
세그먼트 트리는 이 트리 구조의 이점을 살려서 특정 구간의 특정 정보를 빠르게 구할 수 있도록 한다.

예를 들어, {1, 2, 3, 4, 5, 6}의 배열에서 2-5 구간의 합을 구하려고 할 때,
한번의 연산이야 중요하지 않지만, 이러한 연산을 만 번씩 해야 할 때 걸리는 시간은 어마어마 할 것이다.

이러한 문제를 해결하기 위해 세그먼트 트리는 자식의 구간 정보를 부모 노드에 기록하는 방식을 사용한다.
아래의 그림을 확인하자.

위 그림에서 보면 이해 할 수 있듯이, 가장 높은 레벨의 리프 노드들은 원본 배열의 실제적인 값이 저장되있고, 리프 노드들의 부모 노드들은 자기 자식들에 대한 정보를 저장하고 있다.

이러한 방식을 사용하면,
배열의 구간 합을 구하는 연산 O(N) -> O(logN) 으로 변하는 마법을 볼 수 있다.

사실 세그먼트 트리를 공부하고자 찾아볼때, 이론적으로 구조가 이해가 안가진 않을 것이다. 굉장히 간단하다.
하지만 구현 하고자 할 때가 까다롭다.

일반적인 트리는 루트 노드부터 작성해 나가며 연결하는데, 세그먼트 트리는 가장 마지막 노드부터 작성해서 루트 노드까지 올라가는 구조이다.

이러한 구조를 이해하기 위해 찾다 crocus 님의 글을 보게 되었는데, 너무너무 자세한 설명이라 꼭 봐야한다.

세그먼트 트리를 구현할 때 가장 이해가 안되던 것이 높이를 설정하는 부분이었는데 이부분도 명쾌하게 이해가 되었다.

그림을 보면 알다시피 현재 우리는 포인터로 동적할당을 통한 트리가 아닌 배열로 트리를 만들고 있다.

그 이유는 세그먼트 트리full binary tree에 가깝기에 배열에 모든 값들이 꽉꽉차서 올 가능성이 매우 높기때문에

포인터보다는 배열을 이용하게 된다. 그리고 각 노드마다의 왼쪽, 오른쪽 자식 노드는 항상 규칙이 정해져 있다.

출처:https://www.crocus.co.kr/648[Crocus]

기본 예제

백준 2042 : 구간 합 구하기

package tree;

import java.io.*;
import java.util.*;

class SegmentTree {
    long[] tree;

    /**
     * @param tree - 세그먼트 트리의 배열
     * @param list - 원본 숫자의 배열 
     */
    SegmentTree(int N, long[] list) {
        tree = new long[N*4];
        init(0, N-1, 1, list);
    }

    /**
     * @param start - 시작 인덱스
     * @param end - 끝 인덱스
     */
    long init(int start, int end, int node, long[] list) {
        // 리프 노드
        if(start == end) return tree[node] = list[start];

        int mid = (start + end) / 2;

        return tree[node] = init(start, mid, node*2, list) + init(mid+1, end, node*2 + 1, list);
    }

    /**
     * @param left, right - 구간 합 구하고자 하는 범위
     * @return - 구간 합
     */
    long sum(int node, int start, int end, int left, int right) {
        // 범위를 벗어난 경우 종료
        if(left > end || right < start) return 0;
        // 범위 안에 start - end 가 포함된 경우 node의 자식도 모두 포함되기 때문에 tree[node] 리턴
        if(left <= start && end <= right) return tree[node];

        int mid = (start + end) / 2;
        // 위 두 경우가 아닌 경우 자식 노드들 탐색
        return sum(node*2, start, mid, left, right) + sum(node*2 + 1, mid+1, end, left, right);
    }

    /**
     * @param idx - 수정할 인덱스
     * @param diff - 바뀐 정도(차)
     */
    void update(int node, int start, int end, int idx, long diff) {
        // 범위를 벗어난 경우 종료
        if(idx < start || end < idx) return;
        // 노드 찾아 내려가면서 갱신
        tree[node] += diff;
        // 리프 노드가 아닌 경우 자식 노드도 갱신
        if(start != end) {
            int mid = (start + end) / 2;
            update(node * 2, start, mid, idx, diff);
            update(node * 2 + 1, mid+1, end, idx, diff);
        }
    }
}

public class Segtree {
    static int N, K;

    public static void main(String[] args) throws IOException{
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        StringTokenizer st = new StringTokenizer(br.readLine());
        StringBuilder sb = new StringBuilder();

        N = Integer.parseInt(st.nextToken());
        K = Integer.parseInt(st.nextToken()) + Integer.parseInt(st.nextToken());

        long[] list = new long[N];

        for(int i = 0; i < N; i++) {
            list[i] = Long.parseLong(br.readLine());
        }

        // segment tree 초기화
        SegmentTree segTree = new SegmentTree(N, list);

        while(K-- > 0) {
            st = new StringTokenizer(br.readLine(), " ");
            int a = Integer.parseInt(st.nextToken());
            int b = Integer.parseInt(st.nextToken()) - 1;
            long c = Long.parseLong(st.nextToken());

            switch(a) {
            case 1: 
                segTree.update(1, 0, N-1, b, c - list[b]); 
                list[b] = c;
                break;
            case 2: sb.append(segTree.sum(1, 0, N-1, b, (int)(c-1))).append("\n"); break;
            }
        }

        System.out.println(sb.toString());
    }
}

가장 기본적인 구간 합을 구하는 문제다. 가장 기초적인 유형이라 세그먼트 트리를 구현해보고자 할 때 좋은 것 같다.
기초 예제로 백준 2357 : 최솟값과 최댓값 문제도 좋은 예제이다.

문제 내의 update 함수로 시간 복잡도가 O(logn) 이다.

profile
I am me

0개의 댓글