옛날 어느 시골 마을에 서로 우애가 좋던 형과 아우가 있었다. 형제는 돌아가신 부모님으로부터 물려받은 각자의 논에 벼를 심은 뒤 열심히 가꾸어 추수했고, 남은 볏단을 논에 쌓아 두었다. 형제는 서로의 형편을 걱정해 자신의 논에 쌓아 놓은 볏단을 매일 밤 몰래 서로의 논에 옮겨 놓았다. 형제는 마음이 통했는지 서로의 논에 모인 볏단의 양은 변하지 않았고, 어느 날 진실을 알게 된 형제는 부둥켜안고 눈물을 흘렸다.
현욱이는 이 의좋은 형제 이야기를 다음과 같이 시뮬레이션하려고 한다.
형제는 부모님으로부터 논을 물려받아 형과 아우가 똑같이
개씩 논을 나눠 가졌다. 추수를 마치고 보니 형의 논에는 각각
개의 볏단이, 아우의 논에는 각각
개의 볏단이 쌓여 있었다.
형제는 마음이 통했기 때문에 매일 밤 서로 같은 정수
를 정해서 자신의
번째 논의 볏단을 상대의
번째 논에 옮겨 놓았다. 볏단을 옮기기 전에 형제의
번째 논에는 아직 볏단이 쌓여 있어야 한다.
형제는 모든 볏단이
번째 논으로 모일 때까지 매일 밤 볏단을 옮기는 것을 반복했다.
현욱이는 모든 볏단을
번째 논으로 모으는 방법들을 시뮬레이션하고 있다. 더 이상 볏단을 옮길 수 없을 때,
번째 논에 모인 두 볏단의 양은 최대 얼마만큼 차이날 수 있는지 구해보자.
...
DP 문제라고 생각해 이런 저런 공식을 생각했지만, 아무리 생각해도 답이 나오지 않았다. 문제 난이도를 보고 시작했기에 설마... 하고 알고리즘이 뭔지 확인했는데 "그리디"라고 적혀있어 처음부터 다시생각했다.
그리디 알고리즘이면 직관적인 알고리즘일 것이라 생각해 그림을 다시 살펴보았는데, 서로의 최대값에만 더하는 것으로 보였다.
그래서 i부터 순회를 돌아 arr[i+1:] 가 max인 논에 볏단을 옮기면 어떨까? 하고 코드를 짜고 이리저리 테스트해봤는데 예외가 없었다.
import sys
input = sys.stdin.readline
def solve():
N = int(input())
A = list(map(int, input().split()))
B = list(map(int, input().split()))
for i in range(N-1):
A_tmp = A[i]
B_tmp = B[i]
A[i] = -1
B[i] = -1
B[B.index(max(B[i+1:]))] += A_tmp
A[A.index(max(A[i+1:]))] += B_tmp
print(abs(A[N-1] - B[N-1]))
solve()
해당코드는 max에서 O(n), index에서 O(n)이 돌기 때문에 O(2 n^2) = O(n^2)이다. 따라서 시간초과로 통과하지 못할 것 같았지만 일단 제출해보았다.
역시나 시간초과가 뜬 모습
하지만 시간복잡도를 개선하니 틀렸습니다! 가 나왔다. 아무래도 내 풀이는 그리디하지 못했거나 문제의 조건을 충족하지 못한것이다.
실제로 이방식은 A[-2]가 A[-1]로 갈수있다는 치명적인 오류가 있었다. 그래서 다시 처음부터 생각하기로 했다
0 | 1 | 2 | 3 | |
---|---|---|---|---|
A | 4 | 3 | 2 | 1 |
B | 2 | 3 | 3 | 2 |
문제를 다시보자. i -> j로 한번 옮길때, 교차(A원소가 B원소로 더해지는 일)가 일어나게 된다.
따라서경우에서 2->3으로 옮길때는 교차가 일어나게 되는데, 2번 논의 경우는 3번으로만 옮길 수 있으므로 반드시 교차가 일어난다고 볼 수 있다.
(B[3] += A[2]; A[3] += B[2])
이를 이용해 우리는 최종적으로 원하는 사람의 논에 볏단을 몰아줄 수 있다.
편의를 위해 다음과 같은 경우를 생각해보자.(문제 조건에서는 0인 경우는 없지만, 0이 있어도 상관은 없다)
초기 | 0 | 1 | 2 | 3 |
---|---|---|---|---|
A | 4 | 3 | 1 | 1 |
B | 0 | 0 | 0 | 0 |
우리는 2번(N-2번) 논의 볏단은 반드시 3번논에 교차해서 들어간다는 사실을 알 수 있다. 따라서 A[3]에 A[0]과 A[1]의 볏단을 넣고싶으면 B[2]에 넣으면 된다.
i=0 | 0 | 1 | 2 | 3 |
---|---|---|---|---|
A | 0 | 3 | 1 | 1 |
B | 0 | 0 | 4 | 0 |
i=1 | 0 | 1 | 2 | 3 |
---|---|---|---|---|
A | 0 | 0 | 1 | 1 |
B | 0 | 0 | 7 | 0 |
따라서 최종적으로는 다음의 결과가 나오며, 차는 7이 된다.
최종 | 0 | 1 | 2 | 3 |
---|---|---|---|---|
A | 0 | 0 | 0 | 8 |
B | 0 | 0 | 0 | 1 |
이번엔 다음과 같은 경우를 생각해보자
초기 | 0 | 1 | 2 | 3 |
---|---|---|---|---|
A | 0 | 0 | 0 | 0 |
B | 4 | 3 | 1 | 1 |
마찬가지로 우리는 A[3]이 최대가 되도록 할것이다. B[0]과 B[1]은 바로 A[3]으로 넣을수 있기 때문에 이 경우는 더 간단하다.
i=0 | 0 | 1 | 2 | 3 |
---|---|---|---|---|
A | 0 | 0 | 0 | 4 |
B | 0 | 3 | 1 | 1 |
i=1 | 0 | 1 | 2 | 3 |
---|---|---|---|---|
A | 0 | 0 | 0 | 7 |
B | 0 | 0 | 1 | 1 |
최종적으로는 다음이 나온다.
최종 | 0 | 1 | 2 | 3 |
---|---|---|---|---|
A | 0 | 0 | 0 | 8 |
B | 0 | 0 | 0 | 1 |
따라서 다음과 같은 사실을 알 수 있다.
B[-2] += A[i]
A[-1] += B[i]
이 사실을 바탕으로 코드를 짜보자.
import sys
input = sys.stdin.readline
def solve():
N = int(input())
A = list(map(int, input().split()))
B = list(map(int, input().split()))
def get_diff(A, B):
"""
A[-1]을 최대로 만든 후, A[-1]과 B[-1]의 차이를 구한다.
"""
# 리스트는 mutable하므로 미리 복사한다.
a_end = A[-2:]
b_end = B[-2:]
for i in range(N-2):
if A[i] > B[i]:
a_end[-2] += B[i]
b_end[-2] += A[i]
else:
a_end[-1] += B[i]
b_end[-1] += A[i]
a_end[-1] += b_end[-2]
b_end[-1] += a_end[-2]
return abs(a_end[-1] - b_end[-1])
print(max(get_diff(A,B), get_diff(B,A)))
solve()
위 코드를 아래와 같이도 바꿀 수 있다.
import sys
input = sys.stdin.readline
def solve():
N = int(input())
A = list(map(int, input().split()))
B = list(map(int, input().split()))
def get_diff2(A, B):
"""
A[-1]을 최대로 만든 후, A[-1]과 B[-1]의 차이를 구한다.
"""
a_end = A[-1]
b_end = B[-1]
for i in range(N-2):
if A[i] > B[i]:
a_end += A[i]
b_end += B[i]
else:
a_end += B[i]
b_end += A[i]
a_end += B[-2]
b_end += A[-2]
return abs(a_end - b_end)
print(max(get_diff2(A,B), get_diff2(B,A)))
solve()
1번코드:
2번코드
그리디는 방법을 떠올리는 것이 너무 어려운 것 같다.
"N-2번은 반드시 교차하여 N-1번 논에 들어간다" 라는 사실만 깨달으면 쉬운 문제인 것 같다.