BOJ 17429 - 국제 메시 기구 링크
(2023.07.11 기준 D4)
(치팅하지 마세요)
N개의 금고가 트리 모양으로 연결되어 있다. 모든 금고는 처음엔 0원이 있으며, Q개의 쿼리를 알맞게 처리 및 출력
- 1 X V: 금고 X의 서브트리에 있는 모든 금고에 V원을 더합니다. (1 ≤ X ≤ N)
- 2 X Y V: 금고 X부터 금고 Y까지의 경로에 있는 모든 금고에 V원을 더합니다. (1 ≤ X, Y ≤ N)
- 3 X V: 금고 X의 서브트리에 있는 모든 금고의 돈을 V배 합니다. (1 ≤ X ≤ N)
- 4 X Y V: 금고 X부터 금고 Y까지의 경로에 있는 모든 금고의 돈을 V배 합니다. (1 ≤ X, Y ≤ N)
- 5 X: 금고 X의 서브트리에 있는 모든 금고의 돈을 합한 값을 출력합니다. (1 ≤ X ≤ N)
- 6 X Y: 금고 X부터 금고 Y까지의 경로에 있는 모든 금고의 돈을 합한 값을 출력합니다. (1 ≤ X, Y ≤ N)
Lazy propagation 및 트리 경로는 HLD로 관리
쿼리 자체는 BOJ 13925 - 수열과 쿼리 13 풀이와 동일하다.
단, 이 문제는 트리 모양에서의 쿼리 처리기 때문에, HLD로 하여금 경로를 관리해야 한다.
서브트리 쿼리는 오일러 경로에서의 in, out을 이용해 직접 세그먼트 트리에 접근하면 되고,
경로 쿼리는 HLD에서의 체인으로 접근하면 된다.
MOD가 2^32인데, 이를 MOD로 직접 나머지 연산으로 처리하면 안된다. 그러면 두 수를 곱할 때, long long 범위를 벗어나기 때문이다.
그러므로 정수 타입을 long long 대신 unsigned int로 선언하면 2^32까지 지원하기 때문에 저절로 오버플로우가 일어나서 MOD가 저절로 적용된다.
아 물론 C++에서만..
#include <bits/stdc++.h>
using namespace std;
typedef unsigned int ll; // 1 << 32 까지의 범위를 지원하는 unsigned int를 사용
const int MAXN = 500000, MAXH = 1 << (int)ceil(log2(MAXN) + 1);
// MOD = (long long)1 << 32;
// unsigned int를 사용함으로써 MOD가 자동으로 적용된다.
int N, Q;
// Lazy 구조체
struct Lazy{
ll mul = 1, add = 0; // (1, 0)은 항등원이다.
bool is_default(){ // 초기값인지 확인
return mul == 1 && add == 0;
}
void make_default(){ // 초기값으로 만들기
mul = 1; add = 0;
}
void calc(ll _mul, ll _add){ // (ax + b)c + d = acx + bc + d
mul = mul * _mul;
add = add * _mul + _add;
}
};
// 세그먼트 트리
struct ST{
ll tree[MAXH];
Lazy lazy[MAXH];
void init();
void _pull(int nd){ // child -> parent
tree[nd] = tree[nd << 1] + tree[nd << 1 | 1];
}
void _push(int nd, int st, int en){ // parent -> child
if (lazy[nd].is_default()) return;
tree[nd] = tree[nd] * lazy[nd].mul + (en - st + 1) * lazy[nd].add;
if (st != en){
lazy[nd << 1].calc(lazy[nd].mul, lazy[nd].add);
lazy[nd << 1 | 1].calc(lazy[nd].mul, lazy[nd].add);
}
lazy[nd].make_default();
}
void _update(int nd, int st, int en, int l, int r, ll mul, ll add){
_push(nd, st, en);
if (r < st || en < l) return;
if (l <= st && en <= r){
lazy[nd].calc(mul, add);
_push(nd, st, en);
return;
}
int mid = (st + en) >> 1;
_update(nd << 1, st, mid, l, r, mul, add);
_update(nd << 1 | 1, mid + 1, en, l, r, mul, add);
_pull(nd);
}
void update(int l, int r, ll mul, ll add){
_update(1, 0, N - 1, l, r, mul, add);
}
ll _query(int nd, int st, int en, int l, int r){
_push(nd, st, en);
if (r < st || en < l) return 0;
if (l <= st && en <= r) return tree[nd];
int mid = (st + en) >> 1;
return _query(nd << 1, st, mid, l, r) + _query(nd << 1 | 1, mid + 1, en, l, r);
}
ll query(int l, int r){
return _query(1, 0, N - 1, l, r);
}
}st;
// heavy-light 분할
struct HLD{
int sz[MAXN], pa[MAXN], lv[MAXN], in[MAXN], out[MAXN], head[MAXN], idx;
vector<int> graph[MAXN];
void init();
int _dfs(int here){
sz[here] = 1;
for (int i = 0, gsz = graph[here].size(); i < gsz; i++){
int there = graph[here][i];
if (pa[here] == there) continue;
pa[there] = here;
lv[there] = lv[here] + 1;
sz[here] += _dfs(there);
if (sz[graph[here][0]] < sz[there]) swap(graph[here][0], graph[here][i]);
}
return sz[here];
}
void _hld(int here){
in[here] = ++idx;
for (auto there: graph[here]){
if (pa[here] == there) continue;
if (graph[here][0] == there) head[there] = head[here];
else head[there] = there;
_hld(there);
}
out[here] = idx;
}
void update(int u, int v, ll mul, ll add){
while (head[u] != head[v]){
if (lv[head[u]] < lv[head[v]]) swap(u, v);
st.update(in[head[u]], in[u], mul, add);
u = pa[head[u]];
}
if (in[u] > in[v]) swap(u, v);
st.update(in[u], in[v], mul, add);
}
ll query(int u, int v){
ll result = 0;
while (head[u] != head[v]){
if (lv[head[u]] < lv[head[v]]) swap(u, v);
result += st.query(in[head[u]], in[u]);
u = pa[head[u]];
}
if (in[u] > in[v]) swap(u, v);
result += st.query(in[u], in[v]);
return result;
}
}hld;
void ST::init(){
fill(tree, tree + MAXH, 0);
}
void HLD::init(){
fill(sz, sz + N, 0);
fill(pa, pa + N, 0);
fill(lv, lv + N, 0);
fill(in, in + N, 0);
fill(out, out + N, 0);
fill(head, head + N, 0);
idx = -1;
_dfs(0); _hld(0);
}
int main(){
ios_base::sync_with_stdio(0);
cin.tie(0);
cin >> N >> Q;
for (int i = 1, S, E; i < N; i++){
cin >> S >> E;
hld.graph[--S].push_back(--E);
hld.graph[E].push_back(S);
}
hld.init(); st.init();
int q, X, Y, V;
while (Q--){
cin >> q;
if (q == 1){ // X의 서브트리에 V 더하기
cin >> X >> V;
st.update(hld.in[X - 1], hld.out[X - 1], 1, V);
}
else if (q == 2){ // X와 Y의 경로에 V 더하기
cin >> X >> Y >> V;
hld.update(X - 1, Y - 1, 1, V);
}
else if (q == 3){ // X의 서브트리에 V 곱하기
cin >> X >> V;
st.update(hld.in[X - 1], hld.out[X - 1], V, 0);
}
else if (q == 4){ // X와 Y의 경로에 V 곱하기
cin >> X >> Y >> V;
hld.update(X - 1, Y - 1, V, 0);
}
else if (q == 5){ // X의 서브트리 합 쿼리
cin >> X;
cout << st.query(hld.in[X - 1], hld.out[X - 1]) << '\n';
}
else{ // X와 Y의 경로 합 쿼리
cin >> X >> Y;
cout << hld.query(X - 1, Y - 1) << '\n';
}
}
}
import sys; input = sys.stdin.readline
sys.setrecursionlimit(500000)
from math import ceil, log2
MOD = 1 << 32
# Lazy 객체
class Lazy:
def __init__(self): # (1, 0)은 항등원이다.
self.mul = 1
self.add = 0
def is_default(self): # 초기값인지 확인
return self.mul == 1 and self.add == 0
def make_default(self): # 초기값으로 만들기
self.mul = 1
self.add = 0
def calc(self, mul, add): # (ax + b)c + d = acx + bc + d
self.mul = (self.mul * mul) % MOD
self.add = (self.add * mul + add) % MOD
# 세그먼트 트리
class ST:
def __init__(self):
self.N = N
self.H = 1 << ceil(log2(self.N) + 1)
self.tree = [0] * self.H
self.lazy = [Lazy() for _ in range(self.H)]
def _pull(self, nd): # child -> parent
self.tree[nd] = (self.tree[nd << 1] + self.tree[nd << 1 | 1]) % MOD
def _push(self, nd, st, en): # parent -> child
if self.lazy[nd].is_default():
return
self.tree[nd] = (self.tree[nd] * self.lazy[nd].mul + (en - st + 1) * self.lazy[nd].add) % MOD
if st != en:
self.lazy[nd << 1].calc(self.lazy[nd].mul, self.lazy[nd].add)
self.lazy[nd << 1 | 1].calc(self.lazy[nd].mul, self.lazy[nd].add)
self.lazy[nd].make_default()
def _update(self, nd, st, en, l, r, mul, add):
self._push(nd, st, en)
if r < st or en < l:
return
if l <= st and en <= r:
self.lazy[nd].calc(mul, add)
self._push(nd, st, en)
return
mid = (st + en) >> 1
self._update(nd << 1, st, mid, l, r, mul, add)
self._update(nd << 1 | 1, mid + 1, en, l, r, mul, add)
self._pull(nd)
def update(self, l, r, mul, add):
return self._update(1, 0, self.N - 1, l, r, mul, add)
def _query(self, nd, st, en, l, r):
self._push(nd, st, en)
if r < st or en < l:
return 0
if l <= st and en <= r:
return self.tree[nd]
mid = (st + en) >> 1
return (self._query(nd << 1, st, mid, l, r) + self._query(nd << 1 | 1, mid + 1, en, l, r)) % MOD
def query(self, l, r):
return self._query(1, 0, self.N - 1, l, r)
# heavy-light 분할
class HLD:
def __init__(self):
self.N = N
self.graph = [[] for _ in range(self.N)]
self.sz = [0] * self.N
self.pa = [0] * self.N
self.lv = [0] * self.N
self.inn = [0] * self.N
self.out = [0] * self.N
self.head = [0] * self.N
self.idx = -1
def init(self):
self._dfs(0)
self._hld(0)
def _dfs(self, here):
self.sz[here] = 1
for i in range(len(self.graph[here])):
there = self.graph[here][i]
if self.pa[here] == there:
continue
self.pa[there] = here
self.lv[there] = self.lv[here] + 1
self.sz[here] += self._dfs(there)
if self.sz[self.graph[here][0]] < self.sz[there]:
self.graph[here][0], self.graph[here][i] = self.graph[here][i], self.graph[here][0]
return self.sz[here]
def _hld(self, here):
self.idx += 1
self.inn[here] = self.idx
for there in self.graph[here]:
if self.pa[here] == there:
continue
if self.graph[here][0] == there:
self.head[there] = self.head[here]
else:
self.head[there] = there
self._hld(there)
self.out[here] = self.idx
def update(self, u, v, mul, add):
while self.head[u] != self.head[v]:
if self.lv[self.head[u]] < self.lv[self.head[v]]:
u, v = v, u
st.update(self.inn[self.head[u]], self.inn[u], mul, add)
u = self.pa[self.head[u]]
if self.inn[u] > self.inn[v]:
u, v = v, u
st.update(self.inn[u], self.inn[v], mul, add)
def query(self, u, v):
result = 0
while self.head[u] != self.head[v]:
if self.lv[self.head[u]] < self.lv[self.head[v]]:
u, v = v, u
result += st.query(self.inn[self.head[u]], self.inn[u])
u = self.pa[self.head[u]]
if self.inn[u] > self.inn[v]:
u, v = v, u
result += st.query(self.inn[u], self.inn[v])
return result
N, Q = map(int, input().split())
hld = HLD()
for _ in range(N - 1):
S, E = map(int, input().split())
S -= 1; E -= 1
hld.graph[S].append(E)
hld.graph[E].append(S)
hld.init(); st = ST()
for _ in range(Q):
q, *query = map(int, input().split())
if q == 1: # X의 서브트리에 V 더하기
X, V = query
st.update(hld.inn[X - 1], hld.out[X - 1], 1, V)
elif q == 2: # X와 Y의 경로에 V 더하기
X, Y, V = query
hld.update(X - 1, Y - 1, 1, V)
elif q == 3: # X의 서브트리에 V 곱하기
X, V = query
st.update(hld.inn[X - 1], hld.out[X - 1], V, 0)
elif q == 4: # X와 Y의 경로에 V 곱하기
X, Y, V = query
hld.update(X - 1, Y - 1, V, 0)
elif q == 5: # X의 서브트리 합 쿼리
X = query[0]
print(st.query(hld.inn[X - 1], hld.out[X - 1]))
else: # X와 Y의 경로 합 쿼리
X, Y = query
print(hld.query(X - 1, Y - 1))
Python은 별짓을 다해봐도 TLE다.
PyPy3 + 비재귀 세그 + FASTIO + NTT를 이용한 빠른 곱셈 = TLE
ㅋㅋㅋㅋㅋㅋㅋㅋㅋㅋㅋㅋㅋㅋㅋㅋㅋㅋㅋㅋㅋㅋㅋㅋㅋㅋ
짜증난다...