안녕하세요. 오늘은 트리에서 개수를 셀 거예요.

문제

https://www.acmicpc.net/problem/15899

아이디어

일단 이 문제가 트리가 아닌 선형에서 풀었다고 생각해봅시다.
그러면 머지소트트리를 활용해서 모든 값을 저장한 뒤, upper_bound함수를 사용해서 개수를 셀 수 있습니다. 하지만 이 문제는 트리이므로 다른 방법을 써야합니다.

트리를 선형처럼, 특히 서브트리 문제에서 선형처럼 풀 수 있는 방법이 하나 있습니다. 바로 오일러 경로 테크닉입니다. 이 기술(?)은 들어가는 순서, 나오는 순서를 기준으로 잡아서 트리의 값들에 연속성을 부여하는 것입니다. 그래서 in값, out값, 역으로 인덱스가 있을때 원래 값을 알고싶으면 ReverseIdx값까지 있으면 됩니다.

소스코드

#include <iostream>
#include <vector>
#include <algorithm>
#define ll long long
using namespace std;

vector <ll> merge(vector <ll> v, vector <ll> v2)
{
    vector <ll> ans;
    ll size = v.size(), size2 = v2.size();
    ll i = 0, j = 0;

    while (i < size && j < size2)
    {
        if (v[i] < v2[j]) ans.push_back(v[i++]);
        else ans.push_back(v2[j++]);
    }
    while (i < size)
        ans.push_back(v[i++]);
    while (j < size2)
        ans.push_back(v2[j++]);
    return ans;
}

vector <ll> MergeSortTree[808080];
ll in[202020] = { 0 }, out[202020] = { 0 }, ReverseIdx[202020] = { 0 }; //inorder, outorder, inorder을 거꾸로 한 값 (inorder[i]=x, ReverseIdx[x]=i)
ll color[202020] = { 0 };
vector <ll> init(ll s, ll e, ll node)
{
    if (s == e)
    {
        MergeSortTree[node].push_back(color[ReverseIdx[s]]);
        return MergeSortTree[node];
    }
    ll mid = (s + e) / 2;
    MergeSortTree[node] = merge(init(s, mid, node * 2), init(mid + 1, e, node * 2 + 1));
    return MergeSortTree[node];
}
ll find(ll s, ll e, ll node, ll l, ll r, ll num)
{
    if (e < l || r < s) return 0;
    if (l <= s && e <= r) return (upper_bound(MergeSortTree[node].begin(), MergeSortTree[node].end(), num) - MergeSortTree[node].begin()); //num을 초과하는 가장 작은 인덱스=num이하인 수의 개수
    ll mid = (s + e) / 2;
    return find(s, mid, node * 2, l, r, num) + find(mid + 1, e, node * 2 + 1, l, r, num);
}

vector <ll> graph[202020];
ll cnt = 0;
void DFS(ll node, ll up)
{
    in[node] = ++cnt;
    ReverseIdx[cnt] = node;
    for (ll next : graph[node])
        if (next != up)
            DFS(next, node);
    out[node] = cnt;
}

int main(void)
{
    ios_base::sync_with_stdio(false); cin.tie(NULL);
    ll N, M, C, a, b, i;

    cin >> N >> M >> C;
    for (i = 1; i <= N; i++) cin >> color[i];
    for (i = 0; i < N - 1; i++)
    {
        cin >> a >> b;
        graph[a].push_back(b);
        graph[b].push_back(a);
    }
    DFS(1, 0);
    init(1, N, 1);

    ll sum = 0;
    for (i = 0; i < M; i++)
    {
        cin >> a >> b;
        sum += find(1, N, 1, in[a], out[a], b);
        sum %= 1000000007;
    }
    cout << sum;
}


감사합니다.

0개의 댓글