[알고리즘] 가장 긴 증가하는 부분 수열 (LIS) 구하기 - 최적화와 역추적까지

이상현·2025년 3월 26일

알고리즘

목록 보기
12/15
post-thumbnail

가장 긴 증가하는 부분 수열

3102120

위 배열에서 최장 증가 부분 수열은 3, 10, 20 으로 길이가 3이다.
최장 증가 부분 수열의 길이를 구하는 알고리즘을 알아보자.

다이나믹 프로그래밍 O(N2)O(N^2)

우선, 다이나믹 프로그래밍을 사용한 방법이다.
dp[i] = ary[i]를 포함하는 최장 증가 부분 수열의 길이 이다.

이중 반복문으로, 앞의 요소들의 값을 보면서 dp값을 갱신한다.

i보다 작은 j에 대해, ary[j] < ary[i] 인 경우 dp[i] = max(dp[i], dp[j] + 1) 로 갱신하면 된다.

int calcLISLen(int[] ary, int n) {
    int dp[] = new int[n];
    Arrays.fill(dp, 1);

    for (int i = 1; i < n; i++) {
        for (int j = 0; j < i; j++) {
            if (ary[j] < ary[i]) {
                dp[i] = Math.max(dp[i], dp[j] + 1);
            }
        }
    }

    return Arrays.stream(dp).max().orElse(1);
}

정답은 3으로 잘 나온다.

백준 11052 - 가장 긴 증가하는 부분 수열
이 문제는 DP로 통과 가능하다.

최적화 - 이분 탐색 O(NlogN)O(NlogN)

이분 탐색을 활용하면 이 작업을 O(NlogN)O(NlogN) 으로 할 수 있다.
이 로직은 탐욕적 선택최적 후보 유지 전략에 기반한다.

우선, 길이가 ary 와 같은 lis 배열이 추가된다.
lis 배열의 각 원소는 해당 길이의 증가 부분 수열에서 가능한 최소의 마지막 값을 의미한다. 이를 유지하면 앞으로 등장하는 더 큰 수를 쉽게 추가할 수 있다.

즉, 현재 값이 lis 배열의 마지막 요소보다 작거나 같을 경우 해당 위치의 값을 대체하여, 나중에 더 긴 부분 수열을 만들 가능성을 높이는 것이다.

단, 계산이 끝난 후 lis 배열이 실제 lis 배열을 의미하지는 않는다. lis 배열의 길이만 유효한 결과이다.

알고리즘

ary 의 요소를 처음부터 탐색하면서, lis 배열의 어디에 집어넣을지 본다.

  • ary[i] 보다 lis 배열의 마지막 요소가 크거나 같으면, lis 배열에서 ary[i] 보다 크거나 같은 수 중에 가장 작은 수가 있는 위치에 ary[i]덮어쓴다. (lower bound 이진 탐색 활용)
  • ary[i] 보다 lis 배열의 마지막 요소가 작으면, lis 배열 뒤에 ary[i] 값을 추가한다.

ary

i01234
ary[i]3102120

ary가 이렇게 있다면, 요소를 돌 때 마다 lis 배열이 어떻게 되는지 하나하나 살펴보겠다.


  1. 초기화: ary의 첫번째 요소를 넣어둔다.
i01234
lis[i]3

  1. i = 1 부터 시작
    ary[1] = 10lis의 마지막 요소 3보다 크므로 뒤에 추가한다.
i01234
lis[i]310

  1. i = 2
    ary[2] = 2, 2 <= 10 이므로, lis 배열중에 2보다 크거나 같고, 가장 작은 수는 3이다. 그 위치에 덮어쓴다.
i01234
lis[i]210

  1. i = 3
    ary[3] = 1, 1 <= 10 이므로, 1이 있는 0번 위치에 덮어쓴다.
i01234
lis[i]110

  1. i = 4
    ary[4] = 20 , 20 <= 10 이 아니므로 뒤에 추가한다.
i01234
lis[i]21020

모두 순회했다. LIS 의 길이는 3인것이 정답으로 나왔다.

보시다시피 위 계산을 통해서는 실제 LIS 배열을 얻을 수는 없다. 원본 배열 ary를 보면, 2, 10, 20 은 연속하지 않아서 조건을 만족하지 않는다.

코드

백준 12738 - 가장 긴 증가하는 부분 수열 3
이 문제는 O(N2)O(N^2) 으로 풀리지 않아 위 방법으로 O(NlogN)O(NlogN) 에 풀어야 한다.

입출력은 제외하고 중요 코드만 작성하겠다.

// n, ary, lis 초기화..

lis[0] = ary[0]; // 첫번쨰 요소는 미리 넣어둔다.
int len = 1; // lis 배열의 이론적 길이를 의미

for (int i = 1; i < n; i++) {
	// lis 배열의 마지막 요소가 ary[i] 보다 크거나 같으면
    if (lis[len - 1] >= ary[i]) {
        int index = binarySearch(lis, len, ary[i]);
        lis[index] = ary[i]; // 덮어쓰기
    } else {
        lis[len++] = ary[i]; // 뒤에 추가하고 길이 증가
    }
}

// 정렬된 배열에서 target 의 lower bound 의 인덱스 찾기
private static int binarySearch(int[] lis, int len, int target) {
    int start = 0;
    int end = len - 1;
    int ans = len - 1;
    while (start <= end) {
        int mid = (start + end) / 2;
        if (lis[mid] >= target) {
            end = mid - 1;
            ans = mid;
        } else {
            start = mid + 1;
        }
    }
    return ans;
}

LIS 의 길이가 아니라, 실제 LIS 를 구하려면?

위 방법으로는 실제 LIS 를 구하는것이 불가능했다.
그럼 실제 LIS인 3, 10, 20 를 구하려면 어떻게 해야 할까?

위에서 lis 의 길이를 구하는 과정을 보면, 나중에 진짜 lis 가 되는 요소의 정보가 다른 값으로 덮어써지면서 없어진다. 덮어써지기 전에 정보를 어딘가 저장해야 한다.

계산 과정에서, ary의 요소가 lis에서 몇번째 인덱스에 들어갔었는지 한번 index 배열에 기록해보자.

ary

i01234
ary[i]3102120

  1. 초기화
i01234
lis[i]3

ary0번째 요소가 lis0번에 들어갔으니 index[0]0 을 넣는다.

i01234
index[i]0

  1. i = 1
i01234
lis[i]310

ary[1]lis[1] 에 들어갔으므로 index[1] = 1

i01234
index[i]01

  1. i = 2
i01234
lis[i]210

ary[2]lis[0] 에 들어갔으므로 index[2] = 0

i01234
index[i]010

  1. i = 3
i01234
lis[i]110

ary[3]lis[0] 에 들어갔으므로 index[3] = 0

i01234
index[i]0100

  1. i = 4
i01234
lis[i]21020

ary[4]lis[2] 에 들어갔으므로 index[4] = 2

i01234
index[i]01002

순회가 끝났다. 계산이 끝난 후 index 배열을 유심히 보면, 실제 lis 배열인 3, 10, 20 (인덱스는 0, 1, 4) 를 찾아낼 수 있다.

반대 방향으로 순회하면서, 가장 먼저 나오는 2, 1, 0을 순서대로 찾으면 된다.

  1. 역순으로 순회한다.
  2. 2 (lis길이 - 1) 가 먼저 나오는곳: 4
  3. 1이 먼저 나오는 곳: 1
  4. 0이 먼저 나오는 곳: 0

다시 순서를 뒤집으면 0, 1, 4 즉, 실제 lis 배열의 인덱스이다.

코드를 이를 활용해야 하는 문제와 함께 보자.

코드

백준 14003 - 가장 긴 증가하는 부분 수열 5

이 문제는, lis의 길이를 이진 탐색으로 O(NlogN)O(NlogN) 으로 구하고, 또한 실제 LIS 배열을 출력해야 한다.

기존 코드에서 추가된 로직에는 주석을 작성했다.

int[] indexAry = new int[n]; // 새로 추가된 indexAry
Arrays.fill(indexAry, -1);
lis[0] = ary[0];
indexAry[0] = 0;
int len = 1;
for (int i = 1; i < n; i++) {
	if (lis[len - 1] >= ary[i]) {
		int index = binarySearch(lis, len, ary[i]);
		indexAry[i] = index; // 현재 요소가 lis배열에 어디에 추가됐는지 기록
		lis[index] = ary[i];
	} else {
		indexAry[i] = len; // 현재 요소가 lis배열에 어디에 추가됐는지 기록
		lis[len++] = ary[i];
	}
}
System.out.println(len);

int tmp = len - 1; // lis의 길이 - 1 부터 하나씩 찾기 위한 변수
Stack<Integer> ansAry = new Stack<>(); // 먼저 넣은걸 나중에 꺼내기 위해 스택 사용.

for (int i = n - 1; i >= 0; i--) { // 역방향 탐색
	if (indexAry[i] == tmp) { // 처음 만나는 tmp 값이면
		ansAry.push(ary[i]); // 해당 위치에 저장했던 요소 값 저장
		tmp--; // 다음 값 탐색
	}
}

StringBuilder sb = new StringBuilder();
while (!ansAry.isEmpty()) {
	sb.append(ansAry.pop()).append(" "); // 스택에 저장한 순서의 반대로 출력하기 위함
}
System.out.println(sb);

앞서 설명했듯이 기존에 계산을 하는 도중에, 요소가 lis의 어디에 들어갔었는지만 기록하면 결과를 알 수 있다.

참고 - 14003 전체 코드

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.*;

public class Main {
    static int n;
    static int[] ary, lis;

    public static void main(String[] args) throws IOException {
        BufferedReader bf = new BufferedReader(new InputStreamReader(System.in));
        n = Integer.parseInt(bf.readLine());
        StringTokenizer st = new StringTokenizer(bf.readLine());
        ary = new int[n];
        for (int i = 0; i < n; i++) {
            ary[i] = Integer.parseInt(st.nextToken());
        }
        lis = new int[n];
        int[] indexAry = new int[n];
        Arrays.fill(indexAry, -1);
        lis[0] = ary[0];
        indexAry[0] = 0;
        int len = 1;
        for (int i = 1; i < n; i++) {
            if (lis[len - 1] >= ary[i]) {
                int index = binarySearch(lis, len, ary[i]);
                indexAry[i] = index;
                lis[index] = ary[i];
            } else {
                indexAry[i] = len;
                lis[len++] = ary[i];
            }
        }
        System.out.println(len);

        int tmp = len - 1;
        Stack<Integer> ansAry = new Stack<>();
        for (int i = n - 1; i >= 0; i--) {
            if (indexAry[i] == tmp) {
                ansAry.push(ary[i]);
                tmp--;
            }
        }
        StringBuilder sb = new StringBuilder();

        while (!ansAry.isEmpty()) {
            sb.append(ansAry.pop()).append(" ");
        }

        System.out.println(sb);
    }

    private static int binarySearch(int[] lis, int len, int target) {
        int start = 0;
        int end = len - 1;
        int ans = len - 1;
        while (start <= end) {
            int mid = (start + end) / 2;
            if (lis[mid] >= target) {
                end = mid - 1;
                ans = mid;
            } else {
                start = mid + 1;
            }
        }
        return ans;
    }
}

0개의 댓글