Segment Tree

SangHoon Lee·2022년 8월 16일
0

세그먼트 트리는 구간을 업데이트 해 주면서 여러가지 알고리즘을 구현하기에 편리한 트리 구조이다.

특히 구간 합을 구할 때 많이 사용하며 일반적으로 구간 합을 구할 때 예시를 보게 된다면,

// sum : 2 to 5
int arr[5] = { 1, 2, 3, 4, 5 };
for(int i = 1; i< 5; i++) {
    int sum += arr[i];
}

// time complexity : O(n)

이렇게 배열 인덱스에 하나씩 접근하는 방법이 있다. 이렇게 작성하게 된다면, 시간 복잡도가 O(N)이므로, 데이터가 커질수록 느린 속도를 가질 수 있다. 이런 경우 사용하는것이 O(logN) 만큼 걸리는 세그먼트 트리이다.

쿼리 부분을 어떻게 작성하는지에 따라 여러가지 용도로 사용이 된다.

원리는 그림과 같다.

세그먼트 트리를 하기 전, 꼭 알아두어야 할 사항은 아래와 같다.

Root node 기준 (Root node index : 1)

  • 왼쪽 노드 : node * 2
  • 오른쪽 노드 : node * 2 + 1
  • 부모 노드 : node / 2

세그먼트 트리를 생성하는 코드는 다음과 같다.

// seg tree를 위한 struct
typedef struct tree {
	ll value;
	ll lazy; // lazy propagation
}tree;

int tree_size = 0;

ll init(tree *T,int node,int start,int end) {
	if(start == end) return T[node].value = v[start];
	else {
		ll mid = (start + end) / 2;
		return T[node].value = init(T,node *2, start,mid) + init(T,node * 2 + 1, mid +1, end);
	}	
}

그림을 보게 되면 F의 값이 새로 추가되는 과정이다. 값이 추가되면, 부모 노드를 거쳐서 Root 노드까지 계속 더해주면 된다.

void update(tree *T, int node,int start,int end, int i, int j, ll dif) {
	if(T[node].lazy != 0) {
		T[node].value += (end-start + 1) * T[node].lazy;
		if(start != end) {
			T[node * 2].lazy += T[node].lazy;
			T[node * 2 + 1].lazy += T[node].lazy;
		}
		T[node].lazy = 0;
	}
	
	if(j < start || i > end) return;
	
	if(i <= start && end <= j) {
		T[node].value += (end - start + 1) * dif;
		if(start != end) {
			T[node * 2].lazy += dif;
			T[node * 2 + 1].lazy += dif;
		}
		return;
	}
	
	int mid = (start + end) / 2;
	
	update(T,node * 2, start,mid,i,j,dif);
	update(T,node * 2 + 1, mid+1, end,i,j,dif);
	
	T[node].value = T[node * 2].value + T[node * 2 + 1].value;
}

세그먼트 트리의 구간 합 구하는 코드는 아래와 같다.

ll segtree_sum(tree *T,int node, int start, int end, int i, int j) {
	if(T[node].lazy != 0) {
		T[node].value += (end - start + 1) * T[node].lazy;
		if(start != end) {
			T[node *2].lazy += T[node].lazy;
			T[node * 2 + 1].lazy += T[node].lazy;
		}
		T[node].lazy = 0;
	}
	
	if(i > end || j < start) return 0;
	if(i <= start && end <= j) return T[node].value;
	
	ll mid = (start + end) / 2;
	
	return segtree_sum(T,node *2,start,mid,i,j) + segtree_sum(T, node * 2 + 1,mid +1, end,i,j);
}

i와 j는 i(시작 점) 부터 j(끝 점) 까지의 구간이며, 예외처리 후에 재귀를 통하여 계속 세그먼트 트리의 합을 구하여 준다.

풀 코드 (백준 2042)

#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <vector>

#define MAX_LENGTH 1000001
#define ll long long int

using namespace std;

typedef struct tree {
	ll value;
	ll lazy;
}tree;


ll v[MAX_LENGTH];

ll init(tree *T,int node,int start,int end) {
	if(start == end) return T[node].value = v[start];
	else {
		ll mid = (start + end) / 2;
		return T[node].value = init(T,node *2, start,mid) + init(T,node * 2 + 1, mid +1, end);
	}	
}

void update(tree *T, int node,int start,int end, int i, int j, ll dif) {
	if(T[node].lazy != 0) {
		T[node].value += (end-start + 1) * T[node].lazy;
		if(start != end) {
			T[node * 2].lazy += T[node].lazy;
			T[node * 2 + 1].lazy += T[node].lazy;
		}
		T[node].lazy = 0;
	}
	
	if(j < start || i > end) return;
	
	if(i <= start && end <= j) {
		T[node].value += (end - start + 1) * dif;
		if(start != end) {
			T[node * 2].lazy += dif;
			T[node * 2 + 1].lazy += dif;
		}
		return;
	}
	
	int mid = (start + end) / 2;
	
	update(T,node * 2, start,mid,i,j,dif);
	update(T,node * 2 + 1, mid+1, end,i,j,dif);
	
	T[node].value = T[node * 2].value + T[node * 2 + 1].value;
}

ll segtree_sum(tree *T,int node, int start, int end, int i, int j) {
	if(T[node].lazy != 0) {
		T[node].value += (end - start + 1) * T[node].lazy;
		if(start != end) {
			T[node *2].lazy += T[node].lazy;
			T[node * 2 + 1].lazy += T[node].lazy;
		}
		T[node].lazy = 0;
	}
	
	if(i > end || j < start) return 0;
	if(i <= start && end <= j) return T[node].value;
	
	ll mid = (start + end) / 2;
	
	return segtree_sum(T,node *2,start,mid,i,j) + segtree_sum(T, node * 2 + 1,mid +1, end,i,j);
}

int main() {
	tree *T;
	int n,m,k;
	scanf("%d %d %d",&n,&m,&k);

	for(int i = 1; i<=n; i++) {
		scanf("%lld",&v[i]);
	}
	T = (tree *)malloc(sizeof(tree) * 4 * MAX_LENGTH);
	init(T,1,1,n);
	ll change_value = 0;
	for(int i = 1; i<=m+k; i++) {
		int a,b,c;
		ll d;
		scanf("%d",&a);
		if(a == 1) {
			scanf("%d %lld",&b,&d);
			if(v[b] != d) {			
				if((d < 0 && v[b] < 0)){
					change_value = d - v[b];
				}
				else {
					change_value = d - v[b];
				}
				v[b] = d;
			}
			
			update(T,1,1,n,b,b,change_value);
			change_value = 0; 
		}
		else {
			scanf("%d %d",&b,&c);
			cout<<segtree_sum(T,1,1,n,b,c)<<endl;
		}
	}
	free(T);
	
	return 0;
}
    

다음에는 이 코드에서 쓰인 lazy propagation에 대하여 작성을 해 볼 예정이다.

profile
C++ 공부하고있는 대학생입니다.

0개의 댓글