BOJ 13896 - Sky Tax 링크
(2023.04.06 기준 P3)
(No Cheating)
트리 형태의 N개의 도시가 있고 수도는 R번 도시이다.
쿼리가 "S, U" 형태로 Q개 주어진다.
- S = 0 : 수도를 U로 바꾸기.
- S = 1 : 모든 도시에서 수도로 향할 때, U번 도시를 거치는 도시 개수 출력.
쿼리를 알맞게 처리하기.
1 <= N <= 100000, 1 <= Q <= 50000 이다. 그러므로 수도가 바뀔 때마다 다시 DFS로 서브 트리의 크기를 구하는 방법으로 구하면 TLE다.
최소 공통 조상을 이용해 한번 잘 풀어보자.
이런 형태의 트리가 있다고 생각을 해보자. 1을 루트로 한 트리이며 현재 수도는 5번이다.
만약 수도와 U가 같다면? 모든 도시가 수도로 가야 하며 또 수도는 U이니깐 이 쿼리의 답은 전체 도시 수인 N이 된다.
만약 수도와 U의 최소 공통 조상이 U와 다르다면? 이 쿼리의 답은 U를 루트로 한 서브 트리의 크기가 된다.
만약 수도와 U의 최소 공통 조상이 U가 된다면? 그림을 보면 수도는 5번, 수도는 1번, 최소 공통 조상은 1번이 된다. 수도로 가기 위해 1번을 거치는 도시들은 직관적으로 봤을 때 1번을 포함한 왼쪽 자식들이다. 그렇다면 포함하지 않는 자식들은? 1번의 오른쪽 자식. 즉, 수도에서 U로 거슬러 올라갈 때 가장 마지막 도시를 루트로 한 서브트리의 크기가 포함하지 않는 자식인 것이다.
다른 트리를 살펴보자.
결국 수도와 U의 최소 공통 조상이 U가 된다면, (N - U로 거슬러 올라갈 때 가장 마지막 도시를 루트로 한 서브트리의 크기)가 답이 된다.
#include <bits/stdc++.h>
using namespace std;
const int MAXN = 100000, MAXH = (int)ceil(log2(MAXN));
int N, Q, R, H, lv[MAXN], sz[MAXN];
int pa[MAXN][MAXH]; // 희소 배열
vector<int> graph[MAXN];
int dfs(int i, int p){
sz[i] = 1;
for (auto j: graph[i]){
if (j == p) continue;
pa[j][0] = i;
lv[j] = lv[i] + 1;
sz[i] += dfs(j, i);
}
return sz[i];
}
int lca(int i, int j){
if (lv[i] < lv[j]) swap(i, j);
int dif = lv[i] - lv[j];
int k = 0;
while (dif){
if (dif & 1) i = pa[i][k];
dif >>= 1; k++;
}
if (i != j){
for (k = H - 1; k >= 0; k--)
if (pa[i][k] != pa[j][k]) i = pa[i][k], j = pa[j][k];
i = pa[i][0];
}
return i;
}
void solve(){
cin >> N >> Q >> R;
R--; // 0-based index
for (int i = 0; i < N; i++) graph[i].clear(); // 그래프 초기화
for (int i = 1, A, B; i < N; i++){
cin >> A >> B;
graph[--A].push_back(--B);
graph[B].push_back(A);
}
// 희소 배열, 깊이, 서브트리의 크기 초기화
H = (int)ceil(log2(N));
fill(&pa[0][0], &pa[N - 1][H], -1);
fill(lv, lv + N, 0);
fill(sz, sz + N, 0);
// 희소 배열 채우기
dfs(0, -1);
for (int j = 1; j < H; j++) for (int i = 0; i < N; i++)
pa[i][j] = pa[pa[i][j - 1]][j - 1];
for (int i = 0, S, U; i < Q; i++){
cin >> S >> U;
U--; // 0-based index
if (S){
if (R == U) cout << N << '\n'; // U가 수도와 같다면 모든 도시가 U. 즉, 수도를 거친다.
else{
int l = lca(R, U);
if (l == U){
int dif = lv[R] - lv[U] - 1; // 수도에서 U로 가는 경로 중 가장 마지막 노드로 가서
int r = R, k = 0;
while (dif){
if (dif & 1) r = pa[r][k];
dif >>= 1; k++;
}
cout << N - sz[r] << '\n'; // '가장 마지막 노드를 루트로 한 서브트리의 모든 노드'를 제외한 모든 노드가 U를 거쳐 수도로 간다.
}
else cout << sz[U] << '\n'; // U를 루트로 한 서브트리의 모든 노드가 U를 거쳐서 수도로 간다.
}
}
else R = U; // 수도 바꾸기
}
}
int main(){
ios_base::sync_with_stdio(0);
cin.tie(0);
int T;
cin >> T;
for (int i = 1; i <= T; i++){
cout << "Case #" << i << ":" << '\n';
solve();
}
}
import sys; input = sys.stdin.readline
sys.setrecursionlimit(100000)
from math import ceil, log2
MAXN = 100000; MAXH = ceil(log2(MAXN))
lv = [0] * MAXN; sz = [0] * MAXN
pa = [[0] * MAXH for _ in range(MAXN)]
graph = [[] for _ in range(MAXN)]
def dfs(i, p):
sz[i] = 1
for j in graph[i]:
if j == p:
continue
pa[j][0] = i
lv[j] = lv[i] + 1
sz[i] += dfs(j, i)
return sz[i]
def lca(i, j):
if lv[i] < lv[j]:
i, j = j, i
dif = lv[i] - lv[j]
k = 0
while dif:
if dif & 1:
i = pa[i][k]
dif >>= 1
k += 1
if i != j:
for k in range(H - 1, -1, -1):
if pa[i][k] != pa[j][k] != -1:
i = pa[i][k]
j = pa[j][k]
i = pa[i][0]
return i
for T in range(1, int(input()) + 1):
print('Case #%d:' % T)
N, Q, R = map(int, input().split())
R -= 1 # 0-based index
for i in range(N): # 그래프 초기화
graph[i].clear()
for _ in range(N - 1):
A, B = map(int, input().split())
A -= 1; B -= 1
graph[A].append(B)
graph[B].append(A)
# 희소 배열, 깊이, 서브트리의 크기 초기화
H = ceil(log2(N))
for i in range(N):
lv[i] = sz[i] = 0
for j in range(H):
pa[i][j] = 0
# 희소 배열 채우기
dfs(0, -1)
for j in range(1, H):
for i in range(N):
pa[i][j] = pa[pa[i][j - 1]][j - 1]
for _ in range(Q):
S, U = map(int, input().split())
U -= 1 # 0-based index
if S:
if R == U: # U가 수도와 같다면 모든 도시가 U. 즉, 수도를 거친다.
print(N)
else:
l = lca(R, U)
if l == U: # U를 루트로 한 서브트리에 수도가 포함된다.
dif = lv[R] - lv[U] - 1 # 수도에서 U로 가는 경로 중 가장 마지막 노드로 가서
r = R; k = 0
while dif:
if dif & 1:
r = pa[r][k]
dif >>= 1
k += 1
print(N - sz[r]) # '가장 마지막 노드를 루트로 한 서브트리의 모든 노드'를 제외한 모든 노드가 U를 거쳐 수도로 간다.
else: # U를 루트로 한 서브트리의 모든 노드가 U를 거쳐서 수도로 간다.
print(sz[U])
else:
R = U # 수도 바꾸기
#include <bits/stdc++.h>
using namespace std;
int N, Q, R, A, B, M, S, U, l, dif, w, r, level[100000], sz[100000];
int parent[100000][(int)ceil(log2(100000))]; // 희소 배열
vector<int> graph[100000];
int dfs(int here, int prev){
sz[here] = 1;
for (auto there: graph[here]){
if (there == prev) continue;
parent[there][0] = here;
level[there] = level[here] + 1;
sz[here] += dfs(there, here);
}
return sz[here];
}
int lca(int u, int v){
if (level[u] < level[v]) swap(u, v);
dif = level[u] - level[v];
w = 0;
while (dif){
if (dif & 1) u = parent[u][w];
dif >>= 1;
w += 1;
}
if (u != v){
for (int w = M - 1; w; w--) if (parent[u][w] != parent[v][w]) u = parent[u][w], v = parent[v][w];
u = parent[u][0];
}
return u;
}
void solve(){
cin >> N >> Q >> R;
R--; // 0-based index
for (int i = 0; i < N; i++) graph[i].clear(); // 그래프 초기화
for (int i = 0; i < N - 1; i++){
cin >> A >> B;
graph[--A].push_back(--B);
graph[B].push_back(A);
}
M = ceil(log2(N));
for (int i = 0; i < N; i++){ // 희소 배열, 깊이, '자기를 루트로 한 서브트리의 크기' 초기화
for (int j = 0; j < M; j++) parent[i][j] = -1;
level[i] = 0, sz[i] = 0;
}
dfs(0, -1);
for (int j = 1; j < M; j++) for (int i = 0; i < N; i++){ // 희소 배열 완성
if (parent[i][j - 1] != -1) parent[i][j] = parent[parent[i][j - 1]][j - 1];
}
for (int i = 0; i < Q; i++){
cin >> S >> U;
U--;
if (S){
if (R == U) cout << N << '\n'; // U가 수도와 같다면 모든 도시가 U. 즉, 수도를 거친다.
else{
l = lca(R, U);
if (l == U){
r = R;
dif = level[r] - level[U] - 1; // 수도에서 U로 가는 경로 중 가장 마지막 노드로 가서
w = 0;
while (dif){
if (dif & 1) r = parent[r][w];
dif >>= 1;
w += 1;
}
cout << N - sz[r] << '\n'; // '가장 마지막 노드를 루트로 한 서브트리의 모든 노드'를 제외한 모든 노드가 U를 거쳐 수도로 간다.
}
else cout << sz[U] << '\n'; // U를 루트로 한 서브트리의 모든 노드가 U를 거쳐서 수도로 간다.
}
}
else R = U; // 수도 바꾸기
}
}
int main(){
ios_base::sync_with_stdio(0);
cin.tie(0);
int T;
cin >> T;
for (int i = 1; i <= T; i++){
cout << "Case #" << i << ":" << '\n';
solve();
}
}
import sys; input = sys.stdin.readline
sys.setrecursionlimit(100000)
from math import ceil, log2
def dfs(here, prev):
size[here] = 1
for there in graph[here]:
if there == prev:
continue
parent[there][0] = here
level[there] = level[here] + 1
size[here] += dfs(there, here)
return size[here]
def lca(u, v):
if level[u] < level[v]:
u, v = v, u
dif = level[u] - level[v]
w = 0
while dif:
if dif & 1:
u = parent[u][w]
dif >>= 1
w += 1
if u != v:
for w in range(M - 1, -1, -1):
if parent[u][w] != parent[v][w]:
u = parent[u][w]
v = parent[v][w]
u = parent[u][0]
return u
for T in range(1, int(input()) + 1):
print('Case #%d:' % T)
N, Q, R = map(int, input().split())
R -= 1 # 0-based index
graph = [[] for _ in range(N)]
for _ in range(N - 1):
A, B = map(int, input().split())
A -= 1; B -= 1
graph[A].append(B)
graph[B].append(A)
M = ceil(log2(N))
parent = [[-1] * M for _ in range(N)] # 희소 배열
level = [0] * N # 깊이
size = [0] * N # 자기를 루트로 한 서브트리의 크기
dfs(0, -1)
for j in range(1, M): # 희소 배열 완성
for i in range(N):
if parent[i][j - 1] != -1:
parent[i][j] = parent[parent[i][j - 1]][j - 1]
for _ in range(Q):
S, U = map(int, input().split())
U -= 1
if S:
if R == U: # U가 수도와 같다면 모든 도시가 U. 즉, 수도를 거친다.
print(N)
else:
l = lca(R, U)
if l == U: # U를 루트로 한 서브트리에 수도가 포함된다.
r = R
dif = level[r] - level[U] - 1 # 수도에서 U로 가는 경로 중 가장 마지막 노드로 가서
k = 0
while dif:
if dif & 1:
r = parent[r][k]
dif >>= 1
k += 1
print(N - size[r]) # '가장 마지막 노드를 루트로 한 서브트리의 모든 노드'를 제외한 모든 노드가 U를 거쳐 수도로 간다.
else: # U를 루트로 한 서브트리의 모든 노드가 U를 거쳐서 수도로 간다.
print(size[U])
else:
R = U # 수도 바꾸기