BOJ 23032) 서프라이즈

Wonjun Lee·2024년 6월 10일

문제 링크)

https://www.acmicpc.net/problem/23032

입력)

첫째 줄에 정수 N(2 ≤ N ≤ 2,000)이 주어진다.

다음 줄에 1번 학생부터 N번 학생까지 차례대로 설문 조사에 적은 스테이크의 무게를 나타내는 정수 W(1 ≤ W ≤ 10,000)가 주어진다.

출력)

이벤트에 당첨된 학생들의 스테이크 무게 합을 출력한다.


풀이과정

문제 이해

이 문제는 선형적으로 주어지는 정수 리스트에서 생성 가능한 부분집합 중 오직 연속적인 위치로 구성되며 2개 이상의 원소를 갖는 집합들을 검사하라는 조건을 가지고 있다.

연속적인 위치라는 것은 주어진 인덱스를 변화시키지 않고 즉시 문제에 적용하라는 의미이며, 정렬 알고리즘을 적용해선 안됨을 알 수 있다.

부분 집합에 대해 두 개의 그룹으로 나눴을 때, 두 그룹의 정수 차가 최소인 그룹 중에서 가장 정수 합이 큰 값을 출력한다.

입력은 최대 2000개의 정수이며, O(N^2)의 시간복잡도까지는 적용 가능할 수 있다. (비록 시간 제한이 1초이긴 하지만)

부분집합이 선택되면, 이 부분집합에 있는 모든 요소들은 그룹에 소속된다. 이 그룹도 연속적인 위치의 요소들로만 구성된다.

접근 방법

일단 무식한 방법으로 접근해본다.

가능한 모든 부분집합들을 전부 반복문을 통해 지정하고, 각 반복에서 최소차를 갖는 두 그룹을 계산한다. 그 결과를 최종 결과용 변수와 비교한다.

이 알고리즘에서 각 그룹별로 가장 차이가 적은 두 그룹을 고를 때, 그 부분 집합 원소들의 개수만큼 반복이 수행되므로, 각 반복마다 점근적으로 O(N)의 시간이 요구된다.

부분 집합을 구하는 반복문이 N(N-1)/2번 반복되므로, 최종적인 성능은 O(N^3)에 가깝다. 최대 길이를 대입해보면, 8E9 이므로 시간초과가 발생할 것이다.

시간 초과 발생 후 가장 처음 고민한 것은 해싱을 이용한 방법이었다.

해싱을 토대로 이전 반복에서 계산한 결과를 다음반복에서 사용하는 DP를 구현해보고자 했으나 진 부분집합들에 대해서 그룹 경계가 어디서 나눠질지 예측하는 것이 어렵고, 계산하는데 O(N)의 시간이 필요할 것이라 판단하였다. 그리고 구현의 복잡도도 증가하며, 디버깅이 어려워질 것 같아서 다른 방법을 고민했다.

다른 투 포인터 문제들 처럼 양쪽 끝에서 포인터들을 각 각 이동시키는 알고리즘을 이용해 탐색 범위를 좁혀볼까 했다.

두 그룹 정수 차를 기반으로 포인터를 이동시키는 방식으로는 아래 테스트 케이스에 대해서 오류를 만든다.

100 200 1300 1000 1000 -> 답은 2000이지만, 100 200 1300을 1개 그룹으로, 1000 1000을 1개 그룹으로 하고, 탐색 범위를 가운데로 좁히는 과정에서 2000이 더 크기 때문에 오른쪽 1000을 범위에서 제외하게 된다.

이로인해 정답을 계산할 수 없다.

<참고용 : 시간 초과가 난 반복 구문과 O(N)이 소요되는 함수 호출 구조>

for i in range(n-1) :
    for j in range(i+1, n) :
        Calculate(from i, to j)


이전 알고리즘을 기반으로하여 시간복잡도를 줄이기 위해선, 연속한다는 특징을 이용하여야 한다.

모든 부분집합은 연속하며 어쨌든 반복문을 이용해 모든 경우를 탐색할 수 있다. 그렇다면, 위치적으로 유사성이 있는 부분집합들은 굳이 개별적으로 그룹을 나눠보는 연산을 적용할 것이 아니라, 단 한 번의 순차탐색 만으로 모두 검사해볼 수 있지 않을까?

예를 들자면 다음과 같다.

2 1 5 2 3

우선 처음엔 2와 1만으로 계산해본다.

[2 | 1] 5 2 3
left : 2g
right : 1g
diff = 1g, sum = 3g

다음으로 5를 포함하여 2 1 5라는 부분집합에 대해 검사한다고 해보자.

그룹은 연속되는 요소들로 이뤄지므로, 5가 추가되면 이 5는 무조건 right group에 속한다. 그러므로 이전에 나눈 그룹에서 그대로 right에 5를 더해보자.

[2 | 1 5] 2 3
left : 2g
right : 6g
diff = 4g, sum = 8g

이제 이 그룹에서 차이를 줄일 방법을 생각해본다. right에 수가 추가되었다. 그럼, 추가되기 전의 right가 left보다 합이 적었던 컸던, 다시 left와 비교해봐야 한다. 그리고 left에 추가할 수 있는 right 그룹의 가장 좌측 요소를 left 그룹에 더하고, right 그룹에선 뺐을 때 차이와 비교해본다.

[2 | 1 5] 일 때, 차이 = 4g
[2 1 | 5] 일 때, 차이 = 2g
-> 더 작아졌고, 더이상 right 그룹에서 요소를 뺄 수 없다.
따라서 부분집합 [2 1 5]에선 [2 1 | 5]로 나누는게 최선이다.

이런 과정을 반복한다. 새로운 요소를 부분집합에 추가했을 때 처음부터 다시 비교할 필요가 없는 이유를 생각해보자.

1) left < right 였을 때,
right에 새로운 요소를 추가하면 right가 더 커진다. 차이를 줄이기 위해선 left를 증가시키고 right를 감소시켜야한다.

2) left > right 였을 때,
새로운 요소의 추가 이후에 right가 더 커진다. 이로인해 차이가 작아질 수 있고, 더 커졌을 수도 있다. 어느 쪽이든 left와 right의 크기 차이를 줄일 수 있는지 확인해야 하므로 검사해야한다.

3) left = right 였을 때,
이미 최소 차이인 0이었고, right가 커짐에 따라 차이가 증가한다. 따라서 left를 증가시키고 right를 감소시키며 해당 부분집합에서 최소 차이를 찾는다.

설명이 너무 길었는데 한 문장으로 정리하자면,
결국 바로 옆 요소를 추가한 부분집합은 이전 부분집합과 크게 닮아 있으므로 이를 이용해 각 그룹의 크기를 조절하면 된다는 것이다.

난잡하며 논리정연하지 못한 막글이므로, 코드를 보는 것이 차라리 덜 헷갈릴 것 같다!

코드 구현

아래는 위에서 설명한 알고리즘을 구현한 코드이다.
시간 복잡도는 O(N^2)으로 함수 호출용 반복문이 하나만 사용된다.

import sys

def solve() :
    N = int(sys.stdin.readline())
    grams = list(map(int, sys.stdin.readline().split()))
    
    def findClosestGroups(left) :
        bound = left+1
        left_group, right_group = grams[left], 0
        min_diff, max_grams = sys.maxsize, 0
        for i in range(left + 1, N) :
            right_group += grams[i]
            while(bound < i and abs(left_group - right_group) > abs(left_group + grams[bound] - (right_group - grams[bound]))) :
                left_group += grams[bound]
                right_group -= grams[bound]
                bound += 1
            
            if min_diff > abs(left_group - right_group) :
                min_diff = abs(left_group - right_group)
                max_grams = left_group + right_group
            elif min_diff == abs(left_group - right_group) :
                max_grams = max(max_grams, left_group + right_group)
        
        return (min_diff, max_grams)
    
    min_diff_so_far = sys.maxsize
    max_gram_so_far = 0
    
    # 좌 우에서 Sandwitch -> 100 200 1300 1000 1000 -> error
    # 모든 경우를 다 뒤져야 함 -> 랜덤하므로.
    # 중복되는 연산을 제거한다.
    for i in range(N - 1) :
        min_diff, max_gram = findClosestGroups(i)
        if min_diff < min_diff_so_far :
            min_diff_so_far = min_diff
            max_gram_so_far = max_gram
        elif min_diff == min_diff_so_far :
            max_gram_so_far = max(max_gram_so_far, max_gram)
    print(max_gram_so_far)
    
solve()

0개의 댓글