최종 수정일: 2024년 11월 18일
주어지는 수열의 가장 긴 증가하는 부분 수열의 길이를 구하는 문제입니다.
이때 가장 긴 증가하는 부분 수열을 영어로는 Longest Increasing Subsequence라고 합니다. 따라서 가장 긴 증가하는 부분 수열을 앞으로는 약자인 LIS라고 부르도록 하겠습니다.
1 2 3 1 2
일 경우 인덱스 0부터 시작하는 1 2 3이 LIS입니다. 앞에서부터 읽으면 직관적으로 알 수 있습니다.
하지만 4 5 1 2 3
일 경우 인덱스 2부터 시작하는 1 2 3이 LIS입니다. 앞의 2개만 읽고 4 5를 LIS라고 생각하더라도, 후에 더 긴 LIS가 있을 수 있습니다.
수열의 어느 값이 정답 LIS일지는 바로 알기 어렵습니다.
따라서 수열의 각 값을 LIS의 끝으로 하여 길이를 계산했을 때 가장 긴 값이 답이 되겠습니다. 배열의 최댓값을 구하기 위해 O(n)으로 배열을 순회하여 값을 찾아내는 것과 유사합니다.
예를 들어 수열이 4 5 1 2 3
인 경우 4, 5, 1, 2, 3을 각각의 값을 LIS의 끝으로 고정했을 때의 길이를 구하고 나온 5개의 값들 중 가장 큰 값
이 문제에서 요구하는 정답이 됩니다.
순차적으로 보면 아래 과정을 거쳐 정답이 계산됩니다. 배열의 0번 인덱스부터 마지막 값까지를 순회하며 그 값을 끝으로 하는 LIS 길이를 구합니다.
4를 끝으로 하는 LIS의 길이는 4 하나뿐이니 1입니다.
5를 끝으로 하는 LIS의 길이는 4 5이니 2입니다.
1을 끝으로 하는 LIS의 길이는 4 5 1로 1 하나 뿐이니 1입니다.
2를 끝으로 하는 LIS의 길이는 4 5 1 2에서의 1 2(또는 4 5) 이므로 2입니다.
3을 끝으로 하는 LIS의 길이는 4 5 1 2 3에서의 1 2 3이므로 3입니다.
수열의 모든 부분에서의 LIS를 계산했습니다. 정답은 최대 길이를 가져야 하므로 4 5 1 2 3
수열에서의 max([1, 2, 1, 2, 3]) = 3이 정답이 됩니다.
그러면 임의의 값을 끝으로 하는 LIS의 길이는 어떻게 구할 수 있을까요? 구했던 과정과 똑같이 부분 수열의 각 값을 끝으로 했을 때의 LIS를 계산(또는 저장했던 값을 가져옴)합니다. 우리가 주어진 수열의 모든 값에서의 LIS를 구하려고 했던 것처럼 부분 수열에서도 마찬가지로 반복합니다. 더 구체적으로는, 임의의 값 n을 LIS의 끝 값으로 정했을 때, 이 n의 이전에 올 수 있는 값은 n보다 반드시 작은 값이어야 합니다. n보다 같거나 큰 값m이 n의 이전에 올 경우
m -> n은 증가하는 부분 수열이 아니게 되기 때문입니다. 따라서 이를 고려하며 예시를 머리로 계산해보겠습니다.
4 5 1 2 3
에서 3을 끝으로 하는 LIS의 길이를 예로 들어보겠습니다. 3을 끝으로 했을 때 3의 이전에 올 수 있는 값은 3보다 작은 값만 가능합니다. 4->3, 5->3 같은 경우는 증가하지 않기 때문에 불가능합니다. 따라서 3의 이전에 올 수 있는 값은 1, 2만 가능합니다.
그리고 이때 고려해야할 점은 각각의 값을 LIS의 끝으로 고정했을 때의 길이를 구하고 나온 값들 중 가장 큰 값
이 정답을 찾을 수 있는 방법입니다. 따라서 지금 3보다 작다고 구해진 1, 2에서도 같은 방법을 적용해야합니다.
즉, 1을 끝으로 했을 때의 LIS 길이, 2를 끝으로 했을 때의 LIS 길이를 비교해 더 긴 값을 선택한다는 것입니다.
그리고 이 중 선택된 길이 + 1한 값이 3을 끝으로 하는 LIS의 길이입니다. 2가 선택되었다고 하면 ...->2
에서 ...->2->3
으로 2를 끝으로 하는 LIS에서 3이 한 개 더 추가되었기 때문에 1을 더해줍니다.
코드로 보면 다음과 같습니다.
for i in range(n): # 부분 수열로 나눈다
lastValue = arr[i] # i번째의 값을 LIS의 끝으로 한다. 에서 i번째의 값
for j in range(i): # 나눠진 부분 수열을 순회한다
curValue = arr[j]
if curValue < lastValue: # LIS의 오른쪽 값이 반드시 현재 비교하기로 선택한 값보다 커야한다.('증가하는'이 조건이므로)
LIS[i] = max(LIS[i], LIS[j] + 1) # 현재 비교하기로 선택한 값 + 1의 이유는 lastValue까지 포함한 LIS 길이
앞에서 든 예시의 총 과정은 아래 이미지와 같습니다.
n = int(input())
ns = list(map(int, input().split()))
dp = [1] * (n)
for i in range(n):
for j in range(i):
if ns[j] < ns[i]:
dp[i] = max(dp[i], dp[j] + 1)
print(max(dp))
읽어주셔서 감사합니다.
글을 읽으며 이해가 어려웠던 부분이나 질문하고 싶은 내용이 있으시다면 이메일또는 댓글 남겨주시길 바랍니다.
당연히 아실것 같지만 이건 O(N^2) 시간복잡도라 10^4~5 개가 넘어가는 녀석들한텐 쓰기 힘들죠
관심 있으신분들은 이진탐색을 사용 하는 방식을 한번 찾아보시길