백준 20148 트리와 쿼리 18
23년 5월 8일 기준 다이아 1
먼저 3번 쿼리부터 보자.
3 v : 를 출력한다. 는 정점 x에서 정점 y로 가는 경로에 있는 간선의 수를 의미한다.
모든 노드에 대해 각 노드의 가중치 * v 노드에서의 거리를 구하는 것이다.
이에 대해서 각 노드가 '서브 트리의 가중치'을 알고 있다면 아래의 식으로 해결된다.
(모든 노드의 가중치 합) + (v 노드에서 루트 노드까지의 노드 수) * (루트 노드 서브 트리의 가중치) - (v 노드에서 루트 노드까지의 모든 노드의 가중치 합) * 2
여기서 말하는 v 노드에서 루트 노드까지 는 위로 쭉 올라가는 부모 체인을 의미한다.
v 노드 서브 트리 내 노드를 보자면
1. '모든 노드의 가중치 합'에서 본인 depth + 1만큼 더해진다.
2. '(v 노드에서 루트 노드까지의 노드 수) (루트 노드 서브 트리의 가중치)'에서 v 노드 depth + 1만큼 더해진다.
3. '- (v 노드에서 루트 노드까지의 모든 노드의 가중치 합) 2'에서 (v 노드 depth + 1) * 2만큼 빼진다.
v 노드 바깥 노드를 보자면
1. '모든 노드의 가중치 합'에서 본인 depth + 1만큼 더해진다.
2. '(v 노드에서 루트 노드까지의 노드 수) (루트 노드 서브 트리의 가중치)'에서 v 노드 depth + 1만큼 더해진다.
3. '- (v 노드에서 루트 노드까지의 모든 노드의 가중치 합) 2'에서 (v 노드 depth + 1) * 2만큼 빼진다.
v 노드에서 루트 노드까지의 모든 노드의 가중치 합에서 각 노드의 위치에 따라 적절히 빼져서 해당 거리에 맞게 값을 도출할 수 있다.
그러면 우리는 각 노드가 서브 트리의 가중치를 알고 있는 세그먼트 트리를 만들어야 한다.
이제 1번 쿼리, 2번 쿼리로 해당 세그먼트 트리를 유지할 수 있으면 된다.
1 u v: 트리의 루트를 정점 u라 하였을 때, 정점 v를 루트로 하는 서브트리의 모든 정점 i의 에 1을 더한다.
루트를 정점 u로 지정하는 부분에서 처리가 나뉜다.
오일러 투어 id 기준으로
u가 v 내부에 있을 때
u가 v 외부에 있을 때
1번 쿼리 과정에서 세그먼트 트리에 더해주거나 뺄 때 '그 값'을 더해줘야 할 때도 있고, 해당 노드의 서브 트리 노드 개수만큼 더해줘야 할 때도 있다. update 함수를 구별하여 잘 짜주자.
나는 SUB_ADD, SUB_MIN 이면 서브 트리 노드 개수만큼 더하거나 빼도록 처리해두고 다른 값이 들어오면 그 값을 더하거나 빼도록 처리해두었다.
2 u v: 정점 u에서 정점 v로 가는 유일한 경로에 있는 모든 정점 i의 에 1을 더한다.
2번 쿼리는 하떨별의 응용이라 보면 된다.
하떨별에서는 range update, point find로 AC를 맞을 수 있는데, 이 문제의 2번 쿼리는 range update, range find라 보면 된다.
이를 위해서 하떨별을 range find로 바꿔서 풀어보았다.
1번 쿼리에서 보았듯이, u와 v 노드의 최소 공통 조상까지 값 처리를 해주고, 더해진 값만큼 최소 공통 조상의 부모에서부터 루트 체인까지 쭉 빼주면 된다.
3번 쿼리에서 범위 find를 할 때, seg[root]만 가져가도 되도록 처리했다.
lazy 처리를 할 때, 해당 범위 값들을 seg[root]에 바로바로 넣어주었다.
이를 위해서 ps 배열을 만들어 두었고, 이는 오일러 투어 id 기준 서브 트리 노드 개수의 누적합이다.
알고리즘 분류가 담백한데 난이도가 높다면 무서운 문제이다..
#include <algorithm>
#include <iostream>
#include <vector>
using namespace std;
typedef long long ll;
const int vmax = 200001;
const int SUB_ADD = 2000000000, SUB_MIN = -2000000000;
int N, Q, a, b, c;
vector<vector<int>> tree(vmax);
vector<vector<int>> parent(19, vector<int>(vmax));
vector<int> depth(vmax);
vector<int> siz(vmax);
vector<int> in(vmax);
vector<int> rin(vmax);
vector<int> out(vmax);
vector<int> top(vmax, 1);
vector<ll> ps(vmax);
ll seg[vmax * 4];
ll lazy1[vmax * 4];
ll lazy2[vmax * 4];
ll del[vmax * 4];
int cnt[vmax * 4];
void makeParent() {
for (int i = 1; i < 19; i++) {
for (int j = 1; j <= N; j++) {
parent[i][j] = parent[i - 1][parent[i - 1][j]];
}
}
}
int lca(int a, int b) {
if (depth[a] > depth[b])
swap(a, b);
int diff = depth[b] - depth[a];
int j = 0;
while (diff) {
if (diff & 1) {
b = parent[j][b];
}
j++;
diff /= 2;
}
if (a != b) {
for (int i = 18; i >= 0; i--) {
int pa = parent[i][a];
int pb = parent[i][b];
if (pa != pb) {
a = pa;
b = pb;
}
}
a = parent[0][a];
}
return a;
}
void dfs1(int here = 1, int prev = 0) {
siz[here] = 1;
int maxx = -1, maxxIdx = -1;
for (auto &next : tree[here]) {
int idx = &next - &tree[here][0];
if (next == prev) {
swap(next, tree[here][0]);
if (maxxIdx == 0)
maxxIdx = idx;
continue;
}
parent[0][next] = here;
depth[next] = depth[here] + 1;
dfs1(next, here);
siz[here] += siz[next];
if (siz[next] > maxx) {
maxx = siz[next];
maxxIdx = idx;
}
}
if (maxxIdx != -1)
swap(tree[here][0], tree[here][maxxIdx]);
}
int idx = 1;
void dfs2(int here = 1, int prev = 0) {
rin[idx] = here;
in[here] = idx++;
for (auto &next : tree[here]) {
if (next == prev)
continue;
int forIdx = &next - &tree[here][0];
top[next] = forIdx == 0 ? top[here] : next;
dfs2(next, here);
}
out[here] = idx;
}
void lazyUpdate(int root, int s, int e) {
if (lazy1[root] == 0)
return;
if (s != e) {
lazy1[root * 2] += lazy1[root];
lazy1[root * 2 + 1] += lazy1[root];
}
seg[root] += lazy1[root] * (ps[e] - ps[s - 1]);
lazy1[root] = 0;
}
void lazyUpdate2(int root, int s, int e) {
if (lazy2[root] == 0)
return;
seg[root] += lazy2[root] * (e - s + 1);
seg[root] += cnt[root] * ((ll)(e - s + 1) * (e - s) / 2);
if (s != e) {
int m = (s + e) / 2;
lazy2[root * 2] += lazy2[root] + cnt[root] * (e - m);
lazy2[root * 2 + 1] += lazy2[root];
cnt[root * 2] += cnt[root];
cnt[root * 2 + 1] += cnt[root];
}
lazy2[root] = 0;
cnt[root] = 0;
}
void lazyDel(int root, int s, int e) {
if (del[root] == 0)
return;
if (s != e) {
del[root * 2] += del[root];
del[root * 2 + 1] += del[root];
}
seg[root] -= del[root] * (e - s + 1);
del[root] = 0;
}
void lazy(int root, int s, int e) {
lazyUpdate(root, s, e);
lazyUpdate2(root, s, e);
lazyDel(root, s, e);
}
void ru(int v, int l, int r, int root = 1, int s = 1, int e = N) {
lazy(root, s, e);
if (r < s || l > e)
return;
if (l <= s && e <= r) {
if (v == SUB_ADD)
lazy1[root]++;
else if (v == SUB_MIN)
lazy1[root]--;
else
del[root] += v;
lazy(root, s, e);
return;
}
int m = (s + e) / 2;
ru(v, l, r, root * 2, s, m);
ru(v, l, r, root * 2 + 1, m + 1, e);
seg[root] = seg[root * 2] + seg[root * 2 + 1];
}
void ru2(int v, int l, int r, int root = 1, int s = 1, int e = N) {
lazy(root, s, e);
if (r < s || l > e)
return;
if (l <= s && e <= r) {
lazy2[root] += v + r - e;
cnt[root]++;
lazy(root, s, e);
return;
}
int m = (s + e) / 2;
ru2(v, l, r, root * 2, s, m);
ru2(v, l, r, root * 2 + 1, m + 1, e);
seg[root] = seg[root * 2] + seg[root * 2 + 1];
}
ll rf(int l, int r, int root = 1, int s = 1, int e = N) {
lazy(root, s, e);
if (r < s || l > e)
return 0;
if (l <= s && e <= r)
return seg[root];
int m = (s + e) / 2;
return rf(l, r, root * 2, s, m) + rf(l, r, root * 2 + 1, m + 1, e);
}
void hld2(int t, int a, int fixed = 1) {
while (top[t] != top[a]) {
ru2(fixed, in[top[a]], in[a]);
fixed += in[a] - in[top[a]] + 1;
a = parent[0][top[a]];
}
ru2(fixed, in[t], in[a]);
}
ll hld3(int t, int a) {
ll res = 0;
while (top[t] != top[a]) {
res += rf(in[top[a]], in[a]);
a = parent[0][top[a]];
}
res += rf(in[t], in[a]);
return res;
}
int main() {
ios::sync_with_stdio(false);
cin.tie(NULL);
cout.tie(NULL);
#ifndef ONLINE_JUDGE
freopen("data.txt", "r", stdin);
#endif
cin >> N;
for (int i = 1; i < N; i++) {
cin >> a >> b;
tree[a].push_back(b);
tree[b].push_back(a);
}
dfs1();
dfs2();
makeParent();
vector<int> temp(vmax, 0);
for (int i = 1; i <= N; i++)
temp[in[i]] = siz[i];
for (int i = 1; i <= N; i++)
ps[i] = ps[i - 1] + temp[i];
cin >> Q;
for (int i = 0; i < Q; i++) {
cin >> a >> b;
if (a == 1) {
cin >> c;
if (in[c] <= in[b] && in[b] <= out[c] - 1) {
ru(SUB_ADD, 1, out[1] - 1);
if (b == c)
continue;
while (top[c] != top[b]) {
if (parent[0][top[b]] == c) {
b = top[b];
break;
}
b = parent[0][top[b]];
}
if (top[c] == top[b])
b = rin[in[c] + 1];
ru(SUB_MIN, in[b], out[b] - 1);
int del = siz[b];
b = parent[0][b];
while (1 != top[b]) {
ru(del, in[top[b]], in[b]);
b = parent[0][top[b]];
}
ru(del, 1, in[b]);
} else {
ru(SUB_ADD, in[c], out[c] - 1);
int add = siz[c];
if (c == 1)
continue;
c = parent[0][c];
while (1 != top[c]) {
ru(-add, in[top[c]], in[c]);
c = parent[0][top[c]];
}
ru(-add, 1, in[c]);
}
} else if (a == 2) {
cin >> c;
int t = lca(c, b);
if (in[b] > in[c])
swap(b, c);
hld2(t, c);
int add = depth[c] - depth[t] + 1;
if (out[b] - 1 < in[c]) {
add += depth[b] - depth[t];
hld2(t, b);
ru(1, in[t], in[t]);
}
if (t == 1)
continue;
t = parent[0][t];
while (1 != top[t]) {
ru(-add, in[top[t]], in[t]);
t = parent[0][top[t]];
}
ru(-add, 1, in[t]);
} else {
ll res = (ll)(depth[b] + 1) * rf(1, 1);
res += rf(1, out[1] - 1);
res -= hld3(1, b) * 2;
cout << res << '\n';
}
}
return 0;
}