오늘은 이전에 풀었던 백준 11812번 k진 트리 문제에 대한 저의 풀이법을 적어보고자 합니다,
우선 k진 트리가 무엇인가?
k진 트리란 각 노드가 가질 수 있는 자식 노드가 최대 k개인 트리를 의미하며 예를 들어 2진 트리는 두 개의 자식 노드(왼쪽과 오른쪽)를 가질 수 있는 경우를 의미합니다,
그럼 이 문제를 풀기 위해서는 어떤점들을 유의해야할까요?
- 찾고자 하는 노드의 깊이를 알아낸다
- 두 노드의 깊이를 동일하게 맞춰준다
- 이후부턴 최소 공통 조상 알고리즘과 동일하게 최소 공통 조상이 될때까지 탐색한다
그러면 여기서 찾고자 하는 노드의 깊이는 어떻게 알 수 있을까요 ?
그림을 그려보며 한번 찾아봅시다.
위 그림을 보면 k가 3인 3진 트리의 형태를 보이고 있습니다,
그리고 우리가 깊이를 알기 위해서는 각 깊이마다 가장 왼쪽에 있는 값과 가장 오른쪽에 있는 값을 알아내야하며
depth가 1인 1, 2, 3의 경우 left는 1이고 right는 3입니다
그리고 1번 노드의 가장 왼쪽은 4이며 이렇게 쭉 그려다나가다 보면 3번 노드의 가장 오른쪽은 12가 될 것입니다, 그리고 4번 노드의 가장 왼쪽은 13이 될 것이며 이런 패턴을 이해하게 되면 깊이를 알기 위해서 가장 왼쪽과 가장 오른쪽에 있는 노드를 찾을 수 있는 코드를 알 수 있게 됩니다,
left = left * k + 1; // 1 = 1 * 3 + 1 = 4
right = right * k + k; // 3 = 3 * 3 + 3 = 12
이렇게 깊이를 알아냈고 그러면 두 노드의 깊이를 맞추기 위해서는 어떻게 해야할까요?
우선 깊이를 맞추기 위해서는 노드의 부모를 알아낼 수 있는 코드가 필요합니다,
다시 위 그림을 살펴보면 1번 노드의 자식 노드는 4, 5, 6 입니다 그리고 여기서
위 left와 right 코드를 찾는것과 동일하게 k 만큼 나누게 되면 4 / 3 = 1, 5 / 3 = 1, 6 / 3 = 2가 됩니다
하지만 (node - 1) / 2 와 같이 노드에 -1를 뺀 후 k 만큼 나눠주면 바로 위에 부모 노드 까지는 알아낼 수 있습니다, 이를 통해 최소 공통 조상과 동일한 방법으로
1. 깊이를 비교한다 (u < v) 일경우 u와 v를 바꿔준다
2. 깊이를 맞춰준다
-> u와 v의 깊이만큼 부모를 가져오며 깊이를 맞춰준다 -> 이때도 count 해주어야함
3. 깊이가 똑같다면 최소 공통 조상까지 탐색하며 count를 한다
아래는 정답 코드입니다.
#include <bits/stdc++.h>
using namespace std;
long long n, k, t;
long long getDepth(long long node) {
long long depth = 0;
long long left = 1;
long long right = k;
if(k == 1) {
return node;
} else {
if(node != 0) {
depth = 1;
while (!(left <= node && node <= right)) {
++depth;
left = left * k + 1;
right = right * k + k;
}
}
}
return depth;
}
long long getParent(long long node) {
long long parent = (node - 1) / k;
return parent;
}
long long solve(long long u, long long v) {
long long ret = 0;
if(getDepth(u) < getDepth(v)) {
swap(u, v);
}
if(k == 1) {
ret = getDepth(u) - getDepth(v);
} else {
long long diff = getDepth(u) - getDepth(v);
ret += diff;
while(diff--) {
u = getParent(u);
}
while(u != v) {
u = getParent(u);
v = getParent(v);
ret += 2;
}
}
return ret;
}
int main() {
ios_base::sync_with_stdio(0);
cin.tie(0);
cin >> n >> k >> t;
while(t--) {
long long u, v;
cin >> u >> v;
cout << solve(u - 1, v - 1) << '\n';
}
return 0;
}