백준 1761번 정점들의 거리

chisae·2023년 12월 9일

백준

목록 보기
5/10

오늘 백준 1761번 정점들의 거리를 풀어보았습니다.
일단 이 문제의 핵심 키워드들입니다

1. 최소 공통 조상(LCA) 알고리즘 사용
2. DFS 탐색시 루트노드로부터 거리 계산
3. 위에 두개 이용해서 두 정점의 거리 계산


우선 이 문제를 풀기 위해선 최소 공통 조상 알고리즘을 알고 있어야 합니다, 여기서 최소 공통 조상 알고리즘이란 "트리 구조에서 두 노드를 선택했을 때 이들의 경로가 처음 만나는 지점"이 최소 공통 조상입니다.


그리고 이런 최소 공통 조상의 구현방법은 여러가지가 있으며 대표적으론 단순하게 구현 하는 방법과 이진 탐색을 이용한 방법이 있습니다,

첫 번째로 단순하게 구현하는 방법은 한 노드에서 루트까지 거슬러 올라가면서 모든 조상을 찾고 다른 노드에 대해서도 같은 작업을 수행한 후 공통된 조상 중 가장 깊은 것을 찾으면 됩니다.
(브루트포스와 유사하며 이때 시간복잡도는 O(N)이 나옵니다)

두 번째로 이진 탐색 기법이 있습니다 미리 각 노드의 2의 거듭제곱 조상을 계산해둔 후 그 후 두 노드의 깊이를 맞춘 다음에 이진 탐색 방법으로 LCA를 찾습니다
(이때 시간 복잡도는 O(log N)이 나옵니다)

그리고 오늘은 이 중 두 번째 방법인 이진 탐색 기법을 활용해 1761번 정점들의 거리를 풀어보았읍니다.



이진 탐색을 통하여 LCA을 구현할 때 중요한 점은 아래와 같습니다

1. 각 정점의 바로 위에 부모 노드들 알기
2. 각 정점의 모든 부모 노드를 알기
3. 두 정점의 깊이를 맞추기 

이진 탐색으로 통하여 LCA를 구현하면 시간복잡도가 log N이 나온다고 했을 때
대부분의 분들은 이미 짐작하셨을 거라고 생각합니다 정점으로부터 위에 있는 부모를 탐색하는데
만약 2의 거듭제곱근을 기준으로 조상을 탐색하게 되면 굳이 탐색하지 않아도 되는 노드들은 탐색하지 않을 수 있습니다

위 그림과 같이 9번과 6번 노드의 최소 공통 조상은 1번입니다,
그럼 여기서 9번 노드로부터 2의 거듭제곱으로 부모 노드를 탐색하면 어떻게 되는지 보겠습니다

2의 0승 = 바로 위에 부모
2의 1승 = 2 번째 위치에 있는 부모
2의 2승 = 4 번째 위치에 있는 부모

이처럼 단순하게 탐색하는 방식과 다르게 더 효율적으로 LCA를 탐색할 수 있습니다 하지만.. 이것만으로는 LCA를 구현할 수 없습니다.. 그 이유로는

  1. 잘못된 LCA 계산, 두 정점의 깊이가 다를 경우에는 더 깊은 정점이 더 상위 레벨의 조상과 비교될 수 있습니다
  2. 비효율적인 계산, 깊이가 맞지 않는 상태일 경우에는 더 깊은 정점이 루트 노드까지 갈정도로 상위 노드로 계속해서 이동해야할 수도 있습니다
  3. 무한 루프 발생 가능성, 잘못된 구현때문에 두 노드가 절대 같은 레벨에 도달하지 못하여 무한 루프가 빠질 위험이있다...

위에 그림도 보면 벌써부터 탐색해야 하는 수가 다른데 만약 노드가 10만개가 넘어갈 경우에는 깊이가 다르면 정말 골치아플 수 있다.. 그러므로 깊이를 맞춰줘야한다

이런식으로 9번 노드에서 4번 노드(2의 1승, 두 번째 부모)로 이동 후 깊이가 2로 맞추어졌기 때문에
이 때부터 LCA를 구하면 된다.

그러면 각 정점의 바로 위에 부모는 어떻게 찾을 수 있을까?
그거는 트리를 배웠다면 아주 간단하다.

for(int next : adj[cur]) {
	parent[next][0] = cur;
}

부모 배열을 만든 후 인접리스트를 통해 계속해서 바로 위에 부모들은 갱신해주면 된다
(바로 위에 부모이기에 parent[next][0]인 것이다 2의 0승은 = 1)

그러면 바로 위에 부모 말고 그 위에 부모는 어떻게 해야 구현할 수 있을까?
이건 생각보다 복잡하다.

for (int k = 1; k < MAX_DEPTH; k++) { // 바로 위에 부모는 이미 갱신한 상태이므로 1부터 시작
    for (int cur = 1; cur <= n; cur++) { // 노드는 1 ~ n까지
        parent[cur][k] = parent[parent[cur][k - 1]][k - 1];
    }
}

parent[cur][k] = parent[parent[cur][k - 1]][k - 1];

이걸 예시를 들어서 위 그림에 9번 노드라고 생각한다면
parent[9][1] = parent[8][0]
parent[9][2] = parent[4][1] 이렇게 되는데

parent[9][0] = 8이며
parnet[9][1] = 4이다

이렇게까지 노드의 부모까지 모두 구하게 되면 이후 위에서 설명한 LCA를 통해
깊이를 맞춘 후 두 정점의 맨 위에 있는 부모가 같아질 때까지 가장 먼거리부터 줄여가면서
최소 공통 조상을 찾을 수 있다





아래는 이를 통해 구현한 코드이다.



#include <bits/stdc++.h>

using namespace std;

const int MAX_DEPTH = 21;
int parent[40001][21]; 
int depth[40001];
int dist[40001];
vector<vector<pair<int, int>>> adj;  //adjacency
int n, m, k;

void dfs(int cur) {
	
	for(auto& p : adj[cur]) { // 주소값을 가져오는게 더 효율 좋음(속도, 수정) 
		int next = p.first;
		int cost = p.second;
	    
		if(depth[next] == -1) { // 갱신 안 했을 경우 
			depth[next] = depth[cur] + 1; // 깊이 갱신 
			parent[next][0] = cur; // 바로 위에 부모 노드 갱신 
			dist[next] = dist[cur] + cost;
			dfs(next);
		}
	}
	
	return;
}

void connection() {
	
    for (int k = 1; k < MAX_DEPTH; k++) { // 바로 위에 부모는 이미 갱신한 상태이므로 1부터 시작
        for (int cur = 1; cur <= n; cur++) { // 노드는 1 ~ n까지
            parent[cur][k] = parent[parent[cur][k - 1]][k - 1];
        }
    }
}


int LCA(int u, int v) {
	
	if(depth[u] < depth[v]) {
		swap(u, v);
	}
	
	int diff = depth[u] - depth[v];
	
	for(int i = 0; diff != 0; i++) { // 깊이 조절 
		if(diff % 2 == 1) {
			u = parent[u][i];
		}
		
		diff /= 2; 
	}
	
	if (u != v) {
		for(int i = MAX_DEPTH - 1; i >= 0; i--) {
			if(parent[u][i] != -1 && parent[u][i] != parent[v][i]) {
				u = parent[u][i];
				v = parent[v][i];
			}
		}
		
		u = parent[u][0];
	}
	
	return u;
	
}

int main() {
	
	ios_base::sync_with_stdio(0);
    cin.tie(0);

		cin >> n;
		adj.resize(n + 1);
		fill(depth, depth + n + 1, -1); // 그냥 배열 초기화 
		fill(dist, dist + n + 1, 0); // 거리 배열 초기화 
		memset(parent, -1, sizeof(parent)); // 2차원 배열 초기화 
		
		depth[1] = 0;
		
		for(int i = 0; i < n - 1; i++) {
			int from, to, cost; 
			cin >> from >> to >> cost;
			adj[from].push_back({to, cost});
			adj[to].push_back({from, cost});
		}
		
		dfs(1);
		connection();
		
		cin >> m; 
		
		for(int i = 0; i < m; i++) {
			int u, v;
			cin >> u >> v;
			cout << dist[u] + dist[v] - (dist[LCA(u, v)] * 2) << '\n';
		} 
	
		
	return 0;
} 
profile
초보 개발자

0개의 댓글