백준 4803번 트리 Java

: ) YOUNG·2024년 4월 11일
1

알고리즘

목록 보기
362/411
post-thumbnail

백준 4803번
https://www.acmicpc.net/problem/4803

문제



생각하기


  • 유니온 파인드 문제이다.


동작

문제는 맞게 풀었던 것 같은데, 은근히 반례를 생각해내지 못해서 고생했다.

일단 사이클을 찾는 방법은 union() 전에 합치는 a 노드와 b노드의 루트가 같은지 파악한다.

만약 둘의 루트가 이미 같은데도 union()이라면 사이클이 발생하기 때문에 트리가 될 수 없다.

그래서 cycleList에 넣어서 해당 노드가 포함된 집합은 트리가 아니라고 저장해둔다.

        for (int i = 1; i <= N; i++) {
            // 모두 입력 후 부모 갱신해줘야 됨
            if (cycleList.contains(i)) {
                parents[i] = find(i);
                cycleList.add(parents[i]);
            }
        }

그리고 이 문제에서 여기가 핵심이지 않을까 생각하는데,
union()에서 cycleList를 만들었다고 끝나는게 아니라 parents배열에서 자신의 루트노드 값이 제대로 저장되어 있지 않는 경우가 꼭 있기 때문에 find(i) 연산을 통해서 다시 갱신해줘야 한다.




이미 find연산을 했는데 왜 제대로 저장이 안되나?

예를 들어 1 2노드가 합쳐지면 2번의 부모는 1로 루트노드가 잘 저장이 된다. 근데 2 3번 노드를 합친다고 하면 1 2 3이 하나의 집합으로 트리가 되지만, 3번의 루트노드는 2로 저장되어 있기 때문에 union 연산만으로는 루트노드가 제대로 저장되지 않는다.

그렇기 때문에 혹여나 각 노드별로 자신의 부모를 정확히 알아야 하는 경우에는 꼭 find연산을 실행해야 한다.


아무튼 이렇게 제대로 만들어진 cycleList 를 가지고


        int ans = 0;
        for (int i = 1; i <= N; i++) {
            if (!cycleList.contains(parents[i]) && parents[i] == i) ans++;
        }

포함되지 않으면서 자신이 루트노드인 경우 트리 조건에 부합하는 루트 노드를 찾을 수 있다.




좋은 예제


9 9
1 2
2 3
3 4
4 5
3 5
6 7
7 8
6 8
8 9
0 0

답 : No tree
// 두개의 집합이 존재하지만 둘 다 내부에 사이클이 존재하므로 트리는 없다.

7 7
1 2
2 3
3 1
4 5
5 6
6 4
1 6
0 0

답 : One Tree
// 7은 자기 자신만 남기 때문에 정답은 1이 된다.


결과


코드



import java.io.*;
import java.util.Arrays;
import java.util.HashSet;
import java.util.StringTokenizer;

public class Main {

    // input
    private static BufferedReader br;

    // variables
    private static int N, M, T;
    private static int[] parents, ranks;

    public static void main(String[] args) throws IOException {
        br = new BufferedReader(new InputStreamReader(System.in));
        BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(System.out));

        T = 1;
        for (; ; ) {
            StringTokenizer st = new StringTokenizer(br.readLine());
            N = Integer.parseInt(st.nextToken());
            M = Integer.parseInt(st.nextToken());
            if (N == 0) break;

            input();

            bw.write(solve());
        }

        bw.close();
    } // End of main()

    private static String solve() throws IOException {
        StringBuilder sb = new StringBuilder();
        HashSet<Integer> cycleList = new HashSet<>();

        for (int i = 0; i < M; i++) {
            StringTokenizer st = new StringTokenizer(br.readLine());
            int a = Integer.parseInt(st.nextToken());
            int b = Integer.parseInt(st.nextToken());

            boolean ret = union(a, b);

            if (ret) {
                cycleList.add(a);
            }
        }

        for (int i = 1; i <= N; i++) {
            // 모두 입력 후 부모 갱신해줘야 됨
            if (cycleList.contains(i)) {
                parents[i] = find(i);
                cycleList.add(parents[i]);
            }
        }

        int ans = 0;
        for (int i = 1; i <= N; i++) {
            if (!cycleList.contains(parents[i]) && parents[i] == i) ans++;
        }

        sb.append("Case ").append(T++).append(": ");
        if (ans == 0) {
            sb.append("No trees.");
        } else if (ans == 1) {
            sb.append("There is one tree.");
        } else {
            sb.append("A forest of ").append(ans).append(" trees.");
        }
        sb.append('\n');

        return sb.toString();
    } // End of solve()

    private static boolean union(int a, int b) {
        int rootA = find(a);
        int rootB = find(b);

        if (rootA == rootB) return true;

        if (ranks[rootA] < ranks[rootB]) {
            parents[rootB] = rootA;
        } else {
            parents[rootA] = rootB;

            if (ranks[rootA] == ranks[rootB]) {
                ranks[rootA]++;
            }
        }

        return false;
    } // End of union()

    private static int find(int node) {
        if (parents[node] != node) {
            parents[node] = find(parents[node]);
        }
        return parents[node];
    } // End of find()

    private static void input() {
        parents = new int[N + 1];
        for (int i = 0; i <= N; i++) {
            parents[i] = i;
        }

        ranks = new int[N + 1];
    } // End of input()
} // End of Main class

0개의 댓글