안녕하세요. 오늘은 트리에서 개수를 셀 거예요.
https://www.acmicpc.net/problem/15899
일단 이 문제가 트리가 아닌 선형에서 풀었다고 생각해봅시다.
그러면 머지소트트리를 활용해서 모든 값을 저장한 뒤, upper_bound함수를 사용해서 개수를 셀 수 있습니다. 하지만 이 문제는 트리이므로 다른 방법을 써야합니다.
트리를 선형처럼, 특히 서브트리 문제에서 선형처럼 풀 수 있는 방법이 하나 있습니다. 바로 오일러 경로 테크닉입니다. 이 기술(?)은 들어가는 순서, 나오는 순서를 기준으로 잡아서 트리의 값들에 연속성을 부여하는 것입니다. 그래서 in값, out값, 역으로 인덱스가 있을때 원래 값을 알고싶으면 ReverseIdx값까지 있으면 됩니다.
#include <iostream>
#include <vector>
#include <algorithm>
#define ll long long
using namespace std;
vector <ll> merge(vector <ll> v, vector <ll> v2)
{
vector <ll> ans;
ll size = v.size(), size2 = v2.size();
ll i = 0, j = 0;
while (i < size && j < size2)
{
if (v[i] < v2[j]) ans.push_back(v[i++]);
else ans.push_back(v2[j++]);
}
while (i < size)
ans.push_back(v[i++]);
while (j < size2)
ans.push_back(v2[j++]);
return ans;
}
vector <ll> MergeSortTree[808080];
ll in[202020] = { 0 }, out[202020] = { 0 }, ReverseIdx[202020] = { 0 }; //inorder, outorder, inorder을 거꾸로 한 값 (inorder[i]=x, ReverseIdx[x]=i)
ll color[202020] = { 0 };
vector <ll> init(ll s, ll e, ll node)
{
if (s == e)
{
MergeSortTree[node].push_back(color[ReverseIdx[s]]);
return MergeSortTree[node];
}
ll mid = (s + e) / 2;
MergeSortTree[node] = merge(init(s, mid, node * 2), init(mid + 1, e, node * 2 + 1));
return MergeSortTree[node];
}
ll find(ll s, ll e, ll node, ll l, ll r, ll num)
{
if (e < l || r < s) return 0;
if (l <= s && e <= r) return (upper_bound(MergeSortTree[node].begin(), MergeSortTree[node].end(), num) - MergeSortTree[node].begin()); //num을 초과하는 가장 작은 인덱스=num이하인 수의 개수
ll mid = (s + e) / 2;
return find(s, mid, node * 2, l, r, num) + find(mid + 1, e, node * 2 + 1, l, r, num);
}
vector <ll> graph[202020];
ll cnt = 0;
void DFS(ll node, ll up)
{
in[node] = ++cnt;
ReverseIdx[cnt] = node;
for (ll next : graph[node])
if (next != up)
DFS(next, node);
out[node] = cnt;
}
int main(void)
{
ios_base::sync_with_stdio(false); cin.tie(NULL);
ll N, M, C, a, b, i;
cin >> N >> M >> C;
for (i = 1; i <= N; i++) cin >> color[i];
for (i = 0; i < N - 1; i++)
{
cin >> a >> b;
graph[a].push_back(b);
graph[b].push_back(a);
}
DFS(1, 0);
init(1, N, 1);
ll sum = 0;
for (i = 0; i < M; i++)
{
cin >> a >> b;
sum += find(1, N, 1, in[a], out[a], b);
sum %= 1000000007;
}
cout << sum;
}
감사합니다.