


이 문제는 동적 프로그래밍(DP)을 이용하여 행렬 곱셈 연산의 최소 비용을 구하는 문제이다.
N개의 행렬이 주어졌을 때, 곱셈의 연산 순서를 최적화하여 최소 연산 횟수를 구하는 문제로 행렬 곱셈은 결합법칙(Associativity) 을 따르므로 괄호를 어디에 치느냐에 따라 연산량이 달라진다.
예를 들어, 3개의 행렬이 있을 때:
(A × B) × C 와 A × (B × C) 의 연산량이 다를 수 있기 때문에 가장 적은 연산량을 가지는 최적의 행렬 곱셈 순서를 찾는 것이 목표다.
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.StringTokenizer;
public class Main {
public static void main(String[] args) throws IOException {
// 입력을 빠르게 받기 위해 BufferedReader 사용
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
// 결과를 한 번에 출력하기 위해 StringBuilder 사용
StringBuilder sb = new StringBuilder();
// 행렬의 개수 입력 받기
int N = Integer.parseInt(br.readLine());
// 행렬의 크기를 저장할 배열 (N개의 행렬이므로, 크기를 N+1로 설정)
int[][] matrix = new int[N + 1][2];
// 최소 곱셈 연산 횟수를 저장할 DP 테이블 (N+1 x N+1 크기)
int[][] dp = new int[N + 1][N + 1];
// 행렬의 크기 입력 받기
for (int i = 1; i <= N; i++) {
StringTokenizer st = new StringTokenizer(br.readLine());
matrix[i][0] = Integer.parseInt(st.nextToken()); // 행렬의 행(row) 크기
matrix[i][1] = Integer.parseInt(st.nextToken()); // 행렬의 열(column) 크기
}
// 행렬 체인 곱셈의 최소 연산 횟수를 구하는 DP 수행
// len: 현재 고려하는 행렬 곱셈의 길이(몇 개의 행렬을 곱할 것인지) (2개부터 N개까지)
for (int len = 2; len <= N; len++) {
// i: 시작 행렬의 인덱스
for (int i = 1; i <= N - len + 1; i++) {
int j = i + len - 1; // 끝 행렬의 인덱스
dp[i][j] = Integer.MAX_VALUE; // 최소값을 찾기 위해 초기값을 최댓값으로 설정
// k: 행렬을 나누는 위치 (i <= k < j)
for (int k = i; k < j; k++) {
// 현재 분할 위치(k)에서 곱셈 연산 비용 계산
int cost = dp[i][k] + dp[k + 1][j] + (matrix[i][0] * matrix[k][1] * matrix[j][1]);
// 최소 연산 횟수 업데이트
dp[i][j] = Math.min(dp[i][j], cost);
}
}
}
// 최소 연산 횟수를 StringBuilder에 저장 후 출력
sb.append(dp[1][N]);
System.out.println(sb);
}
}
int cost = dp[i][k] + dp[k + 1][j] + (matrix[i][0] * matrix[k][1] * matrix[j][1]);
dp[i][k]: 행렬 A_i ~ 행렬 A_k 곱하는 최소 연산 횟수
dp[k+1][j]: 행렬 A_(k+1) ~ 행렬 A_j 곱하는 최소 연산 횟수
(matrix[i][0] X matrix[k][1] X matrix[j][1]):
(※ 백준에서 채점할 때 Python3로는 시간 초과가 나기 때문에 PyPy3를 사용하여 채점하여야 한다.)
import sys
input = sys.stdin.readline # 입력 속도를 빠르게 하기 위해 sys.stdin.readline 사용
inf = float('inf') # 무한대 값 설정 (최소값을 갱신하기 위해 사용)
N = int(input()) # 행렬의 개수 입력 받기
# 행렬의 크기를 저장할 배열 (N개의 행렬이므로 크기를 N+1로 설정)
matrix = [[0, 0] for _ in range(N + 1)]
# 최소 곱셈 연산 횟수를 저장할 DP 테이블 (N+1 x N+1 크기)
dp = [[0] * (N + 1) for _ in range(N + 1)]
# 행렬의 크기 입력 받기
for i in range(1, N + 1):
a, b = map(int, input().split()) # 행렬의 행(a)과 열(b) 크기 입력 받기
matrix[i][0] = a # i번째 행렬의 행 크기 저장
matrix[i][1] = b # i번째 행렬의 열 크기 저장
# 행렬 체인 곱셈의 최소 연산 횟수를 구하는 DP 수행
# len: 현재 고려하는 행렬 곱셈의 길이 (2개부터 N개까지)
for len in range(2, N + 1):
# i: 시작 행렬의 인덱스
for i in range(1, N - len + 2):
j = i + len - 1 # 끝 행렬의 인덱스 설정
dp[i][j] = inf # 최소값을 찾기 위해 초기값을 무한대로 설정
# k: 행렬을 나누는 위치 (i <= k < j)
for k in range(i, j):
# 현재 분할 위치(k)에서 곱셈 연산 비용 계산
cost = dp[i][k] + dp[k + 1][j] + (matrix[i][0] * matrix[k][1] * matrix[j][1])
# 최소 연산 횟수 업데이트
dp[i][j] = min(dp[i][j], cost)
# 최소 연산 횟수를 출력 (1번 행렬부터 N번 행렬까지 곱할 때의 최소 곱셈 연산 횟수)
print(dp[1][N])