문제 링크
https://www.acmicpc.net/problem/11066
우선 N이 500정도로 크지 않으므로 대충 O(N^3) 정도까지는 괜찮겠다 싶었다.
dp로 풀어야겠다 생각했고, 다음과 같이 정의했다.
dp[i][j]
= i번부터 j번 파일까지 합칠 때 드는 최소 비용
그렇다면 dp[i][j]
가 최소가 되게 하려면 어떻게 해야할까?
구간 i~j 사이의 임의의 번호 m에서, 저 구간을 반으로 쪼개보면서 그 중 최소를 찾아내면 된다.
(i ~ j) -> (i ~ m) , (m+1 ~ j) 이런식으로 말이다.
결국 i~j 구간의 길이를 조금씩 늘려가면서 최종적인 dp[0][k-1]
을 구하면 되는 것이다.
점화식은 다음과 같다.
for m in range(i, j):
dp[i][j] = min(dp[i][j], dp[i][m] + dp[m + 1][j] + total)
m
을 기준으로 왼쪽에 있는 놈(i~m)의 최소비용과 오른쪽에 있는 놈(m+1~j)의 최소비용을 더하고,
이번에 드는 비용(이번에 합칠 파일의 크기) total
을 더한 값이 최소가 되게끔 매 m
마다 갱신해준다.
for i in range(k): # 길이 0
dp[i][i] = 0
for l in range(2, k): # 길이 2 ~ k-1
for i in range(k - l):
j = i + l
total = s[j] - s[i - 1] # 구간합
for m in range(i, j):
dp[i][j] = min(dp[i][j], dp[i][m] + dp[m + 1][j] + total)
주의해야 할 점은, dp[i][i]
는 비용이 0이라는 것이다.
이 경우는 그 파일 그 자체인 경우이다.
단 한번도 합치지 않은 것이므로 i~i 까지 든 비용은 0인 것이다.
import sys
# sys.setrecursionlimit(10 ** 8) <-- pypy 제출시 주석처리 안 하면 메모리초과...
input = lambda: sys.stdin.readline().rstrip()
MAX = 1000000000
tc = int(input())
for _ in range(tc):
k = int(input())
c = list(map(int, input().split()))
dp = [[MAX for _ in range(k)] for _ in range(k)]
s = [0 for _ in range(k + 1)] # 누적합
for i in range(k):
s[i] = s[i - 1] + c[i]
for i in range(k): # 길이 0
dp[i][i] = 0
for l in range(1, k): # 길이 1 ~ k-1
for i in range(k - l):
j = i + l # 끝점 j
total = s[j] - s[i - 1] # 구간합
for m in range(i, j): # i~j 사이 임의의 m에 대해
dp[i][j] = min(dp[i][j], dp[i][m] + dp[m + 1][j] + total) # 최소를 찾는 것
print(dp[0][k - 1])