#include <stdio.h>
#include <string.h> //memset용
#define NMAX 1000
int tsize = NMAX;
int segTree[2 * NMAX];
int src[NMAX];
int N;
void update(int pos, int val) {
pos += tsize; // 0번 -> tsize로 변환
segTree[pos] = val;
while (pos > 1) {
int par = pos / 2;
int sibling = pos ^ 1;
segTree[par] = segTree[sibling] + segTree[pos];
pos = par;
}
}
//l번 index부터 r번 index까지의 구간합
int query(int l, int r) {
int res = 0;
l += tsize; r += tsize;
while (l <= r) {
//left가 홀수면? 오른쪽 자식
if (l % 2 == 1) {
res += segTree[l];
l += 1;
}
//right가 짝수면? 왼쪽 자식
if (r % 2 == 0) {
res += segTree[r];
r -= 1;
}
l /= 2; r /= 2;
}
return res;
}
int main() {
N = 10;
tsize = N;
for (int i = 0; i < N; ++i) src[i] = i;
memset(segTree, 0, sizeof(segTree));
// N ~ 부터는 단말노드
for (int i = 0; i < N; ++i) {
segTree[tsize + i] = src[i];
}
//1 ~ N-1 까지는 구간합 노드들
for (int i = N - 1; i > 0; --i) {
segTree[i] = segTree[i * 2] + segTree[i * 2 + 1];
}
int left = 3;
int right = 9;
printf("sum(%d ~ %d): %d\n",left, right, query(3, 9));
update(5, 5);
printf("sum(%d ~ %d): %d\n",left, right, query(3, 9));
return 0;
}
#include <stdio.h>
#include <cmath>
#include <vector>
using namespace std;
#define MAXN 100001
int src[MAXN];
/*
segment tree
- 크기: 4*N, 2*N
- 높이: ceiling(log2(N))
- len : 1 << (h+1)
*/
int segTree[4 * MAXN]; //주로 4 곱해서 사용
//vector<int> src;
//vector<int> segTree;
int N;
//node마다 어떤 구간인지 알아야 하는 경우
//init
int init(int node, int start, int end) {
//leaf node인 경우
if (start == end) {
return segTree[node] = src[start];
}
int mid = (start + end) / 2;
int leftNode = node << 1; //node *2
return segTree[node] = init(leftNode, start, mid) + init(leftNode+1, mid + 1, end);
}
//update: 해당 index의 값을 diff만큼 더하고 싶다.
void update(int node, int start, int end, int idx, int diff) {
//범위 벗어나는 경우 pass
if (idx < start || end < idx) return;
segTree[node] += diff;
//단말노드가 아니라면, 밑에 자식들까지 update 해줘야함.
if (start != end) {
int mid = (start + end) / 2;
int leftNode = node << 1;
update(leftNode, start, mid, idx, diff);
update(leftNode + 1, mid + 1, end, idx, diff);
}
}
//sum: [left, right] 구간의 합을 알고싶다.
int sum (int node, int start, int end, int left, int right) {
//1. 해당 노드의 구간이 모두 벗어나는 경우
if (right < start || end < left) return 0;
//2. 해당 노드의 구간이 모두 포함되는 경우
else if (left <= start && end <= right) {
return segTree[node];
}
//3. 해당 노드의 구간이 포함하는 경우(관련 없는 구간 존재)
//4. 해당 노드의 구간이 일부만 포함(관련 없는 구간 존재)
int mid = (start + end) / 2;
int leftNode = node << 1;
return sum(leftNode, start, mid, left, right) + sum(leftNode + 1, mid + 1, end, left, right);
}
int main() {
N = 10;
//src.clear();
//src.resize(N+1);
//segTree.clear();
//int h = ceil(log2(N));
//segTree.resize((h+1) >>1);
//segTree.resize(N * 4);
for (int i = 1; i <= N; ++i) {
src[i] = i;
}
init(1, 1, N); //root node부터 초기화
int left = 3; int right = 9;
printf("sum(%d ~ %d ): %d\n", left, right, sum(1, 1, N, left, right));
update(1, 1, N, 5, 5);
printf("sum(%d ~ %d ): %d\n", left, right, sum(1, 1, N, left, right));
return 0;
}
: 연속된 구간합 빠르게 구하기
항상 원소는 1부터!!!
자료구조 크기는 FenwickTree[src개수] -> O(n)
Full binary tree(Segment tree의 진화형)
구간합 rangesum[i,j] = psum[0,j] - psum[0, i-1]
부분합 psum[0, j] = a[0] + a[1] ...
2의 승수 k : FenwickTree[k] = psum[1, k]
홀수 k : FenwickTree[k] = A[k]
그 외 k : FenwickTree[k] = A[k-1] + A[k]
idx = idx + (idx & -idx) //뒤로 업데이트 해야하는 idx(증가)
idx = idx - (idx & -idx) //앞으로 업데이트 해야하는 idx (감소)
: pos위치에 val 값 업데이트
while(pos <= NMAX){
FenwickTree[pos] += val;
pos += (pos & -pos);
}
int sum = 0;
while(pos <= NMAX) {
sum += Fenwick[pos];
pos -= (pos & -pos);
}
x & -x : x의 마지막 1 비트 추출
x - (x&-x): x의 마지막 1부트 제거