[c++] 백준 20148 트리와 쿼리 18

Yoonlang·2023년 5월 8일
1
post-custom-banner

백준 20148 트리와 쿼리 18

23년 5월 8일 기준 다이아 1

백준 20148 트리와 쿼리 18 문제

사용한 개념

  • Heavy-light decomposition
  • 오일러 경로 테크닉
  • 세그먼트 트리 with lazy propagation

3번 쿼리 O(log2N)O(log^2 N)

먼저 3번 쿼리부터 보자.

3 v : i=1NAi×dist(v,i)\sum_{i = 1}^{N} A_i \times dist(v, i)를 출력한다. dist(x,y)dist(x, y)는 정점 x에서 정점 y로 가는 경로에 있는 간선의 수를 의미한다.

모든 노드에 대해 각 노드의 가중치 * v 노드에서의 거리를 구하는 것이다.

이에 대해서 각 노드가 '서브 트리의 가중치'을 알고 있다면 아래의 식으로 해결된다.

(모든 노드의 가중치 합) + (v 노드에서 루트 노드까지의 노드 수) * (루트 노드 서브 트리의 가중치) - (v 노드에서 루트 노드까지의 모든 노드의 가중치 합) * 2

여기서 말하는 v 노드에서 루트 노드까지 는 위로 쭉 올라가는 부모 체인을 의미한다.

v 노드 서브 트리 내 노드를 보자면
1. '모든 노드의 가중치 합'에서 본인 depth + 1만큼 더해진다.
2. '(v 노드에서 루트 노드까지의 노드 수) (루트 노드 서브 트리의 가중치)'에서 v 노드 depth + 1만큼 더해진다.
3. '- (v 노드에서 루트 노드까지의 모든 노드의 가중치 합)
2'에서 (v 노드 depth + 1) * 2만큼 빼진다.

v 노드 바깥 노드를 보자면
1. '모든 노드의 가중치 합'에서 본인 depth + 1만큼 더해진다.
2. '(v 노드에서 루트 노드까지의 노드 수) (루트 노드 서브 트리의 가중치)'에서 v 노드 depth + 1만큼 더해진다.
3. '- (v 노드에서 루트 노드까지의 모든 노드의 가중치 합)
2'에서 (v 노드 depth + 1) * 2만큼 빼진다.

v 노드에서 루트 노드까지의 모든 노드의 가중치 합에서 각 노드의 위치에 따라 적절히 빼져서 해당 거리에 맞게 값을 도출할 수 있다.

그러면 우리는 각 노드가 서브 트리의 가중치를 알고 있는 세그먼트 트리를 만들어야 한다.

이제 1번 쿼리, 2번 쿼리로 해당 세그먼트 트리를 유지할 수 있으면 된다.

1번 쿼리 O(log2N)O(log^2 N)

1 u v: 트리의 루트를 정점 u라 하였을 때, 정점 v를 루트로 하는 서브트리의 모든 정점 i의 AiA_i에 1을 더한다.

루트를 정점 u로 지정하는 부분에서 처리가 나뉜다.

오일러 투어 id 기준으로

  1. u가 v 내부에 있을 때

    1. 모든 노드에 서브 트리 노드 개수만큼 더해주고
    2. u에서 v 바로 밑의 노드를 찾아서
    3. 해당 노드의 서브 트리 노드들에 해당 서브 트리 노드 개수만큼 빼주고
    4. hld로 루트 노드까지 빼준 값만큼 다 빼준다.
  2. u가 v 외부에 있을 때

    1. 본인이 설정한 루트 노드 기준으로 v 노드 서브 트리 노드들에 해당 서브 트리 노드 개수만큼 더해준다.
    2. 더해준 만큼 hld로 루트 노드까지 더해준다.

1번 쿼리 과정에서 세그먼트 트리에 더해주거나 뺄 때 '그 값'을 더해줘야 할 때도 있고, 해당 노드의 서브 트리 노드 개수만큼 더해줘야 할 때도 있다. update 함수를 구별하여 잘 짜주자.

나는 SUB_ADD, SUB_MIN 이면 서브 트리 노드 개수만큼 더하거나 빼도록 처리해두고 다른 값이 들어오면 그 값을 더하거나 빼도록 처리해두었다.

2번 쿼리 O(log2N)O(log^2 N)

2 u v: 정점 u에서 정점 v로 가는 유일한 경로에 있는 모든 정점 i의 AiA_i 에 1을 더한다.

2번 쿼리는 하떨별의 응용이라 보면 된다.

하떨별에서는 range update, point find로 AC를 맞을 수 있는데, 이 문제의 2번 쿼리는 range update, range find라 보면 된다.
이를 위해서 하떨별을 range find로 바꿔서 풀어보았다.

1번 쿼리에서 보았듯이, u와 v 노드의 최소 공통 조상까지 값 처리를 해주고, 더해진 값만큼 최소 공통 조상의 부모에서부터 루트 체인까지 쭉 빼주면 된다.

다른 처리

3번 쿼리에서 범위 find를 할 때, seg[root]만 가져가도 되도록 처리했다.
lazy 처리를 할 때, 해당 범위 값들을 seg[root]에 바로바로 넣어주었다.
이를 위해서 ps 배열을 만들어 두었고, 이는 오일러 투어 id 기준 서브 트리 노드 개수의 누적합이다.

최종 Time Complexity : O(Qlog2N)O(Q log^2 N)

20148 AC

후기

알고리즘 분류가 담백한데 난이도가 높다면 무서운 문제이다..

AC 코드

#include <algorithm>
#include <iostream>
#include <vector>

using namespace std;

typedef long long ll;
const int vmax = 200001;
const int SUB_ADD = 2000000000, SUB_MIN = -2000000000;

int N, Q, a, b, c;
vector<vector<int>> tree(vmax);
vector<vector<int>> parent(19, vector<int>(vmax));
vector<int> depth(vmax);
vector<int> siz(vmax);
vector<int> in(vmax);
vector<int> rin(vmax);
vector<int> out(vmax);
vector<int> top(vmax, 1);
vector<ll> ps(vmax);
ll seg[vmax * 4];
ll lazy1[vmax * 4];
ll lazy2[vmax * 4];
ll del[vmax * 4];
int cnt[vmax * 4];

void makeParent() {
    for (int i = 1; i < 19; i++) {
        for (int j = 1; j <= N; j++) {
            parent[i][j] = parent[i - 1][parent[i - 1][j]];
        }
    }
}

int lca(int a, int b) {
    if (depth[a] > depth[b])
        swap(a, b);
    int diff = depth[b] - depth[a];
    int j = 0;
    while (diff) {
        if (diff & 1) {
            b = parent[j][b];
        }
        j++;
        diff /= 2;
    }

    if (a != b) {
        for (int i = 18; i >= 0; i--) {
            int pa = parent[i][a];
            int pb = parent[i][b];
            if (pa != pb) {
                a = pa;
                b = pb;
            }
        }
        a = parent[0][a];
    }

    return a;
}

void dfs1(int here = 1, int prev = 0) {
    siz[here] = 1;

    int maxx = -1, maxxIdx = -1;
    for (auto &next : tree[here]) {
        int idx = &next - &tree[here][0];
        if (next == prev) {
            swap(next, tree[here][0]);
            if (maxxIdx == 0)
                maxxIdx = idx;
            continue;
        }

        parent[0][next] = here;
        depth[next] = depth[here] + 1;

        dfs1(next, here);
        siz[here] += siz[next];
        if (siz[next] > maxx) {
            maxx = siz[next];
            maxxIdx = idx;
        }
    }

    if (maxxIdx != -1)
        swap(tree[here][0], tree[here][maxxIdx]);
}

int idx = 1;
void dfs2(int here = 1, int prev = 0) {
    rin[idx] = here;
    in[here] = idx++;

    for (auto &next : tree[here]) {
        if (next == prev)
            continue;
        int forIdx = &next - &tree[here][0];
        top[next] = forIdx == 0 ? top[here] : next;
        dfs2(next, here);
    }

    out[here] = idx;
}

void lazyUpdate(int root, int s, int e) {
    if (lazy1[root] == 0)
        return;

    if (s != e) {
        lazy1[root * 2] += lazy1[root];
        lazy1[root * 2 + 1] += lazy1[root];
    }
    seg[root] += lazy1[root] * (ps[e] - ps[s - 1]);
    lazy1[root] = 0;
}

void lazyUpdate2(int root, int s, int e) {
    if (lazy2[root] == 0)
        return;

    seg[root] += lazy2[root] * (e - s + 1);
    seg[root] += cnt[root] * ((ll)(e - s + 1) * (e - s) / 2);

    if (s != e) {
        int m = (s + e) / 2;
        lazy2[root * 2] += lazy2[root] + cnt[root] * (e - m);
        lazy2[root * 2 + 1] += lazy2[root];
        cnt[root * 2] += cnt[root];
        cnt[root * 2 + 1] += cnt[root];
    }
    lazy2[root] = 0;
    cnt[root] = 0;
}

void lazyDel(int root, int s, int e) {
    if (del[root] == 0)
        return;

    if (s != e) {
        del[root * 2] += del[root];
        del[root * 2 + 1] += del[root];
    }
    seg[root] -= del[root] * (e - s + 1);
    del[root] = 0;
}

void lazy(int root, int s, int e) {
    lazyUpdate(root, s, e);
    lazyUpdate2(root, s, e);
    lazyDel(root, s, e);
}

void ru(int v, int l, int r, int root = 1, int s = 1, int e = N) {
    lazy(root, s, e);
    if (r < s || l > e)
        return;
    if (l <= s && e <= r) {
        if (v == SUB_ADD)
            lazy1[root]++;
        else if (v == SUB_MIN)
            lazy1[root]--;
        else
            del[root] += v;
        lazy(root, s, e);
        return;
    }

    int m = (s + e) / 2;
    ru(v, l, r, root * 2, s, m);
    ru(v, l, r, root * 2 + 1, m + 1, e);
    seg[root] = seg[root * 2] + seg[root * 2 + 1];
}

void ru2(int v, int l, int r, int root = 1, int s = 1, int e = N) {
    lazy(root, s, e);
    if (r < s || l > e)
        return;
    if (l <= s && e <= r) {
        lazy2[root] += v + r - e;
        cnt[root]++;
        lazy(root, s, e);
        return;
    }

    int m = (s + e) / 2;
    ru2(v, l, r, root * 2, s, m);
    ru2(v, l, r, root * 2 + 1, m + 1, e);
    seg[root] = seg[root * 2] + seg[root * 2 + 1];
}

ll rf(int l, int r, int root = 1, int s = 1, int e = N) {
    lazy(root, s, e);
    if (r < s || l > e)
        return 0;
    if (l <= s && e <= r)
        return seg[root];

    int m = (s + e) / 2;
    return rf(l, r, root * 2, s, m) + rf(l, r, root * 2 + 1, m + 1, e);
}

void hld2(int t, int a, int fixed = 1) {
    while (top[t] != top[a]) {
        ru2(fixed, in[top[a]], in[a]);
        fixed += in[a] - in[top[a]] + 1;
        a = parent[0][top[a]];
    }
    ru2(fixed, in[t], in[a]);
}

ll hld3(int t, int a) {
    ll res = 0;
    while (top[t] != top[a]) {
        res += rf(in[top[a]], in[a]);
        a = parent[0][top[a]];
    }
    res += rf(in[t], in[a]);
    return res;
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(NULL);
    cout.tie(NULL);
#ifndef ONLINE_JUDGE
    freopen("data.txt", "r", stdin);
#endif

    cin >> N;
    for (int i = 1; i < N; i++) {
        cin >> a >> b;
        tree[a].push_back(b);
        tree[b].push_back(a);
    }

    dfs1();
    dfs2();

    makeParent();

    vector<int> temp(vmax, 0);
    for (int i = 1; i <= N; i++)
        temp[in[i]] = siz[i];
    for (int i = 1; i <= N; i++)
        ps[i] = ps[i - 1] + temp[i];

    cin >> Q;
    for (int i = 0; i < Q; i++) {
        cin >> a >> b;
        if (a == 1) {
            cin >> c;
            if (in[c] <= in[b] && in[b] <= out[c] - 1) {
                ru(SUB_ADD, 1, out[1] - 1);
                if (b == c)
                    continue;
                while (top[c] != top[b]) {
                    if (parent[0][top[b]] == c) {
                        b = top[b];
                        break;
                    }
                    b = parent[0][top[b]];
                }
                if (top[c] == top[b])
                    b = rin[in[c] + 1];

                ru(SUB_MIN, in[b], out[b] - 1);
                int del = siz[b];
                b = parent[0][b];
                while (1 != top[b]) {
                    ru(del, in[top[b]], in[b]);
                    b = parent[0][top[b]];
                }
                ru(del, 1, in[b]);
            } else {
                ru(SUB_ADD, in[c], out[c] - 1);
                int add = siz[c];
                if (c == 1)
                    continue;
                c = parent[0][c];
                while (1 != top[c]) {
                    ru(-add, in[top[c]], in[c]);
                    c = parent[0][top[c]];
                }
                ru(-add, 1, in[c]);
            }
        } else if (a == 2) {
            cin >> c;
            int t = lca(c, b);
            if (in[b] > in[c])
                swap(b, c);

            hld2(t, c);
            int add = depth[c] - depth[t] + 1;
            if (out[b] - 1 < in[c]) {
                add += depth[b] - depth[t];
                hld2(t, b);
                ru(1, in[t], in[t]);
            }

            if (t == 1)
                continue;
            t = parent[0][t];
            while (1 != top[t]) {
                ru(-add, in[top[t]], in[t]);
                t = parent[0][top[t]];
            }
            ru(-add, 1, in[t]);
        } else {
            ll res = (ll)(depth[b] + 1) * rf(1, 1);
            res += rf(1, out[1] - 1);
            res -= hld3(1, b) * 2;
            cout << res << '\n';
        }
    }

    return 0;
}
post-custom-banner

0개의 댓글