BOJ 13925 - 수열과 쿼리 13 링크
(2023.07.07 기준 P1)
길이가 N인 수열 A가 주어지고 M개의 쿼리가 주어진다. 각 쿼리에 맞게 출력
- 1 x y v: Ai = (Ai + v) % MOD (x ≤ i ≤ y)
- 2 x y v: Ai = (Ai × v) % MOD (x ≤ i ≤ y)
- 3 x y v: Ai = v (x ≤ i ≤ y)
- 4 x y: (ΣAi) % MOD 출력 (x ≤ i ≤ y)
Lazy propagation
더하기와 곱하기가 주어진다. 이를 차곡차곡 lazy에 쌓아야 하는데..
일단, lazy는 곱하기 변수, 더하기 변수. 총 2개를 만들자. 기본 초기값은 곱하기 1, 더하기 0이다. 어떤 수도 곱하기 1 더하기 0을 하면 바뀌는 수는 없다. 즉, 항등원이다.
만약 지금 노드 값은 x, 곱하기 lazy는 a, 더하기 lazy는 b라고 생각해보자. 값은 ax+b다.
여기에 곱하기 v를 하면? (ax+b) * v = avx+bv 가 나온다.
만약 곱하기 c, 더하기 d를 하면? (ax + b) *c + d = acx+bc+d 가 나온다.
결국, 곱하기 lazy에는 곱하기만, 더하기 lazy에는 곱한 후 더하면 된다.자, 이제 쿼리를 처리해보자.
1번 쿼리는 더하기다. 그러면 ax+b + v = (ax+b) * 1 + v 이므로 {1, v}를 뿌려주자. 물론, 앞이 곱하기, 뒤가 더하기다.
2번 쿼리는 곱하기다. 그러면 (ax+b) * v = (ax+b) * v + 0이므로 {v, 0}을 뿌려주자.
3번 쿼리는 변경이다. 모든 수는 0을 곱하면 0이 된다. 그러므로 {0, v}를 뿌려주자.
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int MAXN = 100000, MAXH = 1 << (int)ceil(log2(MAXN) + 1);
const ll MOD = 1e9 + 7;
int N;
ll A[MAXN];
struct Lazy{
    ll mul = 1, add = 0;
    bool is_default(){
        return mul == 1 && add == 0;
    }
    void make_default(){
        mul = 1; add = 0;
    }
    void calc(ll _mul, ll _add){ // (ax + b)c + d = acx + bc + d
        mul = (mul * _mul) % MOD;
        add = (add * _mul + _add) % MOD;
    }
};
struct ST{
    ll tree[MAXH];
    Lazy lazy[MAXH];
    void init();
    void _pull(int nd){ // child -> parent
        tree[nd] = (tree[nd << 1] + tree[nd << 1 | 1]) % MOD;
    }
    void _push(int nd, int st, int en){ // parent -> child
        if (lazy[nd].is_default()) return;
        tree[nd] = (tree[nd] * lazy[nd].mul + (en - st + 1) * lazy[nd].add) % MOD;
        if (st != en){
            lazy[nd << 1].calc(lazy[nd].mul, lazy[nd].add);
            lazy[nd << 1 | 1].calc(lazy[nd].mul, lazy[nd].add);
        }
        lazy[nd].make_default();
    }
    void _init(int nd, int st, int en){
        if (st == en){
            tree[nd] = A[st];
            return;
        }
        int mid = (st + en) >> 1;
        _init(nd << 1, st, mid);
        _init(nd << 1 | 1, mid + 1, en);
        _pull(nd);
    }
    void _update(int nd, int st, int en, int l, int r, ll mul, ll add){
        _push(nd, st, en);
        if (r < st || en < l) return;
        if (l <= st && en <= r){
            lazy[nd].calc(mul, add);
            _push(nd, st, en);
            return;
        }
        int mid = (st + en) >> 1;
        _update(nd << 1, st, mid, l, r, mul, add);
        _update(nd << 1 | 1, mid + 1, en, l, r, mul, add);
        _pull(nd);
    }
    void update(int l, int r, ll mul, ll add){
        _update(1, 0, N - 1, l, r, mul, add);
    }
    ll _query(int nd, int st, int en, int l, int r){
        _push(nd, st, en);
        if (r < st || en < l) return 0;
        if (l <= st && en <= r) return tree[nd];
        int mid = (st + en) >> 1;
        return (_query(nd << 1, st, mid, l, r) + _query(nd << 1 | 1, mid + 1, en, l, r)) % MOD;
    }
    ll query(int l, int r){
        return _query(1, 0, N - 1, l, r);
    }
}st;
void ST::init(){
    _init(1, 0, N - 1);
}
int main(){
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    cin >> N;
    for (int i = 0; i < N; i++) cin >> A[i];
    st.init();
    int M, q, x, y, v;
    cin >> M;
    while (M--){
        cin >> q;
        if (q == 1){ // +
            cin >> x >> y >> v;
            st.update(x - 1, y - 1, 1, v);
        }
        else if (q == 2){ // *
            cin >> x >> y >> v;
            st.update(x - 1, y - 1, v, 0);
        }
        else if (q == 3){ // =
            cin >> x >> y >> v;
            st.update(x - 1, y - 1, 0, v);
        }
        else{
            cin >> x >> y;
            cout << st.query(x - 1, y - 1) << '\n';
        }
    }
}import sys; input = sys.stdin.readline
from math import ceil, log2
MOD = 1000000007
class Lazy:
    def __init__(self):
        self.mul = 1
        self.add = 0
    def is_default(self):
        return self.mul == 1 and self.add == 0
    def make_default(self):
        self.mul = 1
        self.add = 0
    def calc(self, mul, add): # (ax + b)c + d = acx + bc + d
        self.mul = (self.mul * mul) % MOD
        self.add = (self.add * mul + add) % MOD
class ST:
    def __init__(self):
        self.N = N
        self.H = 1 << ceil(log2(self.N) + 1)
        self.tree = [0] * self.H
        self.lazy = [Lazy() for _ in range(self.H)]
        self._init(1, 0, self.N - 1)
    def _pull(self, nd): # child -> parent
        self.tree[nd] = (self.tree[nd << 1] + self.tree[nd << 1 | 1]) % MOD
    def _push(self, nd, st, en): # parent -> child
        if self.lazy[nd].is_default():
            return
        self.tree[nd] = (self.tree[nd] * self.lazy[nd].mul + (en - st + 1) * self.lazy[nd].add) % MOD
        if st != en:
            self.lazy[nd << 1].calc(self.lazy[nd].mul, self.lazy[nd].add)
            self.lazy[nd << 1 | 1].calc(self.lazy[nd].mul, self.lazy[nd].add)
        self.lazy[nd].make_default()
    def _init(self, nd, st, en):
        if st == en:
            self.tree[nd] = A[st]
            return
        mid = (st + en) >> 1
        self._init(nd << 1, st, mid)
        self._init(nd << 1 | 1, mid + 1, en)
        self._pull(nd)
    def _update(self, nd, st, en, l, r, mul, add):
        self._push(nd, st, en)
        if r < st or en < l:
            return
        if l <= st and en <= r:
            self.lazy[nd].calc(mul, add)
            self._push(nd, st, en)
            return
        mid = (st + en) >> 1
        self._update(nd << 1, st, mid, l, r, mul, add)
        self._update(nd << 1 | 1, mid + 1, en, l, r, mul, add)
        self._pull(nd)
    def update(self, l, r, mul, add):
        return self._update(1, 0, self.N - 1, l, r, mul, add)
    def _query(self, nd, st, en, l, r):
        self._push(nd, st, en)
        if r < st or en < l:
            return 0
        if l <= st and en <= r:
            return self.tree[nd]
        mid = (st + en) >> 1
        return (self._query(nd << 1, st, mid, l, r) + self._query(nd << 1 | 1, mid + 1, en, l, r)) % MOD
    def query(self, l, r):
        return self._query(1, 0, self.N - 1, l, r)
N = int(input())
A = list(map(int, input().split()))
st = ST()
for _ in range(int(input())):
    q, *query = map(int, input().split())
    if q == 1: # +
        x = int(query[0]) - 1
        y = int(query[1]) - 1
        v = int(query[2])
        st.update(x, y, 1, v)
    elif q == 2: # *
        x = int(query[0]) - 1
        y = int(query[1]) - 1
        v = int(query[2])
        st.update(x, y, v, 0)
    elif q == 3: # =
        x = int(query[0]) - 1
        y = int(query[1]) - 1
        v = int(query[2])
        st.update(x, y, 0, v)
    else:
        x = int(query[0]) - 1
        y = int(query[1]) - 1
        print(st.query(x, y))