[백준/Java] 11438 LCA 2

박찬병·2024년 11월 2일

Problem Solving

목록 보기
23/48

https://www.acmicpc.net/problem/11438

문제 요약

N개의 정점으로 이루어진 트리가 주어진다. 루트는 1번이며, N번까지 있다.
M개의 두 노드 쌍이 주어질 때, 두 노드의 가장 가까운 공통 조상을 출력한다.
N은 최대 10만, M은 최대 10만이다.


문제 접근

두 값이 최대 10만이기 때문에 O(NlogN)이하의 알고리즘을 사용해야 한다.
쿼리가 M개 주어지고, 각 쿼리에 대한 연산을 로그로 한다면 O(MlogN)이 된다.

O(logN)으로 최소 공통 조상을 얻기 위해서는 2의 제곱수를 이용한 방법을 사용해야 한다.
가장 먼저 입력을 받아 그래프를 생성한다.
루트 노드부터 그래프를 훑으며 자신의 깊이와 부모를 저장한다.
각 노드의 2의 0제곱, 1제곱, ...의 부모를 저장하는 배열을 사용할 것이다.
이 배열을 채운다.
이후 쿼리를 받으며 탐색을 시도한다.
탐색의 기본은 두 노드의 높이를 동일하게 맞춘 뒤, 같이 올라가면서 최소 공통 조상을 얻는다는 점이다.
이때 앞선 배열을 이용해 2^N으로 올라가면서 높이를 맞추고, 공통 조상을 찾아야 O(logN)이 된다.

  1. 입력을 받아 그래프(트리) 생성 - O(N)
  2. 트리를 보고 자신의 깊이와 부모 저장 - O(N): BFS나 DFS 아무거나 사용 (배열 첫번째 줄을 채우는 것)
  3. 2의 제곱수 부모 배열 채우기 - O(NlogN): 그냥 배열 한 줄씩 채워도 NlogN이라 괜찮음
  4. 쿼리 받음(반복) - O(M)
  5. 두 노드 높이 맞추기 - O(logN)
  6. 최소 공통 조상 찾기 - O(logN)

결론: 1 + 2 + 3 + 4 * (5+6) 이므로 O((N+M)logN)이 됨

정점의 번호는 1부터 N까지이다. -> 배열의 크기를 N+1로 사용
다 int 범위 내이므로 int 사용해도 괜찮음


풀이

기본적인 아이디어는 다음과 같다.

  1. (추가 예정)

이를 구현한 코드는 다음과 같다.

import java.util.*;
import java.io.*;

public class Main {
	
	static final int ROOT = 1;
	
	static int N, M;
	static ArrayList<Integer>[] tree;
	
	static int[][] parents;
	static int[] depth;
	
	public static void findDepth(int nowNode, int nowDepth) {
		// 일단 depth를 기록함
		depth[nowNode] = nowDepth;
		
		// 내가 가리키는 노드들의 2^0번째 부모를 나로 기록하고, 재귀로 들어감
		// 다만 부모가 자식을 가리키고, 자식도 부모를 가리키고 있으므로 이를 처리해야 한다.
		for (int next: tree[nowNode]) {
			if (depth[next] > 0) continue; // 이미 처리되었다면 부모이므로 넘어감
			
			parents[0][next] = nowNode;
			findDepth(next, nowDepth+1);
		}
	}
	
	public static void main(String[] args) throws IOException {
		BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
		StringBuilder sb = new StringBuilder();
		
		// N과 정점을 입력받아 트리를 생성함
		N = Integer.parseInt(br.readLine());
		
		tree = new ArrayList[N+1];
		for (int i = 0; i < N+1; i++) {
			tree[i] = new ArrayList<>();
		}
		
		// 트리의 엣지 개수는 (노드 개수-1)이다.
		// 둘 중 무엇이 부모인지 모르기 때문에 일단 양방향으로 모두 저장
		for (int i = 0; i < N-1; i++) {
			StringTokenizer st = new StringTokenizer(br.readLine());
			
			int nodeA = Integer.parseInt(st.nextToken());
			int nodeB = Integer.parseInt(st.nextToken());
			
			tree[nodeA].add(nodeB);
			tree[nodeB].add(nodeA);
		}
		
		// N보다 같거나 큰, 가장 작은 2의 제곱수를 찾음(부모 배열의 크기를 설정하기 위함)
		int multiMax = 0;
		int squared = 1;
		while (squared < N) {
			squared *= 2;
			multiMax++;
		}
		
		// 2^multiMax까지 필요한데, 인덱스가 0부터 시작하므로 1을 더해줌
		parents = new int[multiMax+1][N+1]; // parents[i][j]는 j번 노드의 2^i번째 부모를 의미함
		depth = new int[N+1];
		
		// 루트부터 트리를 돌아 배열의 첫번째 줄(2^0번째 부모)와 본인의 깊이를 기록함
		findDepth(ROOT, 1); // 여기서는 루트의 깊이를 1로 시작한다. 어차피 차이가 중요하기 때문에 이렇게 해도 괜찮음
		
		// 제곱수 배열을 완성함
		for (int i = 1; i < multiMax+1; i++) {
			for (int j = 1; j < N+1; j++) {
				parents[i][j] = parents[i-1][parents[i-1][j]];
			}
		}
		
		// 테스트 코드
//		for (int i = 0; i < multiMax+1; i++) {
//			System.out.println(Arrays.toString(parents[i]));
//		}
		
		// 쿼리를 받음
		M = Integer.parseInt(br.readLine());
		
		for(int i = 0; i < M; i++) {
			StringTokenizer st = new StringTokenizer(br.readLine());
			
			int nodeA = Integer.parseInt(st.nextToken());
			int nodeB = Integer.parseInt(st.nextToken());
			
			// 두 노드의 깊이를 비교해서 다르다면 깊이를 맞춰 줌
			int depthA = depth[nodeA];
			int depthB = depth[nodeB];
			
			if (depthA != depthB) {
				// 깊이가 다르다면 항상 A가 B보다 깊도록 설정해줌(A가 올라간다)
				if (depthB > depthA) {
					int temp = depthA;
					depthA = depthB;
					depthB = temp;
					
					temp = nodeA;
					nodeA = nodeB;
					nodeB = temp;
				}
				
				// A가 올라가면서 깊이를 동일하게 맞춤
				// 가장 큰 2의 제곱수부터 빼가면서 찾아감
				for (int j = multiMax+1; j >= 0; j--) {
					int shift = 1 << j;
					if (depthA - shift >= depthB) {
						depthA -= shift;
						nodeA = parents[j][nodeA];
					}
				}
				
				//sb.append(nodeA+" "+nodeB+"\n");
			}
			
			// 깊이가 같을 때, 최소 공통 조상을 찾음
			// 2의 제곱수 부모를 비교하면서, 언제 처음 같아지는지 확인
			// 같아지기 바로 직전으로 이동해서 다시 판정
			int lca = 0;
			
			while (true) {
				// 같아지면 최소 공통 조상이다.
				if (nodeA == nodeB) {
					lca = nodeA;
					break;
				}
				
				// row가 0일 때 같다면 그냥 그 쪽으로 이동
				int row = 0;
				if (parents[row][nodeA] == parents[row][nodeB]) {
					nodeA = parents[row][nodeA];
					nodeB = parents[row][nodeB];
					continue;
				}
				
				// 언제 처음 같아지는지 확인해서, 같아지기 직전으로 이동
				while (parents[row][nodeA] != parents[row][nodeB]) {
					row++;
				}
				nodeA = parents[row-1][nodeA];
				nodeB = parents[row-1][nodeB];
			}
			
			
			// 얻은 값을 결과에 추가해줌
			sb.append(lca+"\n");
		}

		// 결과 출력
		System.out.println(sb);
	}
}

회고

최소 공통 조상(LCA)를 구하는 보통의 방법은 굉장히 직관적이지만, 이 문제는 시간복잡도를 더 낮춘 효율적인 방법을 사용한다.
이 방법은 별로 직관적이지는 않고, 사실 이 알고리즘을 배운 적이 없다면 풀 수가 없는 문제인 것 같다.

0개의 댓글