Segment Tree 자료구조 - 구간에 대한 Update/Query

코딩몬스터TV·2020년 9월 2일
0

자료구조

목록 보기
2/2
post-thumbnail

세그먼트 트리는 인덱스 트리와 비슷한 구조를 가지지만, 추가적인 공간을 활용하여 좀 더 복잡한 형태의 대표값(0이 아닌 값, 1인 비트의 수, 등)을 표현할 수 있을 뿐만 아니라 가장 큰 차이점은 동시에 구간적인 변화가 일어나도 이를 O(log2(N))만에 처리가 가능하다는 점입니다. 다양한 형태의 Segment트리들 코드를 공유합니다.

범위에 대한 구간합 Segment Tree

#define vi vector<int>
#define mid (lm+rm)/2
#define LQR l, r, lm, mid, pos*2
#define RQR l, r, mid+1, rm, pos*2+1
 
//범위적인 증가, 감소가 빈번히 일어나는 데이터의 구간합
struct ST{
        int lv, w;
        vi node;
        vi sum;
        // [l, r]에¯val을 더함 단, [l,r]은 0-based+
        void update( int val,int l, int r, int lm=0,int rm=0, int pos = 1 ){
                if(pos == 1) rm = w-1;
                if( l <= lm && rm <= r )
                        node[pos] += val;
                else{
                        if( l <= mid ) update(val,LQR);
                        if( mid+1<=r ) update(val,RQR);
                }
                sum[pos] = node[pos] * (rm-lm+1);
                if(lm < rm)
                        sum[pos] += sum[pos*2] + sum[pos*2+1];
        }
        ST(int N ){
                for(lv=0, w=1;  w<N;  w*=2,lv++);
                node = sum = vi(w*2+1);
        }
        //구간합 : sumq(l,r)  단, [l,r ]은 0-based
        int sumq(int l, int r, int lm=0,int rm=0, int pos = 1){
                if(pos == 1) rm = w-1;
                if( r < lm || rm < l )        return 0;
                if( l <= lm && rm <= r )      return sum[pos];
                return sumq(LQR) + sumq(RQR);
        }
};

범위에 대한 비트 수를 세는 Segment Tree

#define vi vector<int>
#define mid (lm+rm)/2
#define LQR l, r, lm, mid, pos*2
#define RQR l, r, mid+1, rm, pos*2+1
 
// 구간을 토글하는 비트에서 특정 구간의 특정 비트 개수 카운트
#define vb vector<bool>
struct TST{
        int lv, w;
        vb node;
        vi cnt;
        // [l, r]을 토글함 단, [l,r] 은¨¬ 0-based
        void toggle(int l, int r, int lm=0,int rm=0, int pos = 1 ){
                if(pos == 1) rm = w-1;
                if( l <= lm && rm <= r ){
                        node[pos] = !node[pos];
                        cnt[pos] = (rm - lm + 1 ) - cnt[pos];
                }
                else{
                        if( l <= mid ) toggle(LQR);
                        if( mid+1<=r ) toggle(RQR);
                        cnt[pos] = cnt[pos*2] + cnt[pos*2+1];
                        if(node[pos])
                                cnt[pos] = (rm - lm + 1 ) - cnt[pos];
                }
        }
        TST(int N ){
                for(lv=0, w=1;  w<N;  w*=2,lv++);
                node = vb(w*2+1);
                cnt = vi(w*2+1);
        }
        // bit 1 count : cntq(l,r)  단¥U, [l,r ]은¨¬ 0-based
        int cntq(int l, int r, int lm=0,int rm=0, int pos = 1,bool tog = false){
                if(pos == 1) rm = w-1;
                if( r < lm || rm < l )        
                        return 0;
                if( l <= lm && rm <= r ){
                        if(tog)
                                return (rm-lm+1) - cnt[pos];
                        return cnt[pos];
                }
                else
                        return cntq(LQR, node[pos] ^ tog) + cntq(RQR,node[pos] ^ tog);
        }
};

구간에서 0이 아닌 칸을 세는 Segment Tree (Plane Sweep)

// 구간에 특정 값을 더함
// 0이 아닌 칸의 개수를 샌다
struct PST{
        int lv, w;
        vi node;
        vi cnt;
        // 0-based [l, r]에 val을 더함
        void update(int val, int l, int r, int lm=0,int rm=0, int pos = 1 ){
                if(pos == 1) rm = w-1;
                if( l <= lm && rm <= r )
                        node[pos] += val;
                else{
                        if( l <= mid ) update(val,LQR);
                        if( mid+1<=r ) update(val,RQR);
                }
                cnt[pos] = 0;
                if(node[pos]) cnt[pos] = rm - lm + 1;
                if(!node[pos] && lm < rm) cnt[pos] = cnt[pos*2] + cnt[pos*2+1];
        }
        PST(int N ){
                for(lv=0, w=1;  w<N;  w*=2,lv++);
                node = cnt = vi(w*2+1);
        }
        // 0-based [l,r ] 구간체크 cntq(l,r);
        int cntq(int l, int r, int lm=0,int rm=0, int pos = 1){
                if(pos == 1) rm = w-1;
                if( r < lm || rm < l )        return 0;
                if( l <= lm && rm <= r )      return cnt[pos];
                return cntq(LQR) + cntq(RQR);
        }
};

https://edu.goorm.io/learn/lecture/554/알고리즘-문제해결기법-입문

profile
개발자 구직과 성장에 대한 정보! 유튜브 [코딩몬스터TV] 채널을 구독해주세요

0개의 댓글