Heavy Light Decomposition (HLD)

코딩생활·2024년 1월 1일
0

알고리즘

목록 보기
4/4

안녕하세요. 오늘은 HLD를 배울거예요.

HLD란?

HLD는 이름 그대로 무거운거랑 가벼운거를 나누는 것입니다.
이런 쿼리가 있을 때 많이 쓰죠.

트리에서
특정 간선(혹은 정점)의 가중치 바꾸기
특정 경로에서 가중치를 더하거나 곱하거나 xor하거나 최대최소 찾기

이게 사실 선형이면 세그트리로 너무쉽게 해결되는 문제입니다.
하지만 여기는 트리 속이므로 그게 어렵습니다. 그래서 간선들을 길게 잘라서 만드는것이 HLD입니다.

HLD 알고리즘

무겁다/가볍다는 무게를 비교하는 표현입니다.
그래서 트리에서도 무게를 표현할겁니다.

특정 정점을 루트로 하는 서브트리의 크기를 sz[node]라고 합시다.
만약 next가 node의 자식이고 sz[node]<=sz[next]x2 라면, 즉 어떤 자식이 자신의 반 이상을 차지하고 있다면 자신과 그 노드를 잇는 간선이 heavy하게 됩니다. 그 자식이 무겁다는 뜻이죠. 그래서 한 node에서 자식으로 내려가는 heavy간선은 0개 혹은 1개가 됩니다.

여기서 핵심 관찰이 필요합니다.
바로 heavy간선을 타고 올라가면 sz가 두배를 넘기지 못하지만, 반대로 heavy간선이 아닌 노드를 타고 올라가면 sz가 두배를 넘기게 됩니다. 이런 간선을 light간선이라고 합니다. light 간선을 타고 올라가면 sz값이 2배이상씩 커지므로 어떤 정점을 잡든지 light간선을 log N개 이하로 타고가면 루트에 도달할 수 있게 됩니다.

여기서 또 관찰이 필요합니다.
바로 heavy 간선들끼리는 같은 체인으로 생각하는 것입니다. 그러면heavy간선들을 타고 최대한 많이 올라가면 light간선처럼 sz가 두배 이상 되는 노드가 나타나고 관리할 체인이 최대 log N개가 됩니다. 그런데 특정 노드의 자식들중 heavy간선으로 연결된 자식을 가장 첫번째로 배치할 겁니다. 그러므로 어떤 정점의 in값이 val이라면 그 정점에서 heavy값으로 내려온 자식은 in값이 val+1이 됩니다. 이는 세그먼트 트리를 사용하기 위해 연속성을 부여한 것으로, 한 체인에서는 연속성이 있어야하기 때문입니다.

참고로 HLD는 정점에 가중치가 있는것이 기본적인 세팅이므로 간선에 가중치를 넣으려면 그 간선에 있는 두 노드중 자식노드를 그 간선과 동일하게 생각해주면 됩니다.

구현

방법은 정말 많지만 맞는 구현을 한가지 라는 말이 있을 정도로 대중적이고 쉬운 구현 방법이 존재합니다.
배열과 함수부터 정의합시다.

tree: 세그트리에 쓰이는 tree
num: 세그트리의 값을 채우는 배열
in,out: dfs ordering (dfs 순서)
dpt: 깊이
par: 부모노드
top: 그 체인을 타고 올라가면 나오는 노드
sz: 그 노드를 루트로 하는 서브트리의 크기
ck: dfs 체크
down: 자식노드들 (나중에는 down[node][0]이 heavy간선으로 잇는 노드가 됨)
graph: 입력으로 받는 그래프
edge: 입력으로 받는 간선들만 따로 모으기
EdgeToVertex: i번째 간선이 잇는 두 정점중에서 자식노드

init,query,change: 세그트리와 관련된 함수들 (최댓값 함수)
dfs: down벡터 채우기
dfs1: sz,dpt,par채우기, down[node][0]값에 가장 무거운 자식 넣기
dfs2: in,out,top 채우기
update: change의 작은 버전
MX: 두 노드 사이의 간선의 최댓값; 자세한 내용은 밑에서

MX에서는 이런일을 합니다.
두 노드 x,y의 top값이 다르면 계속 하기 (다른 체인에 있으면)
각 체인의 top값 까지 보면서 세그트리로 최댓값 계속 갱신하면서 올라가기
top값이 같아지면 같은 체인에 있다는 뜻이므로 세그트리 한번 하고 끝내기

소스코드

HLD 기본 문제인 트리와 쿼리 1 문제의 코드를 올리겠습니다.

#include <iostream>
#include <vector>
#include <algorithm>
#define ll long long
using namespace std;

ll tree[404040] = { 0 }, num[404040] = { 0 };
ll init(ll s, ll e, ll node)
{
    if (s == e) return tree[node] = num[s];
    ll mid = (s + e) / 2;
    return tree[node] = max(init(s, mid, node * 2), init(mid + 1, e, node * 2 + 1));
}
ll query(ll s, ll e, ll node, ll l, ll r)
{
    if (e < l || r < s) return 0;
    if (l <= s && e <= r) return tree[node];
    ll mid = (s + e) / 2;
    return max(query(s, mid, node * 2, l, r), query(mid + 1, e, node * 2 + 1, l, r));
}
void change(ll s, ll e, ll node, ll idx, ll value)
{
    if (e < idx || idx < s) return;
    if (s == e)
    {
        tree[node] = value;
        return;
    }
    ll mid = (s + e) / 2;
    change(s, mid, node * 2, idx, value);
    change(mid + 1, e, node * 2 + 1, idx, value);
    tree[node] = max(tree[node * 2], tree[node * 2 + 1]);
}


ll in[101010] = { 0 }, out[101010] = { 0 }, dpt[101010] = { 0 }, par[101010] = { 0 }, top[101010] = { 0 }, sz[101010] = { 0 };
bool ck[101010] = { 0 };
vector <ll> down[101010];
vector <ll> graph[101010];

void dfs(ll node)
{
    ck[node] = true;
    for (ll next : graph[node])
    {
        if (ck[next]) continue;
        down[node].push_back(next);
        dfs(next);
    }
}
void dfs1(ll node)
{
    sz[node] = 1; //처음에는 자기자신뿐

    ll mx = 0, p = 0, idx = 0;
    for (ll next : down[node])
    {
        dpt[next] = dpt[node] + 1;
        par[next] = node;
        dfs1(next);
        sz[node] += sz[next];
        if (sz[next] > mx) //최댓값을 발견하면
        {
            mx = sz[next]; //저장
            p = idx;
        }
        idx++;
    }
    if (down[node].size()) swap(down[node][0], down[node][p]); //heavy간선이 젤 앞에 오게
}
ll pv = 0;
void dfs2(ll node)
{
    in[node] = ++pv;
    for (ll next : down[node])
    {
        if (next == down[node][0]) //heavy간선이면
            top[next] = top[node]; //같은 묶음에 있으므로 top값 물려받기
        else //아니면
            top[next] = next; //자기자신
        dfs2(next);
    }
    out[node] = pv;
}

ll N;
void update(ll idx, ll value)
{
    change(1, N, 1, in[idx], value);
}
ll MX(ll x, ll y)
{
    ll ans = 0;
    while (top[x] != top[y])
    {
        if (dpt[top[x]] < dpt[top[y]]) swap(x, y);
        ll st = top[x];
        ans = max(ans, query(1, N, 1, in[st], in[x]));
        x = par[st];
    }

    if (dpt[x] > dpt[y]) swap(x, y);
    if (x == y) //완전히 똑같다면
        return ans; //그대로 반환
    else //아니면
        return max(ans, query(1, N, 1, in[x] + 1, in[y]));
}

int main(void)
{
    ios_base::sync_with_stdio(false); cin.tie(NULL);
    vector <pair <pair <ll,ll>, ll> > edge; edge.push_back({ { -1,-1 } ,-1}); //인덱스는 1부터 시작하게
    vector <ll> EdgeToVertex(101010);
    ll M, i, a, b, c;

    cin >> N;
    for (i = 1; i <= N - 1; i++)
    {
        cin >> a >> b >> c;
        edge.push_back({ { a,b } ,c });
        graph[a].push_back(b);
        graph[b].push_back(a);
    }
    dfs(1); dfs1(1); dfs2(1);
    for (i = 1; i <= N - 1; i++)
    {
        ll a = edge[i].first.first, b = edge[i].first.second;
        if (dpt[a] < dpt[b]) EdgeToVertex[i] = b;
        else EdgeToVertex[i] = a;
        num[in[EdgeToVertex[i]]] = edge[i].second;
    }
    init(1, N, 1);

    cin >> M;
    for (i = 0; i < M; i++)
    {
        cin >> a >> b >> c;
        if (a == 1)
            update(EdgeToVertex[b], c);
        else
            cout << MX(b, c) << "\n";
    }
}

안녕히계세요~

0개의 댓글