안녕하세요. 오늘은 최댓값을 찾을 거예요.

문제

https://www.acmicpc.net/problem/13510

아이디어

기본적인 HLD 문제입니다.

소스코드

#include <iostream>
#include <vector>
#include <algorithm>
#define ll long long
using namespace std;

ll tree[404040] = { 0 }, num[404040] = { 0 };
ll init(ll s, ll e, ll node)
{
    if (s == e) return tree[node] = num[s];
    ll mid = (s + e) / 2;
    return tree[node] = max(init(s, mid, node * 2), init(mid + 1, e, node * 2 + 1));
}
ll query(ll s, ll e, ll node, ll l, ll r)
{
    if (e < l || r < s) return 0;
    if (l <= s && e <= r) return tree[node];
    ll mid = (s + e) / 2;
    return max(query(s, mid, node * 2, l, r), query(mid + 1, e, node * 2 + 1, l, r));
}
void change(ll s, ll e, ll node, ll idx, ll value)
{
    if (e < idx || idx < s) return;
    if (s == e)
    {
        tree[node] = value;
        return;
    }
    ll mid = (s + e) / 2;
    change(s, mid, node * 2, idx, value);
    change(mid + 1, e, node * 2 + 1, idx, value);
    tree[node] = max(tree[node * 2], tree[node * 2 + 1]);
}


ll in[101010] = { 0 }, out[101010] = { 0 }, dpt[101010] = { 0 }, par[101010] = { 0 }, top[101010] = { 0 }, sz[101010] = { 0 };
bool ck[101010] = { 0 };
vector <ll> down[101010];
vector <ll> graph[101010];

void dfs(ll node)
{
    ck[node] = true;
    for (ll next : graph[node])
    {
        if (ck[next]) continue;
        down[node].push_back(next);
        dfs(next);
    }
}
void dfs1(ll node)
{
    sz[node] = 1; //처음에는 자기자신뿐

    ll mx = 0, p = 0, idx = 0;
    for (ll next : down[node])
    {
        dpt[next] = dpt[node] + 1;
        par[next] = node;
        dfs1(next);
        sz[node] += sz[next];
        if (sz[next] > mx) //최댓값을 발견하면
        {
            mx = sz[next]; //저장
            p = idx;
        }
        idx++;
    }
    if (down[node].size()) swap(down[node][0], down[node][p]); //heavy간선이 젤 앞에 오게
}
ll pv = 0;
void dfs2(ll node)
{
    in[node] = ++pv;
    for (ll next : down[node])
    {
        if (next == down[node][0]) //heavy간선이면
            top[next] = top[node]; //같은 묶음에 있으므로 top값 물려받기
        else //아니면
            top[next] = next; //자기자신
        dfs2(next);
    }
    out[node] = pv;
}

ll N;
void update(ll idx, ll value)
{
    change(1, N, 1, in[idx], value);
}
ll MX(ll x, ll y)
{
    ll ans = 0;
    while (top[x] != top[y])
    {
        if (dpt[top[x]] < dpt[top[y]]) swap(x, y);
        ll st = top[x];
        ans = max(ans, query(1, N, 1, in[st], in[x]));
        x = par[st];
    }

    if (dpt[x] > dpt[y]) swap(x, y);
    if (x == y) //완전히 똑같다면
        return ans; //그대로 반환
    else //아니면
        return max(ans, query(1, N, 1, in[x] + 1, in[y]));
}

int main(void)
{
    ios_base::sync_with_stdio(false); cin.tie(NULL);
    vector <pair <pair <ll,ll>, ll> > edge; edge.push_back({ { -1,-1 } ,-1}); //인덱스는 1부터 시작하게
    vector <ll> EdgeToVertex(101010);
    ll M, i, a, b, c;

    cin >> N;
    for (i = 1; i <= N - 1; i++)
    {
        cin >> a >> b >> c;
        edge.push_back({ { a,b } ,c });
        graph[a].push_back(b);
        graph[b].push_back(a);
    }
    dfs(1); dfs1(1); dfs2(1);
    for (i = 1; i <= N - 1; i++)
    {
        ll a = edge[i].first.first, b = edge[i].first.second;
        if (dpt[a] < dpt[b]) EdgeToVertex[i] = b;
        else EdgeToVertex[i] = a;
        num[in[EdgeToVertex[i]]] = edge[i].second;
    }
    init(1, N, 1);

    cin >> M;
    for (i = 0; i < M; i++)
    {
        cin >> a >> b >> c;
        if (a == 1)
            update(EdgeToVertex[b], c);
        else
            cout << MX(b, c) << "\n";
    }
}


감사합니다.

0개의 댓글