[BOJ 13896] - Sky Tax (최소 공통 조상, 트리, 희소 배열, C++, Python)

보양쿠·2023년 4월 6일
0

BOJ

목록 보기
96/260
post-custom-banner

BOJ 13896 - Sky Tax 링크
(2023.04.06 기준 P3)
(No Cheating)

문제

트리 형태의 N개의 도시가 있고 수도는 R번 도시이다.
쿼리가 "S, U" 형태로 Q개 주어진다.

  • S = 0 : 수도를 U로 바꾸기.
  • S = 1 : 모든 도시에서 수도로 향할 때, U번 도시를 거치는 도시 개수 출력.

쿼리를 알맞게 처리하기.

알고리즘

1 <= N <= 100000, 1 <= Q <= 50000 이다. 그러므로 수도가 바뀔 때마다 다시 DFS로 서브 트리의 크기를 구하는 방법으로 구하면 TLE다.
최소 공통 조상을 이용해 한번 잘 풀어보자.

풀이

이런 형태의 트리가 있다고 생각을 해보자. 1을 루트로 한 트리이며 현재 수도는 5번이다.

만약 수도와 U가 같다면? 모든 도시가 수도로 가야 하며 또 수도는 U이니깐 이 쿼리의 답은 전체 도시 수인 N이 된다.

만약 수도와 U의 최소 공통 조상이 U와 다르다면? 이 쿼리의 답은 U를 루트로 한 서브 트리의 크기가 된다.

만약 수도와 U의 최소 공통 조상이 U가 된다면? 그림을 보면 수도는 5번, 수도는 1번, 최소 공통 조상은 1번이 된다. 수도로 가기 위해 1번을 거치는 도시들은 직관적으로 봤을 때 1번을 포함한 왼쪽 자식들이다. 그렇다면 포함하지 않는 자식들은? 1번의 오른쪽 자식. 즉, 수도에서 U로 거슬러 올라갈 때 가장 마지막 도시를 루트로 한 서브트리의 크기가 포함하지 않는 자식인 것이다.

다른 트리를 살펴보자.
결국 수도와 U의 최소 공통 조상이 U가 된다면, (N - U로 거슬러 올라갈 때 가장 마지막 도시를 루트로 한 서브트리의 크기)가 답이 된다.

코드 (2023.07.04 수정)

  • C++
#include <bits/stdc++.h>
using namespace std;

const int MAXN = 100000, MAXH = (int)ceil(log2(MAXN));

int N, Q, R, H, lv[MAXN], sz[MAXN];
int pa[MAXN][MAXH]; // 희소 배열
vector<int> graph[MAXN];

int dfs(int i, int p){
    sz[i] = 1;
    for (auto j: graph[i]){
        if (j == p) continue;
        pa[j][0] = i;
        lv[j] = lv[i] + 1;
        sz[i] += dfs(j, i);
    }
    return sz[i];
}

int lca(int i, int j){
    if (lv[i] < lv[j]) swap(i, j);
    int dif = lv[i] - lv[j];

    int k = 0;
    while (dif){
        if (dif & 1) i = pa[i][k];
        dif >>= 1; k++;
    }

    if (i != j){
        for (k = H - 1; k >= 0; k--)
            if (pa[i][k] != pa[j][k]) i = pa[i][k], j = pa[j][k];
        i = pa[i][0];
    }

    return i;
}

void solve(){
    cin >> N >> Q >> R;
    R--; // 0-based index
    for (int i = 0; i < N; i++) graph[i].clear(); // 그래프 초기화

    for (int i = 1, A, B; i < N; i++){
        cin >> A >> B;
        graph[--A].push_back(--B);
        graph[B].push_back(A);
    }

    // 희소 배열, 깊이, 서브트리의 크기 초기화
    H = (int)ceil(log2(N));
    fill(&pa[0][0], &pa[N - 1][H], -1);
    fill(lv, lv + N, 0);
    fill(sz, sz + N, 0);

    // 희소 배열 채우기
    dfs(0, -1);
    for (int j = 1; j < H; j++) for (int i = 0; i < N; i++)
        pa[i][j] = pa[pa[i][j - 1]][j - 1];

    for (int i = 0, S, U; i < Q; i++){
        cin >> S >> U;
        U--; // 0-based index
        if (S){
            if (R == U) cout << N << '\n'; // U가 수도와 같다면 모든 도시가 U. 즉, 수도를 거친다.
            else{
                int l = lca(R, U);
                if (l == U){
                    int dif = lv[R] - lv[U] - 1; // 수도에서 U로 가는 경로 중 가장 마지막 노드로 가서
                    int r = R, k = 0;
                    while (dif){
                        if (dif & 1) r = pa[r][k];
                        dif >>= 1; k++;
                    }
                    cout << N - sz[r] << '\n'; // '가장 마지막 노드를 루트로 한 서브트리의 모든 노드'를 제외한 모든 노드가 U를 거쳐 수도로 간다.
                }
                else cout << sz[U] << '\n'; // U를 루트로 한 서브트리의 모든 노드가 U를 거쳐서 수도로 간다.
            }
        }
        else R = U; // 수도 바꾸기
    }
}

int main(){
    ios_base::sync_with_stdio(0);
    cin.tie(0);

    int T;
    cin >> T;
    for (int i = 1; i <= T; i++){
        cout << "Case #" << i << ":" << '\n';
        solve();
    }
}
  • Python (PyPy3)
import sys; input = sys.stdin.readline
sys.setrecursionlimit(100000)
from math import ceil, log2

MAXN = 100000; MAXH = ceil(log2(MAXN))
lv = [0] * MAXN; sz = [0] * MAXN
pa = [[0] * MAXH for _ in range(MAXN)]
graph = [[] for _ in range(MAXN)]

def dfs(i, p):
    sz[i] = 1
    for j in graph[i]:
        if j == p:
            continue
        pa[j][0] = i
        lv[j] = lv[i] + 1
        sz[i] += dfs(j, i)
    return sz[i]

def lca(i, j):
    if lv[i] < lv[j]:
        i, j = j, i
    dif = lv[i] - lv[j]

    k = 0
    while dif:
        if dif & 1:
            i = pa[i][k]
        dif >>= 1
        k += 1

    if i != j:
        for k in range(H - 1, -1, -1):
            if pa[i][k] != pa[j][k] != -1:
                i = pa[i][k]
                j = pa[j][k]
        i = pa[i][0]

    return i

for T in range(1, int(input()) + 1):
    print('Case #%d:' % T)

    N, Q, R = map(int, input().split())
    R -= 1 # 0-based index
    for i in range(N): # 그래프 초기화
        graph[i].clear()

    for _ in range(N - 1):
        A, B = map(int, input().split())
        A -= 1; B -= 1
        graph[A].append(B)
        graph[B].append(A)

    # 희소 배열, 깊이, 서브트리의 크기 초기화
    H = ceil(log2(N))
    for i in range(N):
        lv[i] = sz[i] = 0
        for j in range(H):
            pa[i][j] = 0

    # 희소 배열 채우기
    dfs(0, -1)
    for j in range(1, H):
        for i in range(N):
            pa[i][j] = pa[pa[i][j - 1]][j - 1]

    for _ in range(Q):
        S, U = map(int, input().split())
        U -= 1 # 0-based index
        if S:
            if R == U: # U가 수도와 같다면 모든 도시가 U. 즉, 수도를 거친다.
                print(N)
            else:
                l = lca(R, U)
                if l == U: #  U를 루트로 한 서브트리에 수도가 포함된다.
                    dif = lv[R] - lv[U] - 1 # 수도에서 U로 가는 경로 중 가장 마지막 노드로 가서
                    r = R; k = 0
                    while dif:
                        if dif & 1:
                            r = pa[r][k]
                        dif >>= 1
                        k += 1
                    print(N - sz[r]) # '가장 마지막 노드를 루트로 한 서브트리의 모든 노드'를 제외한 모든 노드가 U를 거쳐 수도로 간다.
                else: # U를 루트로 한 서브트리의 모든 노드가 U를 거쳐서 수도로 간다.
                    print(sz[U])
        else:
            R = U # 수도 바꾸기

코드 (수정 전)

  • C++
#include <bits/stdc++.h>
using namespace std;

int N, Q, R, A, B, M, S, U, l, dif, w, r, level[100000], sz[100000];
int parent[100000][(int)ceil(log2(100000))]; // 희소 배열
vector<int> graph[100000];

int dfs(int here, int prev){
    sz[here] = 1;
    for (auto there: graph[here]){
        if (there == prev) continue;
        parent[there][0] = here;
        level[there] = level[here] + 1;
        sz[here] += dfs(there, here);
    }
    return sz[here];
}

int lca(int u, int v){
    if (level[u] < level[v]) swap(u, v);
    dif = level[u] - level[v];

    w = 0;
    while (dif){
        if (dif & 1) u = parent[u][w];
        dif >>= 1;
        w += 1;
    }

    if (u != v){
        for (int w = M - 1; w; w--) if (parent[u][w] != parent[v][w]) u = parent[u][w], v = parent[v][w];
        u = parent[u][0];
    }

    return u;
}

void solve(){
    cin >> N >> Q >> R;
    R--; // 0-based index
    for (int i = 0; i < N; i++) graph[i].clear(); // 그래프 초기화

    for (int i = 0; i < N - 1; i++){
        cin >> A >> B;
        graph[--A].push_back(--B);
        graph[B].push_back(A);
    }

    M = ceil(log2(N));
    for (int i = 0; i < N; i++){ // 희소 배열, 깊이, '자기를 루트로 한 서브트리의 크기' 초기화
        for (int j = 0; j < M; j++) parent[i][j] = -1;
        level[i] = 0, sz[i] = 0;
    }
    dfs(0, -1);

    for (int j = 1; j < M; j++) for (int i = 0; i < N; i++){ // 희소 배열 완성
        if (parent[i][j - 1] != -1) parent[i][j] = parent[parent[i][j - 1]][j - 1];
    }

    for (int i = 0; i < Q; i++){
        cin >> S >> U;
        U--;
        if (S){
            if (R == U) cout << N << '\n'; // U가 수도와 같다면 모든 도시가 U. 즉, 수도를 거친다.
            else{
                l = lca(R, U);
                if (l == U){
                    r = R;
                    dif = level[r] - level[U] - 1; // 수도에서 U로 가는 경로 중 가장 마지막 노드로 가서
                    w = 0;
                    while (dif){
                        if (dif & 1) r = parent[r][w];
                        dif >>= 1;
                        w += 1;
                    }
                    cout << N - sz[r] << '\n'; // '가장 마지막 노드를 루트로 한 서브트리의 모든 노드'를 제외한 모든 노드가 U를 거쳐 수도로 간다.
                }
                else cout << sz[U] << '\n'; // U를 루트로 한 서브트리의 모든 노드가 U를 거쳐서 수도로 간다.
            }
        }
        else R = U; // 수도 바꾸기
    }
}

int main(){
    ios_base::sync_with_stdio(0);
    cin.tie(0);

    int T;
    cin >> T;
    for (int i = 1; i <= T; i++){
        cout << "Case #" << i << ":" << '\n';
        solve();
    }
}
  • Python (PyPy3)
import sys; input = sys.stdin.readline
sys.setrecursionlimit(100000)
from math import ceil, log2

def dfs(here, prev):
    size[here] = 1
    for there in graph[here]:
        if there == prev:
            continue
        parent[there][0] = here
        level[there] = level[here] + 1
        size[here] += dfs(there, here)
    return size[here]

def lca(u, v):
    if level[u] < level[v]:
        u, v = v, u
    dif = level[u] - level[v]

    w = 0
    while dif:
        if dif & 1:
            u = parent[u][w]
        dif >>= 1
        w += 1

    if u != v:
        for w in range(M - 1, -1, -1):
            if parent[u][w] != parent[v][w]:
                u = parent[u][w]
                v = parent[v][w]
        u = parent[u][0]

    return u

for T in range(1, int(input()) + 1):
    print('Case #%d:' % T)

    N, Q, R = map(int, input().split())
    R -= 1 # 0-based index

    graph = [[] for _ in range(N)]
    for _ in range(N - 1):
        A, B = map(int, input().split())
        A -= 1; B -= 1
        graph[A].append(B)
        graph[B].append(A)

    M = ceil(log2(N))
    parent = [[-1] * M for _ in range(N)] # 희소 배열
    level = [0] * N # 깊이
    size = [0] * N # 자기를 루트로 한 서브트리의 크기
    dfs(0, -1)

    for j in range(1, M): # 희소 배열 완성
        for i in range(N):
            if parent[i][j - 1] != -1:
                parent[i][j] = parent[parent[i][j - 1]][j - 1]

    for _ in range(Q):
        S, U = map(int, input().split())
        U -= 1
        if S:
            if R == U: # U가 수도와 같다면 모든 도시가 U. 즉, 수도를 거친다.
                print(N)
            else:
                l = lca(R, U)
                if l == U: #  U를 루트로 한 서브트리에 수도가 포함된다.
                    r = R
                    dif = level[r] - level[U] - 1 # 수도에서 U로 가는 경로 중 가장 마지막 노드로 가서
                    k = 0
                    while dif:
                        if dif & 1:
                            r = parent[r][k]
                        dif >>= 1
                        k += 1
                    print(N - size[r]) # '가장 마지막 노드를 루트로 한 서브트리의 모든 노드'를 제외한 모든 노드가 U를 거쳐 수도로 간다.
                else: # U를 루트로 한 서브트리의 모든 노드가 U를 거쳐서 수도로 간다.
                    print(size[U])
        else:
            R = U # 수도 바꾸기
profile
GNU 16 statistics & computer science
post-custom-banner

0개의 댓글