[day12] Segment Tree & Fenwick Tree

나는컴공생·2025년 3월 19일

SW검정준비

목록 보기
10/11

Segment Tree

  • 크기 주로 4를 곱해서 사용
  • tree[N] ~ tree[2*N -1] : leaf node
  • tree[1] ~ tree[N-1] : 중간 노드
  • 리프 노드가 N부터 시작하므로, 트리의 구조는 여전히 1-based 인덱스처럼 동작.
  • N이 홀수든 짝수든, 부모 인덱스 k에 대해 2k(짝수), 2k+1(홀수) 규칙은 변하지 않음.
  • 1-based 트리에서 자식 쌍은 항상 2k와 2k+1.
  • pos / 2로 부모를 찾고, pos ^ 1은 반대 자식을 정확히 가리킴.

방법1: tree_size = N으로 두기

#include <stdio.h>
#include <string.h> //memset용
#define NMAX 1000

int tsize = NMAX;
int segTree[2 * NMAX];
int src[NMAX]; 
int N;

void update(int pos, int val) {
	pos += tsize; // 0번 -> tsize로 변환
	segTree[pos] = val;
	while (pos > 1) {
		int par = pos / 2;
		int sibling = pos ^ 1;
		segTree[par] = segTree[sibling] + segTree[pos];
		pos = par;
	}
}
//l번 index부터 r번 index까지의 구간합
int query(int l, int r) {
	int res = 0;
	l += tsize; r += tsize;
	while (l <= r) {
		//left가 홀수면? 오른쪽 자식
		if (l % 2 == 1) {
			res += segTree[l];
			l += 1;
		}
		//right가 짝수면? 왼쪽 자식
		if (r % 2 == 0) {
			res += segTree[r];
			r -= 1;
		}
		l /= 2; r /= 2;
	}
	return res;
}

int main() {
	N = 10;
	tsize = N;
	for (int i = 0; i < N; ++i) src[i] = i;
	memset(segTree, 0, sizeof(segTree)); 
	// N ~ 부터는 단말노드
	for (int i = 0; i < N; ++i) {
		segTree[tsize + i] = src[i];
	}
	//1 ~ N-1 까지는 구간합 노드들
	for (int i = N - 1; i > 0; --i) {
		segTree[i] = segTree[i * 2] + segTree[i * 2 + 1];
	}

	int left = 3;
	int right = 9;
	printf("sum(%d ~ %d): %d\n",left, right, query(3, 9));
	update(5, 5);
	printf("sum(%d ~ %d): %d\n",left, right, query(3, 9));
	return 0;
}

방법 2: 노드마다 구간 적기

#include <stdio.h>
#include <cmath>
#include <vector>
using namespace std;
#define MAXN 100001
int src[MAXN];
/*
segment tree
	- 크기: 4*N, 2*N
	- 높이: ceiling(log2(N))
	- len : 1 << (h+1)
*/
int segTree[4 * MAXN]; //주로 4 곱해서 사용

//vector<int> src;
//vector<int> segTree;

int N;
//node마다 어떤 구간인지 알아야 하는 경우
//init
int init(int node, int start, int end) {
	//leaf node인 경우
	if (start == end) {
		return segTree[node] = src[start];
	}
	int mid = (start + end) / 2;
	int leftNode = node << 1; //node *2
	return segTree[node] = init(leftNode, start, mid) + init(leftNode+1, mid + 1, end);
}

//update: 해당 index의 값을 diff만큼 더하고 싶다.
void update(int node, int start, int end, int idx, int diff) {
	//범위 벗어나는 경우 pass
	if (idx < start || end < idx) return;

	segTree[node] += diff;
	//단말노드가 아니라면, 밑에 자식들까지 update 해줘야함.
	if (start != end) {
		int mid = (start + end) / 2;
		int leftNode = node << 1;
		update(leftNode, start, mid, idx, diff);
		update(leftNode + 1, mid + 1, end, idx, diff);
	}
}
//sum: [left, right] 구간의 합을 알고싶다.
int sum (int node, int start, int end, int left, int right) {
	//1. 해당 노드의 구간이 모두 벗어나는 경우
	if (right < start || end < left) return 0;
	//2. 해당 노드의 구간이 모두 포함되는 경우
	else if (left <= start && end <= right) {
		return segTree[node];
	}
	//3. 해당 노드의 구간이 포함하는 경우(관련 없는 구간 존재)
	//4. 해당 노드의 구간이 일부만 포함(관련 없는 구간 존재)
	int mid = (start + end) / 2;
	int leftNode = node << 1;
	return sum(leftNode, start, mid, left, right) + sum(leftNode + 1, mid + 1, end, left, right);
	
}

int main() {
	N = 10;
	//src.clear();
	//src.resize(N+1);
	//segTree.clear();
	//int h = ceil(log2(N));
	//segTree.resize((h+1) >>1);
	//segTree.resize(N * 4);
	for (int i = 1; i <= N; ++i) {
		src[i] = i;
	}
	init(1, 1, N); //root node부터 초기화

	int left = 3; int right = 9;
	printf("sum(%d ~ %d ): %d\n", left, right, sum(1, 1, N, left, right));

	update(1, 1, N, 5, 5);
	printf("sum(%d ~ %d ): %d\n", left, right, sum(1, 1, N, left, right));
	return 0;
}

Fenwick Tree(Binary Index Tree)

: 연속된 구간합 빠르게 구하기

  • 항상 원소는 1부터!!!

  • 자료구조 크기는 FenwickTree[src개수] -> O(n)

  • Full binary tree(Segment tree의 진화형)

  • 구간합 rangesum[i,j] = psum[0,j] - psum[0, i-1]

  • 부분합 psum[0, j] = a[0] + a[1] ...

  • 2의 승수 k : FenwickTree[k] = psum[1, k]

  • 홀수 k : FenwickTree[k] = A[k]

  • 그 외 k : FenwickTree[k] = A[k-1] + A[k]

  • idx = idx + (idx & -idx) //뒤로 업데이트 해야하는 idx(증가)

  • idx = idx - (idx & -idx) //앞으로 업데이트 해야하는 idx (감소)

연산

update(pos, val)

: pos위치에 val 값 업데이트

while(pos <= NMAX){
	FenwickTree[pos] += val;
    pos += (pos & -pos);
}

sum(r)

int sum = 0;
while(pos <= NMAX) {
	sum += Fenwick[pos];
    pos -= (pos & -pos);
}

x & -x : x의 마지막 1 비트 추출
x - (x&-x): x의 마지막 1부트 제거

0개의 댓글