세그먼트 트리는 인덱스 트리와 비슷한 구조를 가지지만, 추가적인 공간을 활용하여 좀 더 복잡한 형태의 대표값(0이 아닌 값, 1인 비트의 수, 등)을 표현할 수 있을 뿐만 아니라 가장 큰 차이점은 동시에 구간적인 변화가 일어나도 이를 O(log2(N))만에 처리가 가능하다는 점입니다. 다양한 형태의 Segment트리들 코드를 공유합니다.
#define vi vector<int>
#define mid (lm+rm)/2
#define LQR l, r, lm, mid, pos*2
#define RQR l, r, mid+1, rm, pos*2+1
//범위적인 증가, 감소가 빈번히 일어나는 데이터의 구간합
struct ST{
int lv, w;
vi node;
vi sum;
// [l, r]에¯val을 더함 단, [l,r]은 0-based+
void update( int val,int l, int r, int lm=0,int rm=0, int pos = 1 ){
if(pos == 1) rm = w-1;
if( l <= lm && rm <= r )
node[pos] += val;
else{
if( l <= mid ) update(val,LQR);
if( mid+1<=r ) update(val,RQR);
}
sum[pos] = node[pos] * (rm-lm+1);
if(lm < rm)
sum[pos] += sum[pos*2] + sum[pos*2+1];
}
ST(int N ){
for(lv=0, w=1; w<N; w*=2,lv++);
node = sum = vi(w*2+1);
}
//구간합 : sumq(l,r) 단, [l,r ]은 0-based
int sumq(int l, int r, int lm=0,int rm=0, int pos = 1){
if(pos == 1) rm = w-1;
if( r < lm || rm < l ) return 0;
if( l <= lm && rm <= r ) return sum[pos];
return sumq(LQR) + sumq(RQR);
}
};
#define vi vector<int>
#define mid (lm+rm)/2
#define LQR l, r, lm, mid, pos*2
#define RQR l, r, mid+1, rm, pos*2+1
// 구간을 토글하는 비트에서 특정 구간의 특정 비트 개수 카운트
#define vb vector<bool>
struct TST{
int lv, w;
vb node;
vi cnt;
// [l, r]을 토글함 단, [l,r] 은¨¬ 0-based
void toggle(int l, int r, int lm=0,int rm=0, int pos = 1 ){
if(pos == 1) rm = w-1;
if( l <= lm && rm <= r ){
node[pos] = !node[pos];
cnt[pos] = (rm - lm + 1 ) - cnt[pos];
}
else{
if( l <= mid ) toggle(LQR);
if( mid+1<=r ) toggle(RQR);
cnt[pos] = cnt[pos*2] + cnt[pos*2+1];
if(node[pos])
cnt[pos] = (rm - lm + 1 ) - cnt[pos];
}
}
TST(int N ){
for(lv=0, w=1; w<N; w*=2,lv++);
node = vb(w*2+1);
cnt = vi(w*2+1);
}
// bit 1 count : cntq(l,r) 단¥U, [l,r ]은¨¬ 0-based
int cntq(int l, int r, int lm=0,int rm=0, int pos = 1,bool tog = false){
if(pos == 1) rm = w-1;
if( r < lm || rm < l )
return 0;
if( l <= lm && rm <= r ){
if(tog)
return (rm-lm+1) - cnt[pos];
return cnt[pos];
}
else
return cntq(LQR, node[pos] ^ tog) + cntq(RQR,node[pos] ^ tog);
}
};
// 구간에 특정 값을 더함
// 0이 아닌 칸의 개수를 샌다
struct PST{
int lv, w;
vi node;
vi cnt;
// 0-based [l, r]에 val을 더함
void update(int val, int l, int r, int lm=0,int rm=0, int pos = 1 ){
if(pos == 1) rm = w-1;
if( l <= lm && rm <= r )
node[pos] += val;
else{
if( l <= mid ) update(val,LQR);
if( mid+1<=r ) update(val,RQR);
}
cnt[pos] = 0;
if(node[pos]) cnt[pos] = rm - lm + 1;
if(!node[pos] && lm < rm) cnt[pos] = cnt[pos*2] + cnt[pos*2+1];
}
PST(int N ){
for(lv=0, w=1; w<N; w*=2,lv++);
node = cnt = vi(w*2+1);
}
// 0-based [l,r ] 구간체크 cntq(l,r);
int cntq(int l, int r, int lm=0,int rm=0, int pos = 1){
if(pos == 1) rm = w-1;
if( r < lm || rm < l ) return 0;
if( l <= lm && rm <= r ) return cnt[pos];
return cntq(LQR) + cntq(RQR);
}
};