[백준] 13511번 트리와 쿼리 2 / Java, Python

Jini·2021년 9월 3일
0

백준

목록 보기
218/226

Baekjoon Online Judge

algorithm practice


- 단계별 문제풀기


36. 최소 공통 조상

트리에서 두 정점의 최소 공통 조상을 구하는 자료구조를 배워 봅시다.




5. 트리와 쿼리 2

13511번

트리 상의 경로에서 k번째 정점을 구하는 문제



이번 문제는 N개의 정점으로 이루어진 트리(무방향 사이클이 없는 연결 그래프)가 있다. 정점은 1번부터 N번까지 번호가 매겨져 있고, 간선은 1번부터 N-1번까지 번호가 매겨져 있다. 두 쿼리를 수행하는 프로그램을 작성하는 문제이다.

1 u v: u에서 v로 가는 경로의 비용을 출력한다.
2 u v k: u에서 v로 가는 경로에 존재하는 정점 중에서 k번째 정점을 출력한다. k는 u에서 v로 가는 경로에 포함된 정점의 수보다 작거나 같다.

dfs와 LCA를 이용해서, 이전 LCA 문제와 비슷한 방식으로 구할 수 있다. 조건을 나누어 거리(간선 비용)를 구한다.

Java : LCA함수와 parent를 이용해 조상을 구하고, dfs를 통해 depth를 확인하고 dist를 구한다.

Python : parent를 이용해 각 노드의 부모 노드 및 depth 계산하고, dp를 이용해 희소 테이블을 초기화 및 희소 테이블을 계산해주고, 조건을 나누어 거리를 구할 수 있다.



Java / Python


  • Java



import java.io.*;
import java.util.*;

public class Main {

	static int N, M; // N : 정점수, M : 쿼리 수

	static int[] depth;
	static long[] dist;
	static int[][] parent; // parent[j][i] = parent[parent[j][i - 1]][i - 1];
	static ArrayList<Node>[] tree;

	static class Node {
		int target, cost;

		public Node(int target, int cost) {
			this.target = target;
			this.cost = cost;
		}
	}

	public static void main(String[] args) throws Exception {
		BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
		BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(System.out));
		StringTokenizer st;
		StringBuilder sb = new StringBuilder();

		N = Integer.parseInt(br.readLine());

		depth = new int[N + 1];
		dist = new long[N + 1];
		parent = new int[N + 1][18];
		tree = new ArrayList[N + 1];

		for (int i = 1; i < N + 1; i++) {
			tree[i] = new ArrayList<Node>();
		}

		for (int i = 0; i < N - 1; i++) {
			st = new StringTokenizer(br.readLine());

			int a = Integer.parseInt(st.nextToken());
			int b = Integer.parseInt(st.nextToken());
			int c = Integer.parseInt(st.nextToken());

			tree[a].add(new Node(b, c));
			tree[b].add(new Node(a, c));
		}

		DFS(1, 1);

		// parent 채우기
		for (int i = 1; i < 18; i++) {
			for (int j = 2; j <= N; j++) {
				parent[j][i] = parent[parent[j][i - 1]][i - 1];
			}
		}

		// LCA
		M = Integer.parseInt(br.readLine());
		for (int i = 0; i < M; i++) {

			st = new StringTokenizer(br.readLine());
			int a = Integer.parseInt(st.nextToken());
			int u = Integer.parseInt(st.nextToken());
			int v = Integer.parseInt(st.nextToken());
			int root = LCA(u, v);
			if (a == 1) {
				sb.append(dist[u] + dist[v] - 2 * dist[root] + "\n");
			} else {
				int k = Integer.parseInt(st.nextToken());
				int cnt = depth[u] - depth[root] + 1;
				if (cnt == k)
					sb.append(root + "\n");
				else if (cnt > k) {
					k--;
					int tmp = u;
					for (int j = 0; j < 18; j++) {
						if ((k & 1 << j) != 0) {
							k -= 1 << j;
							tmp = parent[tmp][j];
						}
					}
					sb.append(tmp + "\n");
				} else {
					k = cnt + depth[v] - depth[root] - k + 1;
					k--;
					int tmp = v;
					for (int j = 0; j < 18; j++) {
						if ((k & 1 << j) != 0) {
							k -= 1 << j;
							tmp = parent[tmp][j];
						}
					}
					sb.append(tmp + "\n");
				}
			}
		}
		bw.write(sb.toString());

		bw.flush();
		bw.close();
		br.close();
	}

	// depth 확인
	// dist 추가
	static void DFS(int node, int cur) {
		depth[node] = cur;

		for (Node next : tree[node]) {
			if (depth[next.target] == 0) {
				parent[next.target][0] = node;
				dist[next.target] = dist[node] + next.cost;
				DFS(next.target, cur + 1);
			}
		}
		return;
	}

	static int LCA(int a, int b) {
		if (depth[a] < depth[b]) {
			// a가 더 얕으면 swap
			int temp = a;
			a = b;
			b = temp;
		}
		for (int i = 18; i >= 0; i--) {
			if (Math.pow(2, i) <= depth[a] - depth[b]) {
				a = parent[a][i]; // 높이 차이 만큼 a 높이 올리기
			}
		}
		if (a == b) return a;

		for (int i = 17; i >= 0; i--) {
			if (parent[a][i] != parent[b][i]) { 
				a = parent[a][i];
				b = parent[b][i];
			}
		}
		return parent[a][0];
	}
}




  • Python



import sys
from collections import deque
from math import log2
input = sys.stdin.readline

# tree 입력, 정렬, 부모노드, depth 계산 부분 
N = int(input())
tree = [[] for _ in range(N + 1)]

for _ in range(N - 1):
    a, b, c = map(int, input().split())
    tree[a].append([b, c])
    tree[b].append([a, c])
    
parent = [[0, 0] for _ in range(N + 1)]
depth = [0] * (N + 1)
visit = [False] * (N + 1)
que = deque([1])
visit[1] = True

while que:
    now = que.popleft()
    for b, c in tree[now]:
        if not visit[b]:
            que.append(b)
            parent[b][0] = now
            parent[b][1] = c
            depth[b] = depth[now] + 1
            visit[b] = True

# 희소 테이블
logN = int(log2(N) + 1)
dp = [[[0, 0] for _ in range(logN)] for __ in range(N + 1)]

for i in range(N + 1):
    dp[i][0][0] = parent[i][0]
    dp[i][0][1] = parent[i][1]

for j in range(1, logN):
    for i in range(1, N + 1):
        dp[i][j][0] = dp[dp[i][j - 1][0]][j - 1][0]
        dp[i][j][1] = dp[i][j - 1][1] + dp[dp[i][j - 1][0]][j - 1][1] 

# 쿼리문 입력, 처리 
M = int(sys.stdin.readline())
for _ in range(M):
    Query = list(map(int, input().split()))
    u, v = Query[1], Query[2]
    u2, v2 = u, v
    # 공통 조상 탐색 
    if depth[u2] < depth[v2]:
        u2, v2 = v2, u2
    diff = depth[u2] - depth[v2]
    for i in range(logN):
        if diff & 1 << i:
            u2 = dp[u2][i][0]
    if u2 == v2:
        lca = u2
    else:
        for i in range(logN - 1, -1, -1):
            if dp[u2][i][0] != dp[v2][i][0]:
                u2 = dp[u2][i][0]
                v2 = dp[v2][i][0]
        lca = dp[u2][0][0]
    if Query[0] == 1:
        cost = 0
        diff_u = depth[u] - depth[lca]
        diff_v = depth[v] - depth[lca]
        for i in range(logN):
            if diff_u & 1 << i:
                cost += dp[u][i][1]
                u = dp[u][i][0]
            if diff_v & 1 << i:
                cost += dp[v][i][1]
                v = dp[v][i][0]
        print(cost)
    else: 
        k = Query[3]
        # u 의 k - 1 조상을 계산
        if k <= depth[u] - depth[lca]: 
            diff = k - 1
            for i in range(logN):
                if diff & 1 << i:
                    u = dp[u][i][0]
            print(u)
        else: # 남은 거리를 v부터 계산
            diff = depth[v] + depth[u] - 2 * depth[lca] - k + 1
            for i in range(logN - 1, -1, -1):
                if diff & 1 << i:
                    v = dp[v][i][0]
            print(v)





profile
병아리 개발자 https://jules-jc.tistory.com/

0개의 댓글