Segment Tree

열수철·2025년 3월 5일

1. 세그먼트 트리란?

세그먼트 트리는 배열의 구간에 대한 정보를 빠르게 계산하고 업데이트하기 위한 트리 자료구조입니다. 주로 다음과 같은 작업을 효율적으로 수행할 수 있습니다:

  • 구간 합 구하기 (Range Sum Query)
  • 구간 최소값 구하기 (Range Minimum Query)
  • 구간 최대값 구하기 (Range Maximum Query)
  • 기타 구간에 대한 연산 (곱셈, XOR 등)

세그먼트 트리는 특히 배열의 값이 자주 변하고, 구간 쿼리가 빈번하게 발생하는 상황에서 매우 유용합니다.

2. 세그먼트 트리의 특징

  • 시간 복잡도:
    • 구축: O(n)O(n)
    • 쿼리: O(logn)O(log n)
    • 업데이트: O(logn)O(log n)
  • 공간 복잡도: O(n)O(n)
  • 완전 이진 트리 형태로, 각 노드는 배열의 특정 구간에 대한 정보를 저장

3. 세그먼트 트리의 구조

세그먼트 트리는 다음과 같은 규칙을 따릅니다:

  • 리프 노드: 원본 배열의 각 원소에 해당
  • 내부 노드: 자식 노드들의 정보를 결합한 값 (예: 두 자식 구간의 합)
  • 루트 노드: 전체 배열에 대한 정보 (예: 전체 배열의 합)

4. 세그먼트 트리 구현 (구간 합 예시)

다음은 구간 합을 위한 세그먼트 트리의 구현입니다:

#include <vector>
#include <iostream>

class SegmentTree {
private:
    std::vector<int> tree;
    int n;

    // 세그먼트 트리 구축
    void build(const std::vector<int>& arr, int node, int start, int end) {
        // 리프 노드인 경우
        if (start == end) {
            tree[node] = arr[start];
            return;
        }
        
        // 내부 노드인 경우
        int mid = (start + end) / 2;
        // 왼쪽 자식 노드 구축
        build(arr, 2 * node, start, mid);
        // 오른쪽 자식 노드 구축
        build(arr, 2 * node + 1, mid + 1, end);
        // 부모 노드 값 계산 (두 자식 노드의 합)
        tree[node] = tree[2 * node] + tree[2 * node + 1];
    }

    // 구간 합 쿼리
    int query(int node, int start, int end, int left, int right) {
        // 쿼리 범위가 현재 노드 범위와 겹치지 않는 경우
        if (right < start || end < left) {
            return 0;
        }
        
        // 쿼리 범위가 현재 노드 범위를 완전히 포함하는 경우
        if (left <= start && end <= right) {
            return tree[node];
        }
        
        // 쿼리 범위가 현재 노드 범위와 일부 겹치는 경우
        int mid = (start + end) / 2;
        int left_sum = query(2 * node, start, mid, left, right);
        int right_sum = query(2 * node + 1, mid + 1, end, left, right);
        return left_sum + right_sum;
    }

    // 값 업데이트
    void update(int node, int start, int end, int idx, int val) {
        // 리프 노드에 도달한 경우
        if (start == end) {
            tree[node] = val;
            return;
        }
        
        int mid = (start + end) / 2;
        if (idx <= mid) {
            // 왼쪽 자식으로 이동
            update(2 * node, start, mid, idx, val);
        } else {
            // 오른쪽 자식으로 이동
            update(2 * node + 1, mid + 1, end, idx, val);
        }
        
        // 부모 노드 값 갱신
        tree[node] = tree[2 * node] + tree[2 * node + 1];
    }

public:
    SegmentTree(const std::vector<int>& arr) {
        n = arr.size();
        // 세그먼트 트리 배열의 크기는 4*n으로 설정 (충분한 공간)
        tree.resize(4 * n, 0);
        build(arr, 1, 0, n - 1);
    }
    
    // 간편한 인터페이스를 위한 래퍼 함수들
    int get_sum(int left, int right) {
        return query(1, 0, n - 1, left, right);
    }
    
    void update_value(int idx, int val) {
        update(1, 0, n - 1, idx, val);
    }
};

5. 세그먼트 트리 작동 예시

다음 배열을 사용하여 세그먼트 트리의 작동을 살펴보겠습니다:

arr = [1, 3, 5, 7, 9, 11]

세그먼트 트리 구축

세그먼트 트리를 구축하면 다음과 같은 구조가 됩니다:

  • 노드 1 (루트): 합 = 36 (전체 배열의 합)
  • 노드 2: 합 = 9 (인덱스 0~2 범위의 합)
  • 노드 3: 합 = 27 (인덱스 3~5 범위의 합)
  • 노드 4: 합 = 4 (인덱스 0~1 범위의 합)
  • 노드 5: 합 = 5 (인덱스 2의 값)
  • 노드 6: 합 = 16 (인덱스 3~4 범위의 합)
  • 노드 7: 합 = 11 (인덱스 5의 값)
  • 노드 8: 합 = 1 (인덱스 0의 값)
  • 노드 9: 합 = 3 (인덱스 1의 값)
  • 노드 12: 합 = 7 (인덱스 3의 값)
  • 노드 13: 합 = 9 (인덱스 4의 값)

쿼리 예시: 구간 [2, 4]의 합

  1. 루트 노드(1)에서 시작
  2. 쿼리 범위 [2, 4]는 노드 범위 [0, 5]와 일부 겹침
  3. 왼쪽 자식(노드 2)의 범위 [0, 2]와 쿼리 범위 [2, 4]는 일부 겹침
  4. 노드 4의 범위 [0, 1]은 쿼리 범위와 겹치지 않음 -> 0 반환
  5. 노드 5의 범위 [2, 2]는 쿼리 범위에 완전히 포함 -> 5 반환
  6. 오른쪽 자식(노드 3)의 범위 [3, 5]와 쿼리 범위 [2, 4]는 일부 겹침
  7. 노드 6의 범위 [3, 4]는 쿼리 범위에 완전히 포함 -> 16 반환
  8. 노드 7의 범위 [5, 5]는 쿼리 범위와 겹치지 않음 -> 0 반환
  9. 최종 결과: 5 + 16 = 21

이는 arr[2] + arr[3] + arr[4] = 5 + 7 + 9 = 21과 일치합니다.

업데이트 예시: arr[1] = 10으로 변경

  1. 루트 노드(1)에서 시작
  2. 인덱스 1은 중간값보다 작으므로 왼쪽 자식(노드 2)으로 이동
  3. 인덱스 1은 중간값보다 작으므로 왼쪽 자식(노드 4)으로 이동
  4. 인덱스 1은 중간값보다 크므로 오른쪽 자식(노드 9)으로 이동
  5. 리프 노드 도달, 값을 10으로 업데이트
  6. 상위 노드들의 값 갱신:
    • 노드 4의 합 = 11 (1 + 10)
    • 노드 2의 합 = 16 (11 + 5)
    • 노드 1의 합 = 43 (16 + 27)

6. 세그먼트 트리의 응용

최소값 세그먼트 트리

최소값을 구하는 세그먼트 트리는 다음과 같이 구현할 수 있습니다:

#include <vector>
#include <iostream>
#include <climits>
#include <algorithm>

class MinSegmentTree {
private:
    std::vector<int> tree;
    int n;

    void build(const std::vector<int>& arr, int node, int start, int end) {
        if (start == end) {
            tree[node] = arr[start];
            return;
        }
        
        int mid = (start + end) / 2;
        build(arr, 2 * node, start, mid);
        build(arr, 2 * node + 1, mid + 1, end);
        tree[node] = std::min(tree[2 * node], tree[2 * node + 1]);
    }
    
    int query(int node, int start, int end, int left, int right) {
        if (right < start || end < left) {
            return INT_MAX;
        }
        
        if (left <= start && end <= right) {
            return tree[node];
        }
        
        int mid = (start + end) / 2;
        int left_min = query(2 * node, start, mid, left, right);
        int right_min = query(2 * node + 1, mid + 1, end, left, right);
        return std::min(left_min, right_min);
    }
    
    void update(int node, int start, int end, int idx, int val) {
        if (start == end) {
            tree[node] = val;
            return;
        }
        
        int mid = (start + end) / 2;
        if (idx <= mid) {
            update(2 * node, start, mid, idx, val);
        } else {
            update(2 * node + 1, mid + 1, end, idx, val);
        }
        
        tree[node] = std::min(tree[2 * node], tree[2 * node + 1]);
    }

public:
    MinSegmentTree(const std::vector<int>& arr) {
        n = arr.size();
        tree.resize(4 * n, INT_MAX);
        build(arr, 1, 0, n - 1);
    }
    
    int get_min(int left, int right) {
        return query(1, 0, n - 1, left, right);
    }
    
    void update_value(int idx, int val) {
        update(1, 0, n - 1, idx, val);
    }
};

구간 업데이트를 지원하는 세그먼트 트리 (Lazy Propagation)

대규모 구간 업데이트가 필요한 경우, 지연 전파(Lazy Propagation) 기법을 사용하여 효율성을 높일 수 있습니다:

#include <vector>
#include <iostream>

class LazySegmentTree {
private:
    std::vector<int> tree, lazy;
    int n;

    void build(const std::vector<int>& arr, int node, int start, int end) {
        if (start == end) {
            tree[node] = arr[start];
            return;
        }
        
        int mid = (start + end) / 2;
        build(arr, 2 * node, start, mid);
        build(arr, 2 * node + 1, mid + 1, end);
        tree[node] = tree[2 * node] + tree[2 * node + 1];
    }
    
    void propagate(int node, int start, int end) {
        if (lazy[node] != 0) {
            // 현재 노드에 지연된 업데이트 적용
            tree[node] += (end - start + 1) * lazy[node];
            
            // 리프 노드가 아니면 자식 노드로 지연 전파
            if (start != end) {
                lazy[2 * node] += lazy[node];
                lazy[2 * node + 1] += lazy[node];
            }
            
            // 현재 노드의 지연 값 초기화
            lazy[node] = 0;
        }
    }
    
    void update_range(int node, int start, int end, int left, int right, int val) {
        // 지연된 업데이트 먼저 처리
        propagate(node, start, end);
        
        // 업데이트 범위와 겹치지 않는 경우
        if (right < start || end < left) {
            return;
        }
        
        // 업데이트 범위가 현재 노드 범위를 완전히 포함하는 경우
        if (left <= start && end <= right) {
            // 현재 노드에 업데이트 적용
            tree[node] += (end - start + 1) * val;
            
            // 리프 노드가 아니면 자식 노드로 지연 전파
            if (start != end) {
                lazy[2 * node] += val;
                lazy[2 * node + 1] += val;
            }
            
            return;
        }
        
        // 업데이트 범위가 현재 노드 범위와 일부 겹치는 경우
        int mid = (start + end) / 2;
        update_range(2 * node, start, mid, left, right, val);
        update_range(2 * node + 1, mid + 1, end, left, right, val);
        
        // 부모 노드 값 갱신
        tree[node] = tree[2 * node] + tree[2 * node + 1];
    }
    
    int query(int node, int start, int end, int left, int right) {
        // 지연된 업데이트 먼저 처리
        propagate(node, start, end);
        
        // 쿼리 범위가 현재 노드 범위와 겹치지 않는 경우
        if (right < start || end < left) {
            return 0;
        }
        
        // 쿼리 범위가 현재 노드 범위를 완전히 포함하는 경우
        if (left <= start && end <= right) {
            return tree[node];
        }
        
        // 쿼리 범위가 현재 노드 범위와 일부 겹치는 경우
        int mid = (start + end) / 2;
        int left_sum = query(2 * node, start, mid, left, right);
        int right_sum = query(2 * node + 1, mid + 1, end, left, right);
        return left_sum + right_sum;
    }

public:
    LazySegmentTree(const std::vector<int>& arr) {
        n = arr.size();
        tree.resize(4 * n, 0);
        lazy.resize(4 * n, 0);
        build(arr, 1, 0, n - 1);
    }
    
    // 간편한 인터페이스를 위한 래퍼 함수들
    void range_update(int left, int right, int val) {
        update_range(1, 0, n - 1, left, right, val);
    }
    
    int get_sum(int left, int right) {
        return query(1, 0, n - 1, left, right);
    }
};

7. 세그먼트 트리 활용 문제

세그먼트 트리는 다음과 같은 문제에서 자주 활용됩니다:

  1. 구간 합 쿼리 - 배열의 특정 구간의 합을 빠르게 계산
  2. 구간 최소/최대값 쿼리 - 배열의 특정 구간에서 최소값 또는 최대값을 찾기
  3. 구간 GCD/LCM 쿼리 - 배열의 특정 구간에서 최대공약수 또는 최소공배수 계산
  4. 히스토그램에서 가장 큰 직사각형 찾기 - 세그먼트 트리로 최소값의 위치를 빠르게 찾아 분할 정복 적용
  5. 구간 업데이트 문제 - 배열의 특정 구간에 일정 값을 더하거나 곱하는 문제
profile
그래픽스, 수학, 물리, 게임 만세

0개의 댓글