[BOJ 17429] - 국제 메시 기구 (heavy-light 분할, 세그먼트 트리, 오일러 경로 테크닉, 트리, C++, Python)

보양쿠·2023년 7월 11일
0

BOJ

목록 보기
154/260
post-custom-banner

BOJ 17429 - 국제 메시 기구 링크
(2023.07.11 기준 D4)
(치팅하지 마세요)

문제

N개의 금고가 트리 모양으로 연결되어 있다. 모든 금고는 처음엔 0원이 있으며, Q개의 쿼리를 알맞게 처리 및 출력

  • 1 X V: 금고 X의 서브트리에 있는 모든 금고에 V원을 더합니다. (1 ≤ X ≤ N)
  • 2 X Y V: 금고 X부터 금고 Y까지의 경로에 있는 모든 금고에 V원을 더합니다. (1 ≤ X, Y ≤ N)
  • 3 X V: 금고 X의 서브트리에 있는 모든 금고의 돈을 V배 합니다. (1 ≤ X ≤ N)
  • 4 X Y V: 금고 X부터 금고 Y까지의 경로에 있는 모든 금고의 돈을 V배 합니다. (1 ≤ X, Y ≤ N)
  • 5 X: 금고 X의 서브트리에 있는 모든 금고의 돈을 합한 값을 출력합니다. (1 ≤ X ≤ N)
  • 6 X Y: 금고 X부터 금고 Y까지의 경로에 있는 모든 금고의 돈을 합한 값을 출력합니다. (1 ≤ X, Y ≤ N)

알고리즘

Lazy propagation 및 트리 경로는 HLD로 관리

풀이

쿼리 자체는 BOJ 13925 - 수열과 쿼리 13 풀이와 동일하다.

단, 이 문제는 트리 모양에서의 쿼리 처리기 때문에, HLD로 하여금 경로를 관리해야 한다.
서브트리 쿼리는 오일러 경로에서의 in, out을 이용해 직접 세그먼트 트리에 접근하면 되고,
경로 쿼리는 HLD에서의 체인으로 접근하면 된다.

주의사항

MOD가 2^32인데, 이를 MOD로 직접 나머지 연산으로 처리하면 안된다. 그러면 두 수를 곱할 때, long long 범위를 벗어나기 때문이다.
그러므로 정수 타입을 long long 대신 unsigned int로 선언하면 2^32까지 지원하기 때문에 저절로 오버플로우가 일어나서 MOD가 저절로 적용된다.

아 물론 C++에서만..

코드

  • C++
#include <bits/stdc++.h>
using namespace std;

typedef unsigned int ll; // 1 << 32 까지의 범위를 지원하는 unsigned int를 사용

const int MAXN = 500000, MAXH = 1 << (int)ceil(log2(MAXN) + 1);
// MOD = (long long)1 << 32;
// unsigned int를 사용함으로써 MOD가 자동으로 적용된다.

int N, Q;

// Lazy 구조체
struct Lazy{
    ll mul = 1, add = 0; // (1, 0)은 항등원이다.

    bool is_default(){ // 초기값인지 확인
        return mul == 1 && add == 0;
    }

    void make_default(){ // 초기값으로 만들기
        mul = 1; add = 0;
    }

    void calc(ll _mul, ll _add){ // (ax + b)c + d = acx + bc + d
        mul = mul * _mul;
        add = add * _mul + _add;
    }
};

// 세그먼트 트리
struct ST{
    ll tree[MAXH];
    Lazy lazy[MAXH];

    void init();

    void _pull(int nd){ // child -> parent
        tree[nd] = tree[nd << 1] + tree[nd << 1 | 1];
    }

    void _push(int nd, int st, int en){ // parent -> child
        if (lazy[nd].is_default()) return;
        tree[nd] = tree[nd] * lazy[nd].mul + (en - st + 1) * lazy[nd].add;
        if (st != en){
            lazy[nd << 1].calc(lazy[nd].mul, lazy[nd].add);
            lazy[nd << 1 | 1].calc(lazy[nd].mul, lazy[nd].add);
        }
        lazy[nd].make_default();
    }

    void _update(int nd, int st, int en, int l, int r, ll mul, ll add){
        _push(nd, st, en);
        if (r < st || en < l) return;
        if (l <= st && en <= r){
            lazy[nd].calc(mul, add);
            _push(nd, st, en);
            return;
        }
        int mid = (st + en) >> 1;
        _update(nd << 1, st, mid, l, r, mul, add);
        _update(nd << 1 | 1, mid + 1, en, l, r, mul, add);
        _pull(nd);
    }

    void update(int l, int r, ll mul, ll add){
        _update(1, 0, N - 1, l, r, mul, add);
    }

    ll _query(int nd, int st, int en, int l, int r){
        _push(nd, st, en);
        if (r < st || en < l) return 0;
        if (l <= st && en <= r) return tree[nd];
        int mid = (st + en) >> 1;
        return _query(nd << 1, st, mid, l, r) + _query(nd << 1 | 1, mid + 1, en, l, r);
    }

    ll query(int l, int r){
        return _query(1, 0, N - 1, l, r);
    }
}st;


// heavy-light 분할
struct HLD{
    int sz[MAXN], pa[MAXN], lv[MAXN], in[MAXN], out[MAXN], head[MAXN], idx;
    vector<int> graph[MAXN];

    void init();

    int _dfs(int here){
        sz[here] = 1;
        for (int i = 0, gsz = graph[here].size(); i < gsz; i++){
            int there = graph[here][i];
            if (pa[here] == there) continue;
            pa[there] = here;
            lv[there] = lv[here] + 1;
            sz[here] += _dfs(there);
            if (sz[graph[here][0]] < sz[there]) swap(graph[here][0], graph[here][i]);
        }

        return sz[here];
    }

    void _hld(int here){
        in[here] = ++idx;

        for (auto there: graph[here]){
            if (pa[here] == there) continue;
            if (graph[here][0] == there) head[there] = head[here];
            else head[there] = there;
            _hld(there);
        }

        out[here] = idx;
    }

    void update(int u, int v, ll mul, ll add){
        while (head[u] != head[v]){
            if (lv[head[u]] < lv[head[v]]) swap(u, v);
            st.update(in[head[u]], in[u], mul, add);
            u = pa[head[u]];
        }

        if (in[u] > in[v]) swap(u, v);
        st.update(in[u], in[v], mul, add);
    }

    ll query(int u, int v){
        ll result = 0;

        while (head[u] != head[v]){
            if (lv[head[u]] < lv[head[v]]) swap(u, v);
            result += st.query(in[head[u]], in[u]);
            u = pa[head[u]];
        }

        if (in[u] > in[v]) swap(u, v);
        result += st.query(in[u], in[v]);

        return result;
    }
}hld;

void ST::init(){
    fill(tree, tree + MAXH, 0);
}

void HLD::init(){
    fill(sz, sz + N, 0);
    fill(pa, pa + N, 0);
    fill(lv, lv + N, 0);
    fill(in, in + N, 0);
    fill(out, out + N, 0);
    fill(head, head + N, 0);
    idx = -1;

    _dfs(0); _hld(0);
}

int main(){
    ios_base::sync_with_stdio(0);
    cin.tie(0);

    cin >> N >> Q;

    for (int i = 1, S, E; i < N; i++){
        cin >> S >> E;
        hld.graph[--S].push_back(--E);
        hld.graph[E].push_back(S);
    }

    hld.init(); st.init();

    int q, X, Y, V;
    while (Q--){
        cin >> q;
        if (q == 1){ // X의 서브트리에 V 더하기
            cin >> X >> V;
            st.update(hld.in[X - 1], hld.out[X - 1], 1, V);
        }
        else if (q == 2){ // X와 Y의 경로에 V 더하기
            cin >> X >> Y >> V;
            hld.update(X - 1, Y - 1, 1, V);
        }
        else if (q == 3){ // X의 서브트리에 V 곱하기
            cin >> X >> V;
            st.update(hld.in[X - 1], hld.out[X - 1], V, 0);
        }
        else if (q == 4){ // X와 Y의 경로에 V 곱하기
            cin >> X >> Y >> V;
            hld.update(X - 1, Y - 1, V, 0);
        }
        else if (q == 5){ // X의 서브트리 합 쿼리
            cin >> X;
            cout << st.query(hld.in[X - 1], hld.out[X - 1]) << '\n';
        }
        else{ // X와 Y의 경로 합 쿼리
            cin >> X >> Y;
            cout << hld.query(X - 1, Y - 1) << '\n';
        }
    }
}
  • Python (TLE)
import sys; input = sys.stdin.readline
sys.setrecursionlimit(500000)
from math import ceil, log2
MOD = 1 << 32

# Lazy 객체
class Lazy:
    def __init__(self): # (1, 0)은 항등원이다.
        self.mul = 1
        self.add = 0

    def is_default(self): # 초기값인지 확인
        return self.mul == 1 and self.add == 0

    def make_default(self): # 초기값으로 만들기
        self.mul = 1
        self.add = 0

    def calc(self, mul, add): # (ax + b)c + d = acx + bc + d
        self.mul = (self.mul * mul) % MOD
        self.add = (self.add * mul + add) % MOD

# 세그먼트 트리
class ST:
    def __init__(self):
        self.N = N
        self.H = 1 << ceil(log2(self.N) + 1)
        self.tree = [0] * self.H
        self.lazy = [Lazy() for _ in range(self.H)]

    def _pull(self, nd): # child -> parent
        self.tree[nd] = (self.tree[nd << 1] + self.tree[nd << 1 | 1]) % MOD

    def _push(self, nd, st, en): # parent -> child
        if self.lazy[nd].is_default():
            return
        self.tree[nd] = (self.tree[nd] * self.lazy[nd].mul + (en - st + 1) * self.lazy[nd].add) % MOD
        if st != en:
            self.lazy[nd << 1].calc(self.lazy[nd].mul, self.lazy[nd].add)
            self.lazy[nd << 1 | 1].calc(self.lazy[nd].mul, self.lazy[nd].add)
        self.lazy[nd].make_default()

    def _update(self, nd, st, en, l, r, mul, add):
        self._push(nd, st, en)
        if r < st or en < l:
            return
        if l <= st and en <= r:
            self.lazy[nd].calc(mul, add)
            self._push(nd, st, en)
            return
        mid = (st + en) >> 1
        self._update(nd << 1, st, mid, l, r, mul, add)
        self._update(nd << 1 | 1, mid + 1, en, l, r, mul, add)
        self._pull(nd)

    def update(self, l, r, mul, add):
        return self._update(1, 0, self.N - 1, l, r, mul, add)

    def _query(self, nd, st, en, l, r):
        self._push(nd, st, en)
        if r < st or en < l:
            return 0
        if l <= st and en <= r:
            return self.tree[nd]
        mid = (st + en) >> 1
        return (self._query(nd << 1, st, mid, l, r) + self._query(nd << 1 | 1, mid + 1, en, l, r)) % MOD

    def query(self, l, r):
        return self._query(1, 0, self.N - 1, l, r)

# heavy-light 분할
class HLD:
    def __init__(self):
        self.N = N
        self.graph = [[] for _ in range(self.N)]
        self.sz = [0] * self.N
        self.pa = [0] * self.N
        self.lv = [0] * self.N
        self.inn = [0] * self.N
        self.out = [0] * self.N
        self.head = [0] * self.N
        self.idx = -1

    def init(self):
        self._dfs(0)
        self._hld(0)

    def _dfs(self, here):
        self.sz[here] = 1
        for i in range(len(self.graph[here])):
            there = self.graph[here][i]
            if self.pa[here] == there:
                continue
            self.pa[there] = here
            self.lv[there] = self.lv[here] + 1
            self.sz[here] += self._dfs(there)
            if self.sz[self.graph[here][0]] < self.sz[there]:
                self.graph[here][0], self.graph[here][i] = self.graph[here][i], self.graph[here][0]

        return self.sz[here]

    def _hld(self, here):
        self.idx += 1
        self.inn[here] = self.idx

        for there in self.graph[here]:
            if self.pa[here] == there:
                continue
            if self.graph[here][0] == there:
                self.head[there] = self.head[here]
            else:
                self.head[there] = there
            self._hld(there)

        self.out[here] = self.idx

    def update(self, u, v, mul, add):
        while self.head[u] != self.head[v]:
            if self.lv[self.head[u]] < self.lv[self.head[v]]:
                u, v = v, u
            st.update(self.inn[self.head[u]], self.inn[u], mul, add)
            u = self.pa[self.head[u]]

        if self.inn[u] > self.inn[v]:
            u, v = v, u
        st.update(self.inn[u], self.inn[v], mul, add)

    def query(self, u, v):
        result = 0

        while self.head[u] != self.head[v]:
            if self.lv[self.head[u]] < self.lv[self.head[v]]:
                u, v = v, u
            result += st.query(self.inn[self.head[u]], self.inn[u])
            u = self.pa[self.head[u]]

        if self.inn[u] > self.inn[v]:
            u, v = v, u
        result += st.query(self.inn[u], self.inn[v])

        return result

N, Q = map(int, input().split())

hld = HLD()

for _ in range(N - 1):
    S, E = map(int, input().split())
    S -= 1; E -= 1
    hld.graph[S].append(E)
    hld.graph[E].append(S)

hld.init(); st = ST()

for _ in range(Q):
    q, *query = map(int, input().split())
    if q == 1: # X의 서브트리에 V 더하기
        X, V = query
        st.update(hld.inn[X - 1], hld.out[X - 1], 1, V)
    elif q == 2: # X와 Y의 경로에 V 더하기
        X, Y, V = query
        hld.update(X - 1, Y - 1, 1, V)
    elif q == 3: # X의 서브트리에 V 곱하기
        X, V = query
        st.update(hld.inn[X - 1], hld.out[X - 1], V, 0)
    elif q == 4: # X와 Y의 경로에 V 곱하기
        X, Y, V = query
        hld.update(X - 1, Y - 1, V, 0)
    elif q == 5: # X의 서브트리 합 쿼리
        X = query[0]
        print(st.query(hld.inn[X - 1], hld.out[X - 1]))
    else: # X와 Y의 경로 합 쿼리
        X, Y = query
        print(hld.query(X - 1, Y - 1))

여담

Python은 별짓을 다해봐도 TLE다.
PyPy3 + 비재귀 세그 + FASTIO + NTT를 이용한 빠른 곱셈 = TLE
ㅋㅋㅋㅋㅋㅋㅋㅋㅋㅋㅋㅋㅋㅋㅋㅋㅋㅋㅋㅋㅋㅋㅋㅋㅋㅋ
짜증난다...

profile
GNU 16 statistics & computer science
post-custom-banner

0개의 댓글