트리는 데이터를 효율적으로 저장하고, 빠른 시간에 탐색, 수정 및 추가 삭제 연산을 진행할 수 있도록 하는 자료구조 이다.
세그먼트 트리는 이 트리 구조의 이점을 살려서 특정 구간의 특정 정보를 빠르게 구할 수 있도록 한다.
예를 들어, {1, 2, 3, 4, 5, 6}의 배열에서 2-5 구간의 합을 구하려고 할 때,
한번의 연산이야 중요하지 않지만, 이러한 연산을 만 번씩 해야 할 때 걸리는 시간은 어마어마 할 것이다.
이러한 문제를 해결하기 위해 세그먼트 트리는 자식의 구간 정보를 부모 노드에 기록하는 방식을 사용한다.
아래의 그림을 확인하자.
위 그림에서 보면 이해 할 수 있듯이, 가장 높은 레벨의 리프 노드들은 원본 배열의 실제적인 값이 저장되있고, 리프 노드들의 부모 노드들은 자기 자식들에 대한 정보를 저장하고 있다.
이러한 방식을 사용하면,
배열의 구간 합을 구하는 연산 O(N) -> O(logN) 으로 변하는 마법을 볼 수 있다.
사실 세그먼트 트리를 공부하고자 찾아볼때, 이론적으로 구조가 이해가 안가진 않을 것이다. 굉장히 간단하다.
하지만 구현 하고자 할 때가 까다롭다.
일반적인 트리는 루트 노드부터 작성해 나가며 연결하는데, 세그먼트 트리는 가장 마지막 노드부터 작성해서 루트 노드까지 올라가는 구조이다.
이러한 구조를 이해하기 위해 찾다 crocus 님의 글을 보게 되었는데, 너무너무 자세한 설명이라 꼭 봐야한다.
세그먼트 트리를 구현할 때 가장 이해가 안되던 것이 높이를 설정하는 부분이었는데 이부분도 명쾌하게 이해가 되었다.
그림을 보면 알다시피 현재 우리는 포인터로 동적할당을 통한 트리가 아닌 배열로 트리를 만들고 있다.
그 이유는 세그먼트 트리는 full binary tree에 가깝기에 배열에 모든 값들이 꽉꽉차서 올 가능성이 매우 높기때문에
포인터보다는 배열을 이용하게 된다. 그리고 각 노드마다의 왼쪽, 오른쪽 자식 노드는 항상 규칙이 정해져 있다.
출처:https://www.crocus.co.kr/648[Crocus]
package tree;
import java.io.*;
import java.util.*;
class SegmentTree {
long[] tree;
/**
* @param tree - 세그먼트 트리의 배열
* @param list - 원본 숫자의 배열
*/
SegmentTree(int N, long[] list) {
tree = new long[N*4];
init(0, N-1, 1, list);
}
/**
* @param start - 시작 인덱스
* @param end - 끝 인덱스
*/
long init(int start, int end, int node, long[] list) {
// 리프 노드
if(start == end) return tree[node] = list[start];
int mid = (start + end) / 2;
return tree[node] = init(start, mid, node*2, list) + init(mid+1, end, node*2 + 1, list);
}
/**
* @param left, right - 구간 합 구하고자 하는 범위
* @return - 구간 합
*/
long sum(int node, int start, int end, int left, int right) {
// 범위를 벗어난 경우 종료
if(left > end || right < start) return 0;
// 범위 안에 start - end 가 포함된 경우 node의 자식도 모두 포함되기 때문에 tree[node] 리턴
if(left <= start && end <= right) return tree[node];
int mid = (start + end) / 2;
// 위 두 경우가 아닌 경우 자식 노드들 탐색
return sum(node*2, start, mid, left, right) + sum(node*2 + 1, mid+1, end, left, right);
}
/**
* @param idx - 수정할 인덱스
* @param diff - 바뀐 정도(차)
*/
void update(int node, int start, int end, int idx, long diff) {
// 범위를 벗어난 경우 종료
if(idx < start || end < idx) return;
// 노드 찾아 내려가면서 갱신
tree[node] += diff;
// 리프 노드가 아닌 경우 자식 노드도 갱신
if(start != end) {
int mid = (start + end) / 2;
update(node * 2, start, mid, idx, diff);
update(node * 2 + 1, mid+1, end, idx, diff);
}
}
}
public class Segtree {
static int N, K;
public static void main(String[] args) throws IOException{
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
StringTokenizer st = new StringTokenizer(br.readLine());
StringBuilder sb = new StringBuilder();
N = Integer.parseInt(st.nextToken());
K = Integer.parseInt(st.nextToken()) + Integer.parseInt(st.nextToken());
long[] list = new long[N];
for(int i = 0; i < N; i++) {
list[i] = Long.parseLong(br.readLine());
}
// segment tree 초기화
SegmentTree segTree = new SegmentTree(N, list);
while(K-- > 0) {
st = new StringTokenizer(br.readLine(), " ");
int a = Integer.parseInt(st.nextToken());
int b = Integer.parseInt(st.nextToken()) - 1;
long c = Long.parseLong(st.nextToken());
switch(a) {
case 1:
segTree.update(1, 0, N-1, b, c - list[b]);
list[b] = c;
break;
case 2: sb.append(segTree.sum(1, 0, N-1, b, (int)(c-1))).append("\n"); break;
}
}
System.out.println(sb.toString());
}
}
가장 기본적인 구간 합을 구하는 문제다. 가장 기초적인 유형이라 세그먼트 트리를 구현해보고자 할 때 좋은 것 같다.
기초 예제로 백준 2357 : 최솟값과 최댓값 문제도 좋은 예제이다.
문제 내의 update 함수로 시간 복잡도가 O(logn) 이다.