개선된 LCA 정복하기

·2024년 7월 25일
0
post-thumbnail

개요

근 한달간 코딩테스트를 계속 봤다. 나름 꾸준히 알고리즘 문제를 풀고 있긴 하나, 잘 늘지 않는 게 알고리즘 능력인듯 하다.

여러 코테 시험을 보며 유난히 후회가 컸던 문제가, 바로 LCA 관련 문제였다.

O(logN)O(logN) 기반 LCA 알고리즘을 적용하는 문제였는데, 기본 LCA 는 풀어본 경험이 있어서, 뭔가 풀 수 있을 거 같으면서도, 결국 구현을 하지 못해 아쉬움이 많이 남았다.

이 시험을 계기로, 코테에서 제대로 아이디어가 떠오르지 않으면 그 문제는 빠르게 포기하고 다른 문제에 집중해야한다는 것을 배웠다.

불행중 다행이라면, 코테를 통과?하여 면접을 보게 되었다.

또다시 아쉬움이 남지 않도록 LCA 알고리즘을 정복해보자.

LCA

Lowest Common Ancestor 의 약자로, 임의의 두 정점 A,B 에 대해, 가장 가까운 공통 조상을 의미한다.

예를들어 다음과 같은 트리 노드가 존재한다고 했을 때, 1015LCA5 가 된다 :

         1
       / | \
      2  3  4
     /|\    |
    5 6 7   8
   / \ \  \
  9 `10 11 12
 / \
13 14
   |
  `15

기본 LCA

LCA 의 풀이과정은 다음과 같다.

  1. 완전 탐색을 통해 노드들의 깊이와 직전의 부모 노드를 구한다. depth[N+1] / parent[N+1]

  2. 임의의 두 노드가 깊이가 서로 다르다면, 깊이를 맞춰주는 평탄화 작업을 진행한다. (이 때 두 노드가 포함관계라면 상위의 노드가 LCA 가 된다.)

  3. 깊이가 동일함에도, 서로 가리키는 부모 노드가 다른 경우, 동일한 부모 노드가 나올 때 까지 상위의 노드를 탐색한다. (최종에는 루트 노드가 있으므로 트리의 자료구조라면 마지막엔 결국 동일한 부모 노드가 나올 수 밖에 없다.)

풀어보기

https://www.acmicpc.net/problem/11437

풀이

import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.List;
import java.util.StringTokenizer;

public class Main {
    private static List<Integer>[] adj;
    private static int depth[], parent[];

    public static void main(String[] args) throws Exception {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        int N = Integer.parseInt(br.readLine());

        adj = new ArrayList[N + 1];

        depth = new int[N + 1];
        parent = new int[N + 1];
        for (int i = 1; i <= N; i++) {
            adj[i] = new ArrayList<>();
        }

        for (int i = 0; i < N - 1; i++) {
            StringTokenizer stk = new StringTokenizer(br.readLine());
            int n1 = Integer.parseInt(stk.nextToken());
            int n2 = Integer.parseInt(stk.nextToken());
            adj[n1].add(n2);
            adj[n2].add(n1);
        }

        setTree(1, 0, 1);

        int M = Integer.parseInt(br.readLine());
        StringBuilder sb = new StringBuilder();
        for (int i = 0; i < M; i++) {
            StringTokenizer stk = new StringTokenizer(br.readLine());
            int n1 = Integer.parseInt(stk.nextToken());
            int n2 = Integer.parseInt(stk.nextToken());
            sb.append(LCA(n1,n2)).append("\n");
        }
        System.out.print(sb.toString());
    }

    private static int LCA(int n1, int n2) {
        if (n1 == 1 || n2 == 1) return 1;

        int target = n1; // 평탄화가 반영될 노드 (더 깊은 노드)
        int compare = n2; // 비교 노드

        if (depth[n1] < depth[n2]) {
            target = n2;
            compare = n1;
        }

        // 평탄화
        while(depth[target] != depth[compare]) {
	         target = p[target];
        }

        // 공통조상 찾기
        while(target != compare) {
	        target = p[target];
	        compare = p[compare];
        }
        return target;
    }

    private static void setTree(int node, int p) {
        depth[node] = depth[p] + 1;
        parent[node] = p;

        for (int child : adj[node]) {
            if (child == p) continue;
            setTree(child, node, d + 1);
        }
    }
}

개선된 LCA

기본 LCA 알고리즘의 경우, 깊이의 크기가 시간복잡도에 많은 영향을 끼칠 수 있다는 단점이 있다.

예를 들어 다음과 같은 트리가 있다고 해보자 :

1
|\
2 1001
|
3
|
.
.
.
|
1000

위의 경우, 1000 , 1001 에 대한 LCA 를 구하라고 한다면, 1000 노드는 평탄화 작업을 진행하는데에만 거의 O(1000) 의 시간복잡도가 든다.

이러한 단점을 해결한 것이 개선된 LCA 알고리즘이다.

parent[N+1][k] 배열을 생성하여, 깊이 차이만큼 부모 노드를 저장한다. 단, 그냥 몽땅 저장하는 것이 아닌 2k2^k 씩 저장한다. 이 경우, 2,4,8,16 ... 칸 씩 한번에 점핑이 가능하여 평탄화와 공통 부모 찾기시 더 빠른 탐색이 가능하다.

2k2^k 씩 저장하는 가?

부모 노드를 그룹화하여 저장하는 거라면 3k3^k 혹은 4k4^k 도 가능하지 않을까? 심지어 훨씬 적은 저장공간만 활용하고, 동일한 kk 에 대해 점핑 크기도 더 크다. 3k3^k4k4^k 로는 풀이가 불가능한가?

고민한 내용을 바탕으로 내 의견을 이야기하자면, 3k3^k, 4k4^k 역시 풀이가 가능하며, 2k2^k 대비 메모리도 더 조금 사용할 수 있으나, 시간 초과가 일어날 가능성이 커진다.

이는 이진법의 특징에 기인한다. 2진법은 0과 1만을 사용하여 임의의 자릿수 k 에 대해 값이 존재한다, 존재하지 않는다의 차이만으로 모든 정수를 만들어낼 수 있다.

그러나, 3k3^k 로 부모 노드를 그룹화한다면, 각 자릿수에 대해, 해당 자릿수가 0인지 1 인지, 2 인지를 모두 파악해야봐야 한다. 숫자가 4k4^k, 5k5^k 가 될수록, 확인해야할 수 수들은 더 많아지며, 반복하는 자릿수 만큼 탐색 시간이 점점 늘어나는 것이다.

예컨데, 19 만큼 깊이 차이가 있는 두 노드에 대해 평탄화를 진행한다고 하자.

이를 2k2^k 인 경우와 5k5^k 인 경우를 비교해서 적용한다면, 2k2^k24+21+202^4 + 2^1 + 2^0 으로 3번만에 평탄화가 이루어지지만, 5k5^k(513)+(504)(5^1 *3) + (5^0 * 4) 으로 총 7번의 작업이 진행된다. 이는 5k5^k 로 그룹화를 진행하는 경우, 동일한 k 를 사용한 작업을 반복해야한다는 문제가 발생한다.

재밌는 건, 2k2^k 는 어떤 정수를 가져오더라도 동일한 k 를 반복하여 사용하지 않는다는 특징이 있다. 임의의 깊이 차이에 대해 최대 2k2^k 만큼 점핑을 했다면, 다음에 사용되는 점핑 값은 2k2^k 보다 작은 값이 올 수 밖에 없다.

만약 2k2^k 만큼 반영했는데, 2k2^k 만큼 남는 경우가 발생한다면, 이는 2k+12^{k+1} 로 대체가 가능하다. ex) 23+23=242^3+2^3=2^4

요약하자면 NkN^k 의 그룹화하여 생성할 수 있는 트리의 최대 깊이를 MaxD 라고 했을 때 2k2^k 는 평탄화는 최대 O(MaxD) 만에, 부모 노드는 O(MaxD) 미만의 시간 복잡도 (평탄화에서 일부 점핑을 했으므로) 로 문제를 해결할 수 있다는 특징을 가진다.

그러나 2k2^k 가 아닌 다른 NkN^k 는 평탄화 및 부모 노드 찾기 시, 2k2^k 대비 저장공간은 줄일 수 있으나, kk 에 대한 반복 작업이 추가 반영되어, 시간 초과를 발생시킬 수 있는 가능성이 커진다.

2k2^k 부모 배열에 대한 점화식 구하기

parent[a][k]parent[parent[a][k-1]][k-1] 라는 식이 성립한다. 다음의 노드를 통해 예시를 들어보자.

         1
       / | \
      2  3  4
     /|\    |
    5 6 7   8
   / \ \  \
  9  10 11 12
 / \
13 14
   |
  15
  • 13 노드에 대해, 202^0 부모 노드는 9 , 212^1 부모 노드는 5 , 222^2 부모 노드는 1 이 된다.
  • 이 때 5 노드에 대해 212^1 부모 노드는 1 이다.
  • 5 노드는 13 노드에게 212^1 의 부모 노드에 해당하므로, 13 의 입장에서 parent[13][2]= parent[5][1] 에 해당하며 이는 parent[13][2] = parent[parent[13][1]][1] 에 해당한다고 볼 수 있다.

평탄화

개선된 LCA 의 경우, 2k2^k 의 부모를 저장하여 사용한다는 점 외에는 기본 LCA 와 매우 유사하다.

두 노드의 깊이가 다른 경우, maxDepth 를 기점으로 두 노드의 깊이가 동일해지는 경우까지, target 의 노드를 부모 노드로 대체한다.

공통 부모 찾기

공통 부모 역시 maxDepth 를 기점으로 두 부모 노드가 동일하지 않은 경우, 현재 노드를 상위의 부모 노드로 변경하여 동일할 때 까지 비교한다.

풀어보기

https://www.acmicpc.net/problem/11438

풀이

import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.List;
import java.util.StringTokenizer;

public class Main {
    private static List<Integer>[] adj;
    private static int maxDepth, parent[][], depth[];

    public static void main(String[] args) throws Exception {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        int N = Integer.parseInt(br.readLine());

        adj = new ArrayList[N + 1];

        depth = new int[N + 1];
        for (int i = 1; i <= N; i++) {
            adj[i] = new ArrayList<>();
        }

        for (int i = 0; i < N - 1; i++) {
            StringTokenizer stk = new StringTokenizer(br.readLine());
            int n1 = Integer.parseInt(stk.nextToken());
            int n2 = Integer.parseInt(stk.nextToken());
            adj[n1].add(n2);
            adj[n2].add(n1);
        }

        maxDepth = (int) Math.floor(Math.log(N) / Math.log(2));
        parent = new int[N + 1][maxDepth + 1];

        setTree(1, 0, 1);

        for (int k = 1; k <= maxDepth; k++) {
            for (int i = 1; i <= N; i++) {
                parent[i][k] = parent[parent[i][k - 1]][k - 1];
            }
        }

        int M = Integer.parseInt(br.readLine());
        StringBuilder sb = new StringBuilder();
        for (int i = 0; i < M; i++) {
            StringTokenizer stk = new StringTokenizer(br.readLine());
            int n1 = Integer.parseInt(stk.nextToken());
            int n2 = Integer.parseInt(stk.nextToken());
            sb.append(LCA(n1, n2)).append("\n");
        }
        System.out.print(sb.toString());
    }

    private static int LCA(int n1, int n2) {
        if (n1 == 1 || n2 == 1) return 1;

        int target = n1; // 평탄화가 반영될 노드 (더 깊은 노드)
        int compare = n2; // 비교 노드

        if (depth[n1] < depth[n2]) {
            target = n2;
            compare = n1;
        }

        // 평탄화
        if (depth[target] != depth[compare]) {
            for (int i = maxDepth; i >= 0; i--) {
                if (depth[parent[target][i]] >= depth[compare]) {
                    target = parent[target][i];
                }
            }
        }

        // 공통조상 찾기
        int ret = target;

        if (target != compare) {
            for (int i = maxDepth; i >= 0; i--) {
                if (parent[target][i] != parent[compare][i]) {
                    target = parent[target][i];
                    compare = parent[compare][i];
                }
                ret = parent[target][i];
            }
        }
        return ret;
    }

    private static void setTree(int node, int p, int d) {
        depth[node] = d;
        parent[node][0] = p;

        for (int child : adj[node]) {
            if (child == p) continue;
            setTree(child, node, d + 1);
        }
    }
}

profile
새로운 것에 관심이 많고, 프로젝트 설계 및 최적화를 좋아합니다.

0개의 댓글