10월 24일 - 교수님은 기다리지 않는다.

Yullgiii·2024년 10월 25일
0


또 잔뜩 화가나게 만드네.....

시간초과가 난 코드

import java.util.Scanner;

public class Main {
    static final int MAX = 100010;
    static long[] parent = new long[MAX];
    static long[] diff = new long[MAX];

    public static void main(String[] args) {
        Scanner scanner = new Scanner(System.in);

        while (true) {
            int n = scanner.nextInt();
            int m = scanner.nextInt();
            if (n == 0 && m == 0) {
                break;
            }
            
            init(n);
            
            for (int i = 0; i < m; i++) {
                char cmd = scanner.next().charAt(0);
                if (cmd == '!') {
                    int a = scanner.nextInt();
                    int b = scanner.nextInt();
                    int w = scanner.nextInt();
                    union(a, b, w);
                } else if (cmd == '?') {
                    int a = scanner.nextInt();
                    int b = scanner.nextInt();
                    if (find(a) == find(b)) {
                        System.out.println(diff[b] - diff[a]);
                    } else {
                        System.out.println("UNKNOWN");
                    }
                }
            }
        }
        scanner.close();
    }

    // 초기화 메서드
    static void init(int n) {
        for (int i = 1; i <= n; i++) {
            parent[i] = i;
            diff[i] = 0;
        }
    }

    // find 연산 메서드 (경로 압축 최적화 적용)
    static long find(long x) {
        if (parent[(int) x] != x) {
            long t = parent[(int) x];
            parent[(int) x] = find(parent[(int) x]);
            diff[(int) x] += diff[(int) t]; // 부모 노드로부터 루트까지의 거리 추가
        }
        return parent[(int) x];
    }

    // union 연산 메서드
    static void union(long a, long b, long w) {
        if (a > b) {
            long temp = a;
            a = b;
            b = temp;
            w = -w;
        }
        find(a);
        find(b);
        long x = diff[(int) b], y = diff[(int) a];
        a = find(a);
        b = find(b);
        if (a != b) {
            parent[(int) b] = a;
            diff[(int) b] = w + y - x;
        }
    }
}

정답 코드

import java.util.*;

public class Main {
    static final int MAX = 100010;
    static int[] parent = new int[MAX];
    static long[] diff = new long[MAX];

    public static void main(String[] args) {
        Scanner scanner = new Scanner(System.in);

        while (true) {
            int n = scanner.nextInt();
            int m = scanner.nextInt();
            if (n == 0 && m == 0) {
                break;
            }
            
            init(n);
            
            StringBuilder result = new StringBuilder();
            for (int i = 0; i < m; i++) {
                char cmd = scanner.next().charAt(0);
                if (cmd == '!') {
                    int a = scanner.nextInt();
                    int b = scanner.nextInt();
                    long w = scanner.nextLong();
                    union(a, b, w);
                } else if (cmd == '?') {
                    int a = scanner.nextInt();
                    int b = scanner.nextInt();
                    if (find(a) == find(b)) {
                        result.append(diff[b] - diff[a]).append("\n");
                    } else {
                        result.append("UNKNOWN\n");
                    }
                }
            }
            System.out.print(result);
        }
        scanner.close();
    }

    // 초기화 메서드
    static void init(int n) {
        for (int i = 1; i <= n; i++) {
            parent[i] = i;
            diff[i] = 0;
        }
    }

    // find 연산 메서드 (경로 압축 최적화 적용)
    static int find(int x) {
        if (parent[x] != x) {
            int originalParent = parent[x];
            parent[x] = find(parent[x]);
            diff[x] += diff[originalParent]; // 부모 노드로부터 루트까지의 거리 추가
        }
        return parent[x];
    }

    // union 연산 메서드
    static void union(int a, int b, long w) {
        int rootA = find(a);
        int rootB = find(b);
        if (rootA != rootB) {
            parent[rootB] = rootA;
            diff[rootB] = diff[a] - diff[b] + w; // 두 노드의 관계 업데이트
        }
    }
}

Union-Find 자료구조를 활용한 문제 해결 방식을 두 가지 코드를 통해 비교하며 이해했다. 첫 번째 코드는 시간 초과가 발생했지만, 두 번째 코드는 최적화된 방식으로 문제를 해결할 수 있었다. 아래는 두 코드의 차이점과 최적화 기법에 대해 정리한 내용이다.

코드 개요

  • 두 코드는 유니온-파인드(Union-Find) 알고리즘을 사용하여 주어진 노드 간의 연결 관계와 가중치 차이를 관리하고, 이를 통해 특정 질의에 빠르게 응답하는 방식이다.
  • 첫 번째 코드에서는 시간 초과가 발생했으나, 두 번째 코드에서는 최적화를 통해 시간 내에 문제를 해결할 수 있었다.

첫 번째 코드가 시간 초과가 난 이유

  • 입출력 비효율성: 첫 번째 코드에서는 System.out.println()을 반복적으로 사용하여 결과를 출력했는데, 이 방식은 입출력 연산이 잦아지면 시간이 크게 소모될 수 있다.
    System.out.println(diff[b] - diff[a]);
  • 타입 일관성 문제: 첫 번째 코드에서 parentfind 함수의 변수 타입이 long으로 되어 있었고, 이는 불필요하게 큰 타입을 사용해 성능에 영향을 줄 수 있었다. 실제로 문제의 범위에 맞게 타입을 int로 일관되게 사용하면 더 빠른 연산이 가능하다.

최적화된 두 번째 코드 설명

주요 변경 사항

  1. StringBuilder 사용

    • 많은 결과 출력을 한 번에 처리하기 위해 StringBuilder를 사용했다. 이를 통해 출력의 효율성을 크게 개선했고, 잦은 입출력 연산으로 인한 시간 초과를 방지했다.
    StringBuilder result = new StringBuilder();
    result.append(diff[b] - diff[a]).append("\n");
    System.out.print(result);
  2. 타입 일관성 유지

    • 첫 번째 코드에서는 parentdiff 배열의 타입으로 long을 사용했으나, 실제로는 int로 충분했다. 불필요한 타입 사용을 피함으로써 연산 속도를 향상시켰다.
    static int[] parent = new int[MAX];
    static long[] diff = new long[MAX];
  3. 경로 압축과 거리 누적

    • 경로 압축(Path Compression)을 통해 각 노드가 집합의 루트 노드를 직접 가리키도록 하여 find 연산의 시간 복잡도를 줄였다.
    static int find(int x) {
        if (parent[x] != x) {
            int originalParent = parent[x];
            parent[x] = find(parent[x]);
            diff[x] += diff[originalParent]; // 부모 노드로부터 루트까지의 거리 추가
        }
        return parent[x];
    }
    • 거리 누적(diff 배열)을 통해 두 노드 간의 가중치 차이를 빠르게 계산할 수 있도록 했다. 이는 find 연산 중 부모 노드의 정보를 갱신하며 가중치 누적을 반영하는 방식이다.

주요 함수 설명

  1. init(int n)

    • 각 노드의 부모를 자기 자신으로 초기화하고(parent[i] = i), 초기 무게 차이(diff[i])를 0으로 설정한다. 이는 모든 노드가 독립된 집합으로 시작하는 유니온-파인드의 기본 설정이다.
    static void init(int n) {
        for (int i = 1; i <= n; i++) {
            parent[i] = i;
            diff[i] = 0;
        }
    }
  2. find(int x)

    • 루트 노드를 찾는 과정에서 경로 압축을 통해 각 노드가 직접 루트를 가리키도록 한다. 이를 통해 이후 find 작업이 더 빠르게 수행된다.
    • diff[x]는 부모 노드에서 루트까지 경로 압축을 수행하면서 누적된 가중치 차이를 기록하여, 이후 질의(?)에서 빠르게 결과를 계산할 수 있게 한다.
    static int find(int x) {
        if (parent[x] != x) {
            int originalParent = parent[x];
            parent[x] = find(parent[x]);
            diff[x] += diff[originalParent];
        }
        return parent[x];
    }
  3. union(int a, int b, long w)

    • 두 노드를 같은 집합에 포함시키며, 두 노드 간의 가중치 차이를 diff 배열에 반영한다. ab가 연결될 때 ba보다 얼마나 무거운지(혹은 가벼운지) 계산하여 부모-자식 관계를 설정한다.
    static void union(int a, int b, long w) {
        int rootA = find(a);
        int rootB = find(b);
        if (rootA != rootB) {
            parent[rootB] = rootA;
            diff[rootB] = diff[a] - diff[b] + w;
        }
    }

주요 배운 점

  • 경로 압축(Path Compression): find 연산을 최적화하여 각 노드가 직접 루트 노드를 가리키도록 함으로써 이후의 탐색 시간이 줄어드는 효과를 볼 수 있다.
  • 거리 누적(diff 배열): 루트까지의 경로를 탐색하면서 누적된 가중치를 기록하는 방식으로, 특정 노드 간의 거리 계산을 빠르게 할 수 있었다.
  • 입출력 최적화: StringBuilder를 사용하여 많은 출력이 필요한 경우 효율성을 높일 수 있었다는 점을 깨달았다.
  • 타입 일관성 유지: 문제의 범위에 맞는 자료형을 선택하는 것이 성능에 큰 영향을 미칠 수 있음을 다시 한번 확인했다.

So...

  • 유니온-파인드 자료구조는 경로 압축과 거리 누적 기법을 통해 시간 복잡도를 크게 줄일 수 있다.
  • 입출력 연산은 프로그램 성능에 중요한 영향을 미치며, 효율적인 방법(StringBuilder)을 사용하는 것이 필요하다.
  • 두 코드를 비교하며, 불필요한 연산을 줄이고 최적화를 적용하는 것이 얼마나 중요한지 알 수 있었다. 이를 통해 더 나은 성능을 갖춘 프로그램을 작성할 수 있도록 노력해야겠다.
profile
개발이란 무엇인가..를 공부하는 거북이의 성장일기 🐢

0개의 댓글