[백준 11049] 행렬 곱셈 순서(Java, Python)

KDG: First things first!·2025년 2월 26일

백준

목록 보기
6/8





문제 해설

이 문제는 동적 프로그래밍(DP)을 이용하여 행렬 곱셈 연산의 최소 비용을 구하는 문제이다.

N개의 행렬이 주어졌을 때, 곱셈의 연산 순서를 최적화하여 최소 연산 횟수를 구하는 문제로 행렬 곱셈은 결합법칙(Associativity) 을 따르므로 괄호를 어디에 치느냐에 따라 연산량이 달라진다.

예를 들어, 3개의 행렬이 있을 때:
(A × B) × C 와 A × (B × C) 의 연산량이 다를 수 있기 때문에 가장 적은 연산량을 가지는 최적의 행렬 곱셈 순서를 찾는 것이 목표다.



정답 코드(Java) 및 설명


import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.StringTokenizer;

public class Main {

    public static void main(String[] args) throws IOException {
        // 입력을 빠르게 받기 위해 BufferedReader 사용
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        // 결과를 한 번에 출력하기 위해 StringBuilder 사용
        StringBuilder sb = new StringBuilder();

        // 행렬의 개수 입력 받기
        int N = Integer.parseInt(br.readLine());

        // 행렬의 크기를 저장할 배열 (N개의 행렬이므로, 크기를 N+1로 설정)
        int[][] matrix = new int[N + 1][2];

        // 최소 곱셈 연산 횟수를 저장할 DP 테이블 (N+1 x N+1 크기)
        int[][] dp = new int[N + 1][N + 1];

        // 행렬의 크기 입력 받기
        for (int i = 1; i <= N; i++) {
            StringTokenizer st = new StringTokenizer(br.readLine());
            matrix[i][0] = Integer.parseInt(st.nextToken()); // 행렬의 행(row) 크기
            matrix[i][1] = Integer.parseInt(st.nextToken()); // 행렬의 열(column) 크기
        }

        // 행렬 체인 곱셈의 최소 연산 횟수를 구하는 DP 수행
        // len: 현재 고려하는 행렬 곱셈의 길이(몇 개의 행렬을 곱할 것인지) (2개부터 N개까지)
        for (int len = 2; len <= N; len++) {
            // i: 시작 행렬의 인덱스
            for (int i = 1; i <= N - len + 1; i++) {
                int j = i + len - 1; // 끝 행렬의 인덱스
                dp[i][j] = Integer.MAX_VALUE; // 최소값을 찾기 위해 초기값을 최댓값으로 설정

                // k: 행렬을 나누는 위치 (i <= k < j)
                for (int k = i; k < j; k++) {
                    // 현재 분할 위치(k)에서 곱셈 연산 비용 계산
                    int cost = dp[i][k] + dp[k + 1][j] + (matrix[i][0] * matrix[k][1] * matrix[j][1]);
                    // 최소 연산 횟수 업데이트
                    dp[i][j] = Math.min(dp[i][j], cost);
                }
            }
        }

        // 최소 연산 횟수를 StringBuilder에 저장 후 출력
        sb.append(dp[1][N]);
        System.out.println(sb);
    }
}

int cost = dp[i][k] + dp[k + 1][j] + (matrix[i][0] * matrix[k][1] * matrix[j][1]);
  • dp[i][k]: 행렬 A_i ~ 행렬 A_k 곱하는 최소 연산 횟수

  • dp[k+1][j]: 행렬 A_(k+1) ~ 행렬 A_j 곱하는 최소 연산 횟수

  • (matrix[i][0] X matrix[k][1] X matrix[j][1]):

    • matrix[i][0]: 시작 행렬의 행 크기
    • matrix[k][1]: k번째 행렬의 열 크기 → 곱셈의 중간 연결점
    • matrix[j][1]: j번째 행렬의 열 크기
    • 이 값이 실제 곱셈 연산 횟수를 결정


Python 정답 코드

(※ 백준에서 채점할 때 Python3로는 시간 초과가 나기 때문에 PyPy3를 사용하여 채점하여야 한다.)


import sys
input = sys.stdin.readline  # 입력 속도를 빠르게 하기 위해 sys.stdin.readline 사용
inf = float('inf')  # 무한대 값 설정 (최소값을 갱신하기 위해 사용)

N = int(input())  # 행렬의 개수 입력 받기

# 행렬의 크기를 저장할 배열 (N개의 행렬이므로 크기를 N+1로 설정)
matrix = [[0, 0] for _ in range(N + 1)]

# 최소 곱셈 연산 횟수를 저장할 DP 테이블 (N+1 x N+1 크기)
dp = [[0] * (N + 1) for _ in range(N + 1)]

# 행렬의 크기 입력 받기
for i in range(1, N + 1):
    a, b = map(int, input().split())  # 행렬의 행(a)과 열(b) 크기 입력 받기
    matrix[i][0] = a  # i번째 행렬의 행 크기 저장
    matrix[i][1] = b  # i번째 행렬의 열 크기 저장

# 행렬 체인 곱셈의 최소 연산 횟수를 구하는 DP 수행
# len: 현재 고려하는 행렬 곱셈의 길이 (2개부터 N개까지)
for len in range(2, N + 1):
    # i: 시작 행렬의 인덱스
    for i in range(1, N - len + 2):
        j = i + len - 1  # 끝 행렬의 인덱스 설정
        dp[i][j] = inf  # 최소값을 찾기 위해 초기값을 무한대로 설정

        # k: 행렬을 나누는 위치 (i <= k < j)
        for k in range(i, j):
            # 현재 분할 위치(k)에서 곱셈 연산 비용 계산
            cost = dp[i][k] + dp[k + 1][j] + (matrix[i][0] * matrix[k][1] * matrix[j][1])

            # 최소 연산 횟수 업데이트
            dp[i][j] = min(dp[i][j], cost)

# 최소 연산 횟수를 출력 (1번 행렬부터 N번 행렬까지 곱할 때의 최소 곱셈 연산 횟수)
print(dp[1][N])
profile
알고리즘, 자료구조 블로그: https://gyun97.github.io/

0개의 댓글