BOJ 15681 트리와 쿼리 (Java)

사람·2025년 1월 30일
0

BOJ

목록 보기
21/74

문제

https://www.acmicpc.net/problem/15681
간선에 가중치와 방향성이 없는 임의의 루트 있는 트리가 주어졌을 때, 아래의 쿼리에 답해보도록 하자.

정점 U를 루트로 하는 서브트리에 속한 정점의 수를 출력한다.
만약 이 문제를 해결하는 데에 어려움이 있다면, 하단의 힌트에 첨부한 문서를 참고하자.

입력
트리의 정점의 수 N과 루트의 번호 R, 쿼리의 수 Q가 주어진다. (2 ≤ N ≤ 105, 1 ≤ R ≤ N, 1 ≤ Q ≤ 105)

이어 N-1줄에 걸쳐, U V의 형태로 트리에 속한 간선의 정보가 주어진다. (1 ≤ U, V ≤ N, U ≠ V)

이는 U와 V를 양 끝점으로 하는 간선이 트리에 속함을 의미한다.

이어 Q줄에 걸쳐, 문제에 설명한 U가 하나씩 주어진다. (1 ≤ U ≤ N)

입력으로 주어지는 트리는 항상 올바른 트리임이 보장된다.

출력
Q줄에 걸쳐 각 쿼리의 답을 정수 하나로 출력한다.

예제 입력 1
9 5 3
1 3
4 3
5 4
5 6
6 7
2 3
9 6
6 8
5
4
8
예제 출력 1
9
4
1

접근

굳이 복붙을 하지는 않았지만 문제 뒤에 긴 힌트가 있었다.
읽기 귀찮기도 하고 스포당하기 싫어서 처음에는 안 보고 풀었는데 메모리 초과가 났다.
왜냐면 Q개의 쿼리 입력에 대해 매번 트리를 탐색하면서 해당 정점 U를 찾은 후 서브트리의 정점의 수를 세었기 때문이었다. 모든 입력에 대해 트리 전체를 순회하여야 했던 것.

그러고 나서 힌트를 뒷부분만 살짝 보니...

메인 함수 내에서 makeTree(5, -1)과 countSubtreeNodes(5) 를 차례대로 한 번씩 호출할 경우, 5번을 루트로 하는 트리에서 모든 정점에 대해 각 정점을 루트로 하는 서브트리에 속한 정점의 수를 계산해둘 수가 있다. 이를 이용하면, 모든 질의 U에 대해 size[U] 를 출력하기만 하면 되므로, 정점이 10만 개, 질의가 10만 개인 데이터에서도 충분히 빠른 시간 내에 모든 질의를 처리할 수가 있게 될 것이다.

dp와 같은 방식으로 풀 수 있겠구나 하는 걸 깨닫게 되었다.
트리는 계층 구조로 이루어져 있으므로 부모 노드를 루트로 하는 서브 트리의 모든 노드의 개수는 자식을 루트로 하는 모든 서브 트리의 노드 개수의 합으로 구할 수 있기 때문이다.

그러니까 한 번 루트에서 리프까지의 모든 노드의 개수를 세면, 그 과정에서 동시에 각 서브 트리의 노드 개수까지 세어질 수 있는 것이다. 이 값들을 메모이제이션 해두고 그냥 쿼리 입력 값에 따라 출력만 하면 되는 거였다. i번 노드를 루트로 하는 서브 트리의 모든 노드 개수를 dp[i]라는 배열에 저장해 두기 때문에, 인덱스로 바로 접근할 수 있어 정점 U를 매번 찾을 필요도 없었다.

구현

메모리 초과가 났던 풀이

import java.io.*;
import java.util.*;

class Main {
    static List<Integer>[] graph;
    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        StringTokenizer st = new StringTokenizer(br.readLine());
        int N = Integer.parseInt(st.nextToken());
        int R = Integer.parseInt(st.nextToken());
        int Q = Integer.parseInt(st.nextToken());
        graph = new List[N + 1];
        for (int i = 1; i <= N; i++) {
            graph[i] = new ArrayList<>();
        }
        for (int i = 0; i < N - 1; i++) {
            st = new StringTokenizer(br.readLine());
            int U = Integer.parseInt(st.nextToken());
            int V = Integer.parseInt(st.nextToken());
            graph[U].add(V);
            graph[V].add(U);
        }
        StringBuilder sb = new StringBuilder();
        Node rootNode = new Node(R);
        makeTree(rootNode);
        for (int i = 0; i < Q; i++) {
            sb.append(countNodes(findNode(Integer.parseInt(br.readLine()), rootNode))).append("\n");
        }
        System.out.print(sb.toString());
    }

    private static void makeTree(Node subRoot) {
        for (int childNum : graph[subRoot.num]) {
            Node childNode = new Node(childNum);
            subRoot.children.add(childNode);
            graph[childNum].remove(graph[childNum].indexOf(subRoot.num));
            makeTree(childNode);
        }
    }

    private static Node findNode(int target, Node subRoot) {
        if (subRoot.num == target) {
            return subRoot;
        }
        if (subRoot.children.isEmpty()) {
            return new Node(0);
        }
        for (Node child : subRoot.children) {
            Node found = findNode(target, child);
            if (found.num > 0) {
                return found;
            }
        }
        return new Node(0);
    }

    private static int countNodes(Node subRoot) {
        if (subRoot.children.isEmpty()) {
            return 1;
        }

        int count = 1;
        for (Node child : subRoot.children) {
            count += countNodes(child);
        }
        return count;
    }

    static class Node {
        int num;
        List<Node> children;

        Node (int num) {
            this.num = num;
            this.children = new ArrayList<>();
        }
    }
}

위에 서술했듯, dp를 이용하면 findNode()라는 메소드는 필요가 없었다.

정답 풀이

import java.io.*;
import java.util.*;

class Main {
    static int[] numOfNodes;
    static List<Integer>[] graph;
    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        StringTokenizer st = new StringTokenizer(br.readLine());
        StringBuilder sb = new StringBuilder();

        int N = Integer.parseInt(st.nextToken());
        int R = Integer.parseInt(st.nextToken());
        int Q = Integer.parseInt(st.nextToken());
        graph = new List[N + 1];
        numOfNodes = new int[N + 1];
        
        for (int i = 1; i <= N; i++) {
            graph[i] = new ArrayList<>();
        }

        for (int i = 0; i < N - 1; i++) {
            st = new StringTokenizer(br.readLine());
            int U = Integer.parseInt(st.nextToken());
            int V = Integer.parseInt(st.nextToken());
            graph[U].add(V);
            graph[V].add(U);
        }
        
        Node rootNode = new Node(R);
        makeTree(rootNode);
        // 각 노드를 루트로 하는 서브 트리의 노드 개수 구하기
        numOfNodes[R] = countNodes(rootNode);
        for (int i = 0; i < Q; i++) {
            sb.append(numOfNodes[Integer.parseInt(br.readLine())]).append("\n");
        }
        System.out.print(sb.toString());
    }

    private static void makeTree(Node subRoot) {
        for (int childNum : graph[subRoot.num]) {
            Node childNode = new Node(childNum);
            subRoot.children.add(childNode);
            graph[childNum].remove(graph[childNum].indexOf(subRoot.num));
            makeTree(childNode);
        }
    }

    private static int countNodes(Node subRoot) {
    	// 이미 노드의 개수를 구한 적이 있는 서브 트리인 경우
        if (numOfNodes[subRoot.num] > 0) {
            return numOfNodes[subRoot.num];
        }
        // 리프 노드인 경우 더 이상 서브 트리가 만들어질 수 없으므로 자기 자신의 개수인 1을 리턴하고 종료
        if (subRoot.children.isEmpty()) {
            return 1;
        }

        int count = 1;
        // subRoot를 루트로 하는 서브 트리의 모든 노드의 개수
        // == subRoot의 자식 노드들을 루트로 하는 모든 서브 트리의 노드 개수의 합
        for (Node child : subRoot.children) {
            int num = countNodes(child);
            // child를 루트로 하는 서브 트리의 노드 개수를 메모이제이션
            numOfNodes[child.num] = num;
            count += num;
        }
        return count;
    }

    static class Node {
        int num;
        List<Node> children;

        Node (int num) {
            this.num = num;
            this.children = new ArrayList<>();
        }
    }
}

profile
알고리즘 블로그 아닙니다.

0개의 댓글