LCA(Lowest Common Ancestor)

Bty·2022년 2월 24일
post-thumbnail

LCA(Lowest Common Ancestor)

최소 공통 조상이라고도 한다.
두 정점이 주어졌을 때,
두 정점에서 자신의 정점을 포함한 가장 가까운 공통의 부모 정점을 말한다.

그림으로 살펴보자.

7번 정점과 8번 정점의 LCA는 무엇일까?
4번 정점이다.

3번 정점과 4번 정점의 LCA는 1번 정점,
2번 정점과 4번 정점의 LCA는 2번이다.

이제 이를 코드로 어떻게 구현할까?
LCA를 이용한 문제를 통해 알아보자.

LCA 구현 문제

LCA
LCA의 개념문제이다.
트리 그래프가 주어지며, 루트 노드는 1번 노드라는 조건이 주어졌다.

크게 다음과 같은 단계로 알고리즘을 구현한다.

  1. 인접 리스트에 그래프의 정보를 저장한다.
  2. DFS나 BFS를 이용해 각 정점의 깊이 및 해당 정점의 부모 정점을 저장한다.
  3. 정점 v1, v2의 LCA를 찾는다.

3번 과정을 예를 들어 자세히 살펴보자.

LCA를 찾기 위해서는,
두 정점이 같은 깊이에 있도록 정점을 이동한 뒤,
두 정점의 부모 노드를 비교해가며 루트 노드 방향으로 한 단계씩 이동하면 된다.

이제 초록색으로 채워진 3번, 7번 정점의 LCA를 찾아보자.

두 정점의 깊이를 같게 하려면 어떻게 이동해야 할까?
7번 정점과 3번 정점이 같은 깊이에 위치하도록
7번 정점을 2번 정점으로 이동한다.

여기서부터 단계적으로 부모 노드를 차례차례 비교해가면 LCA를 찾을 수 있다.
해당 예시에서 7번 정점과 3번 정점의 LCA는 1번 정점이다.

코드

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.*;

class Main {

    static ArrayList<Integer>[] graph;
    static int[] parent;
    static int[] depth;

    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        StringTokenizer st;

        int n = Integer.parseInt(br.readLine());

        graph = new ArrayList[n + 1];
        parent = new int[n + 1];
        depth = new int[n + 1];

        for (int i = 0; i < n + 1; i++)
            graph[i] = new ArrayList<>();

        for (int i = 0; i < n - 1; i++) {
            st = new StringTokenizer(br.readLine());
            int start = Integer.parseInt(st.nextToken());
            int end = Integer.parseInt(st.nextToken());

            graph[start].add(end);
            graph[end].add(start);
        }

        // dfs(1, 1);
         bfs();
        int m = Integer.parseInt(br.readLine());

        for (int i = 0; i < m; i++) {
            st = new StringTokenizer(br.readLine());
            int start = Integer.parseInt(st.nextToken());
            int end = Integer.parseInt(st.nextToken());

            lca(start, end);
        }
    }

    static void dfs(int curVertex, int curDepth) {

        depth[curVertex] = curDepth;
        for(int next : graph[curVertex]) {
            if(depth[next] == 0) {
                parent[next] = curVertex;
                dfs(next, curDepth + 1);
            }
        }
    }

    static void bfs() {
        Queue<Integer> q = new LinkedList<>();

        q.add(1);
        depth[1] = 1;

        while(!q.isEmpty()) {
            int curVertex = q.poll();

            for (int next : graph[curVertex]) {
                if(depth[next] == 0) {
                    parent[next] = curVertex;
                    depth[next] = depth[curVertex] + 1;
                    q.add(next);
                }
            }


        }
    }

    static void lca(int v1, int v2) {

        if(depth[v1] > depth[v2]) {
            int tmp = v1;
            v1 = v2;
            v2 = tmp;
        }

        while(depth[v1] != depth[v2]) {
            v2 = parent[v2];
        }

        while(v1 != v2) {
            v1 = parent[v1];
            v2 = parent[v2];
        }

        System.out.println(v1);
    }
}

parent[v1] : v1의 부모 노드를 저장하고 있다.
depth[v1] : v1의 깊이를 저장하고 있다.

그런데 실수할 수 있는 부분이 있다.

위에서 구현한 LCA 함수의 일부분이다.

while(v1 != v2) { // while(parent[v1] != parent[v2])로 작성해도 될까? 
            
            v1 = parent[v1];
            v2 = parent[v2];
        }

정점이 같아질 경우를 고려하면 되지만, 그렇지 않으면
while(parent[v1] != parent[v2])처럼 작성할 경우
다음과 같은 예시에서 오답이 나타난다.

2번 정점과 5번 정점의 LCA는 2번 정점이지만,
위처럼 구현하면 이들의 LCA는 1번 정점이 되어버린다.

LCA는 정점 본인을 포함할 수 있음을 명심하자.

시간 복잡도

DFS 또는 BFS를 한번 실행해서 각 정점의 부모 정점과 깊이를 계산할 수 있고,
두 정점의 깊이만큼의 연산이 필요하다.

편향 트리일 가능성이 있으므로 O(logN)O(logN)이 아닌 O(N)O(N)이다.

정점의 개수가 더 많아지면?

LCA 문제의 경우
정점의 개수가 최대 50000개, 쿼리의 개수가 10000개이므로
O(5000010000)O(50000 * 10000)으로 풀어낼 수 있다.

하지만 LCA 2 문제의 경우,
정점, 쿼리 모두 100000개를 가지고 있으므로
위에서 접근하던 방식으로는 시간 초과가 발생할 것이다.

쿼리의 시간복잡도를 바꿀 수는 없으므로,
LCA를 찾아내는 시간복잡도를 최적화할 방법이 필요하다.

아래의 예시를 보자.

위처럼 정점 100개가 편향되어 있는 트리가 있다고 하자.
이 때 정점 4번과 정점 100번의 LCA를 구해보면,

앞서 사용한 방법을 이용할 경우
두 정점의 깊이를 맞춰주기 위해
정점 100에서 한칸씩 위로 이동해 정점 3에 도달한다.

이런 부분에서 연산을 줄일 수 있다.

정점 100에서 한칸씩 97번 이동하는 게 아니라,
한꺼번에 많이 이동해 연산을 줄여보자는 것이다.

2의 제곱 형태로 이동하는 것이 수월하다.

97=26+25+2097 = 2^6 + 2^5 + 2^0이므로,
97번 이동해야 하던 기존 방법과는 다르게 3번만에 정점 3으로 이동할 수 있다.

그럼 이렇게 이동 하려면,
정점 100에서 26,25,202^6, 2^5, 2^0번째의 상위 부모를 알아야 한다.

parent[i][j]를 i번 정점에서 2j2^j번째 상위 부모 노드의 정점이라고 하자.
그럼 parent[3][2]는 3번 정점의 222^2번째 상위 부모 노드의 정점인 것이다.

그럼 다음과 같이 DP를 적용해 부모 노드들을 기록할 수 있다.

for (int i = 1; i <= 20; i++) {
    for (int j = 1; j <= n; j++) {
        parent[j][i] = parent[parent[j][i-1]][i-1];

코드의 의미를 한번에 파악하기 쉽지 않다.
게다가 parent[j][i]처럼 j와 i의 위치가 바뀌어 있다. 더 헷갈린다.

천천히 살펴보자.

앞서 사용했던 용어를 다시 짚고 넘어가보자.
A 노드의 1번째 상위 부모 노드는 A 노드의 바로 위 노드이다.
상위 부모라는 말이 다소 부자연스럽지만, 확실한 의미를 담기 위해 사용하겠다.

각 노드의 첫번째(=202^0) 상위 부모 노드는 DFS나 BFS를 통해 쉽게 기록할 수 있다.

그렇다면 각 노드의 두 번째(=212^1) 상위 부모 노드는 어떻게 구할까?

모든 노드는 각 노드의 첫번째 상위 부모 노드를 알고 있다.

그렇다면,

각 노드의 두 번째(=212^1) 상위 부모 노드는
각 노드의 첫 번째(=202^0) 상위 부모 노드의 첫 번째(=202^0) 상위 부모 노드이다.

같은 의미로,

각 노드의 네 번째(=222^2) 상위 부모 노드는
각 노드의 두 번째(=212^1) 상위 부모 노드의 두 번째(=212^1) 상위 부모 노드이다.

할아버지(=212^1)는
나를 기준으로 두 번째 상위 부모이기도 하지만,
부모(=202^0)의 부모(=202^0)이기도 하다.

고조 할아버지(=222^2)는
나를 기준으로 네 번째 상위 부모이기도 하지만,
할아버지(=212^1)의 할아버지(=212^1)이기도 하다.

모든 노드들은 각자의 첫 번째 상위 부모 노드를 알고 있기 때문에,
이를 바탕으로 두 번째 상위 부모 노드를 알 수 있고,
또 이를 바탕으로 네 번째 상위 부모 노드를 알 수 있고, ....

굳이 비유가 필요할까 싶지만, 생각나서 적어보았다.


추가 설명을 하자면,

노드 A의 첫 번째 상위 부모 노드를 안다고 해서

이를 바탕으로 어떻게
A의 두 번째 상위 부모 노드를 알수 있을지 의문이 들 수 있다.

A의 두 번째 상위 부모 노드는
A의 첫 번째 상위 부모 노드의 첫 번째 상위 부모 노드이다.

우리는 모든 노드의 첫 번째 상위 부모 노드를 알고 있으므로
A의 첫 번째 상위 부모 노드의 첫 번째 상위 부모 노드도 알 수 있게 된다.


두 정점의 깊이가 동일해졌을 때에도 마찬가지로
한 칸씩 이동하며 비교하는 것이 아닌,

2의 제곱 형태로 이동하며 비교한다.

아래와 같이 구현할 수 있다.

static void lca(int v1, int v2) {
        if(depth[v1] < depth[v2]) {
            int tmp = v1;
            v1 = v2;
            v2 = tmp;

        }

        for (int i = 20; i >= 0; i--) {
            if(depth[v1] - depth[v2] >= (1 << i)) {
                v1 = parent[v1][i];
            }
        }

        if(v1 == v2) {
            System.out.println(v1);
            return;
        }

        for (int i = 20; i >= 0; i--) {
            if(parent[v1][i] != parent[v2][i]) {
                v1 = parent[v1][i];
                v2 = parent[v2][i];

            }

            }

        System.out.println(parent[v1][0]);
        }

값 20은 2의 20제곱이 100만을 넘어가기 시작하는 수이고,
문제의 입력값 범위와 비슷하기 때문에 선택한 값이다.

LCA가 7번째 상위 노드에 있다고 하면,
7=22+21+207 = 2^2 + 2^1 + 2^0이므로
4번째 상위 노드로 이동 후, 2번째 상위 노드로 이동한 뒤

그 정점에서 바로 한단계 위 노드가 LCA가 된다.

이번에는 LCA가 16번째 상위 노드에 있다고 가정하면
16=2416 = 2^4이므로 바로 16번째 노드로 이동하도록 구현할 수도 있겠지만,

23+22+21+202^3 + 2^2 + 2^1 + 2^0에 1을 더한 값이기도 하다.
LCA의 바로 아래 노드까지 이동한 후 반복문을 종료하도록 하는 것이
가장 깔끔한 구현 방법이다.

코드

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.*;

class Main {

   static ArrayList<Integer>[] graph;
   static int[][] parent;
   static int[] depth;

   public static void main(String[] args) throws IOException {
       BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
       StringTokenizer st;

       int n = Integer.parseInt(br.readLine());

       graph = new ArrayList[n + 1];
       parent = new int[n + 1][21];
       depth = new int[n + 1];

       for (int i = 0; i < n + 1; i++)
           graph[i] = new ArrayList<>();

       for (int i = 0; i < n - 1; i++) {
           st = new StringTokenizer(br.readLine());
           int start = Integer.parseInt(st.nextToken());
           int end = Integer.parseInt(st.nextToken());

           graph[start].add(end);
           graph[end].add(start);
       }

       dfs(1, 1);

       for (int i = 1; i <= 20; i++) {
           for (int j = 1; j <= n; j++) {
               parent[j][i] = parent[parent[j][i-1]][i-1];
           }
       }

       int m = Integer.parseInt(br.readLine());

       for (int i = 0; i < m; i++) {
           st = new StringTokenizer(br.readLine());
           int start = Integer.parseInt(st.nextToken());
           int end = Integer.parseInt(st.nextToken());

           lca(start, end);
       }
   }



   static void lca(int v1, int v2) {
       if(depth[v1] < depth[v2]) {
           int tmp = v1;
           v1 = v2;
           v2 = tmp;

       }

       for (int i = 20; i >= 0; i--) {
           if(depth[v1] - depth[v2] >= (1 << i)) {
               v1 = parent[v1][i];
           }
       }

       if(v1 == v2) {
           System.out.println(v1);
           return;
       }

       for (int i = 20; i >= 0; i--) {
           if(parent[v1][i] != parent[v2][i]) {
               v1 = parent[v1][i];
               v2 = parent[v2][i];

           }

           }

       System.out.println(parent[v1][0]);
       }

   static void dfs(int curVertex, int curDepth) {

       depth[curVertex] = curDepth;
       for (int next : graph[curVertex]) {
           if(depth[next] == 0) {
               parent[next][0] = curVertex;
               dfs(next, curDepth + 1);
           }
       }

   }
}

이로써 O(logN)O(logN)의 시간복잡도로 LCA를 구할 수 있게 되었다.

참고로 depth[next] == 0이면
해당 정점을 탐색하지 않았다는 뜻과 동일하므로
굳이 visited[] 배열을 만들어서 방문 여부를 기록할 필요는 없다.

응용 문제

LCA를 어느 상황에 응용할 수 있을까?
도로 네트워크와 같은 문제에 응용할 수 있다.
LCA는 정점 사이의 거리, 두 정점 간의 가중치의 최솟값, 최댓값 등을 효과적으로 저장할 수 있다.

이 문제는 DFS로 각 노드의 첫 번째 상위 부모 노드의 최솟값, 최댓값 가중치를 저장한 뒤
DP를 이용해 2j2^j번째 상위 노드로 가는 길에서의 최솟값, 최댓값 가중치를 저장하면
된다.

코드

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.*;

class Main {

    static class Edge {
        int end;
        int weight;

        public Edge(int end, int weight) {
            this.end = end;
            this.weight = weight;
        }
    }

    static ArrayList<Edge>[] graph;
    static int[][] parent;
    static int[] depth;
    static int[][] maxWeight;
    static int[][] minWeight;

    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        StringTokenizer st;

        int n = Integer.parseInt(br.readLine());

        graph = new ArrayList[n + 1];
        parent = new int[n + 1][21];
        maxWeight = new int[n + 1][21];
        minWeight = new int[n + 1][21];
        depth = new int[n + 1];

        for (int i = 0; i < n + 1; i++)
            graph[i] = new ArrayList<>();

        for (int i = 0; i < n - 1; i++) {
            st = new StringTokenizer(br.readLine());
            int start = Integer.parseInt(st.nextToken());
            int end = Integer.parseInt(st.nextToken());
            int weight = Integer.parseInt(st.nextToken());

            graph[start].add(new Edge(end, weight));
            graph[end].add(new Edge(start, weight));
        }

        dfs(1, 1);

        for (int i = 1; i <= 20; i++) {
            for (int j = 1; j <= n; j++) {
                parent[j][i] = parent[parent[j][i-1]][i-1];
                maxWeight[j][i] = Math.max(maxWeight[j][i-1], maxWeight[parent[j][i-1]][i-1]);
                minWeight[j][i] = Math.min(minWeight[j][i-1], minWeight[parent[j][i-1]][i-1]);


            }
        }

        int m = Integer.parseInt(br.readLine());

        for (int i = 0; i < m; i++) {
            st = new StringTokenizer(br.readLine());
            int start = Integer.parseInt(st.nextToken());
            int end = Integer.parseInt(st.nextToken());

            lca(start, end);
        }
    }



    static void lca(int v1, int v2) {
        int minLength = Integer.MAX_VALUE, maxLength = Integer.MIN_VALUE;

        if(depth[v1] < depth[v2]) {
            int tmp = v1;
            v1 = v2;
            v2 = tmp;

        }

        for (int i = 20; i >= 0; i--) {
            if(depth[v1] - depth[v2] >= (1 << i)) {
                minLength = Math.min(minLength, minWeight[v1][i]);
                maxLength = Math.max(maxLength, maxWeight[v1][i]);
                v1 = parent[v1][i];
            }
        }

        if(v1 == v2) {
            System.out.println(minLength + " " + maxLength);
            return;
        }

        for (int i = 20; i >= 0; i--) {
            if(parent[v1][i] != parent[v2][i]) {
                minLength = Math.min(Math.min(minWeight[v1][i], minWeight[v2][i]), minLength);
                maxLength = Math.max(Math.max(maxWeight[v1][i], maxWeight[v2][i]), maxLength);
                v1 = parent[v1][i];
                v2 = parent[v2][i];

            }

        }

        System.out.println(Math.min(minLength, Math.min(minWeight[v1][0], minWeight[v2][0])) +
                " " + Math.max(maxLength, Math.max(maxWeight[v1][0], maxWeight[v2][0])));
    }

    static void dfs(int curVertex, int curDepth) {

        depth[curVertex] = curDepth;
        for (Edge nextEdge : graph[curVertex]) {
            if(depth[nextEdge.end] == 0) {
                parent[nextEdge.end][0] = curVertex;
                maxWeight[nextEdge.end][0] = nextEdge.weight;
                minWeight[nextEdge.end][0] = nextEdge.weight;
                dfs(nextEdge.end, curDepth + 1);
            }
        }

    }
}
profile
Better than yesterday

0개의 댓글