[백준/Java] 11049 행렬 곱셈 순서

박찬병·2024년 10월 26일

Problem Solving

목록 보기
17/48

https://www.acmicpc.net/problem/11049

문제 요약

크기가 a x b인 행렬과 b x c인 행렬을 곱할 때 필요한 곱셈의 수는 a x b x c번이다.
이를 고려하면 행렬 N개를 곱하는 데 필요한 곱셈의 수는 행렬을 곱하는 순서에 따라 달라짐을 알 수 있다.
행렬 N개의 크기가 각각 주어졌을 때, 모든 행렬을 곱하는 데 필요한 곱셈의 최솟값을 구하여라.
이때 입력으로 주어지는 행렬의 순서를 바꾸면 안 된다.

행렬의 개수 N은 최대 500이다.
각 행렬의 크기를 나타내는 r, c는 최대 500이다.
행렬의 곱셈을 최악의 순서로 수행해도 전체 곱셈 횟수는 2^31-1 이하이다.


문제 접근

DP를 이용해서 문제를 해결할 수 있다.
행렬 곱은 인접한 행렬끼리만 가능하다는 점을 이용해야 한다.
가로 세로가 각각 행렬의 순서를 인덱스로 사용하는 2차원 배열을 만든다.
세로가 행렬곱 시작 위치, 가로가 행렬곱 끝 위치를 나타낸다고 생각하여 각 경우의 최소 곱셈 횟수를 구하며 배열을 채워나간다.
최종적으로 해당 배열의 [0][N-1] 인덱스 값이 전체 행렬곱의 최소 곱셈 수를 나타내게 된다.

최소 곱셈 횟수를 기록하는 배열을 어떤 순서로 채워나갈 지 고민해야 한다.
기본적으로 행렬곱을 수행하는 행렬의 개수가 하나씩 증가하는 경우를 고려해야 할 것이므로 곱하는 행렬의 수를 하나씩 늘리면서 값을 찾아야 할 것 같다.
이러한 작업을 첫 인덱스부터 끝 인덱스까지 수행하면 된다.

이때 유의할 점은, 곱하는 행렬의 수를 하나씩 늘릴 때 단순히 이전 연산에 새로운 행렬을 곱하는 것이 아니라 해당 범위의 최소 곱셈 횟수를 새로 구해야 한다는 점이다.
예를 들어 인덱스 1~4까지 4개의 행렬을 곱하는 최소 곱셈은 1~3에 4를 곱하는 것이 끝이 아니라, 1과 2~4를 곱하기, 1~2와 3~4를 곱하기를 포함한 3가지 경우를 모두 고려해야 한다는 점이다.

위의 이야기를 고려하면 반복문을 3중으로 사용하여 문제를 해결할 수 있다.
이 문제에서 행렬의 개수 N은 최대 500이므로 시간복잡도 O(N3)O(N^3)도 시간 내에 충분히 수행할 수 있다.


풀이

기본적인 아이디어는 다음과 같다.

  1. 특정 위치부터 특정 위치까지의 행렬곱의 최소 곱셈 횟수를 기록하는 배열을 선언한다.
  2. 이 배열은 다음과 같은 방식으로 채운다.
    2.1. 행렬곱을 수행하는 행렬의 개수를 1부터 N까지 증가한다.
    2.2. 이때, 행렬곱을 수행하는 시작 인덱스를 0부터 N-1까지 순회한다.
    2.3. 이때, 새로운 범위의 최소 곱셈 횟수를 구하기 위해 행렬곱하는 두 행렬 간의 길이를 변화시키며 값을 계산한다.
  3. 배열을 모두 채운 뒤, 시작점이 0이고 끝점이 N-1인 해당 배열의 값이 정답이 된다.

이를 구현한 코드는 다음과 같다.

import java.util.*;
import java.io.*;

public class Main {
    public static void main(String[] args) throws Exception {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));

        // 입력 받기
        int N = Integer.parseInt(br.readLine());

        int[][] matrix = new int[N][2]; // 입력 매트릭스를 기록하는 배열
        for (int i = 0; i < N; i++) {
            StringTokenizer stN = new StringTokenizer(br.readLine());

            int r = Integer.parseInt(stN.nextToken());
            int c = Integer.parseInt(stN.nextToken());

            matrix[i][0] = r;
            matrix[i][1] = c;
        }

        // DP에서 사용할 연산의 최솟값을 기록하는 배열
        // 세로는 시작점, 가로는 끝점을 나타낸다.
        int[][] minOps = new int[N][N];

        // 길이가 1인 경우는 연산이 0이다.
        for (int i = 0; i < N; i++) {
            minOps[i][i] = 0;
        }

        // 길이가 2인 경우부터 N인 경우까지 진행한다.
        for (int i = 0; i < N-1; i++) {
            // 시작 인덱스는 0부터, 인덱스+길이-1이 N이 되기 전까지 순회한다.
            for (int startIdx = 0; i+startIdx+1 < N; startIdx++) {
                int endIdx = startIdx+i+1;
                // 길이가 2인 경우에는 계산한 값이 없으므로 그냥 계산해서 기록한다.
                if (i == 0) {
                    int A = matrix[startIdx][0];
                    int B = matrix[startIdx][1];
                    int C = matrix[endIdx][1];
                    int ops = A*B*C;

                    minOps[startIdx][endIdx] = ops;
                }
                // 시작점이 포함된 부분과 끝점이 포함된 부분을 더하며 비교한다.
                // 이는 기존 연산값 + 새로 합치는 연산값을 기반으로 비교해야 한다.
                else {
                    // 시작점 부분의 길이가 1인 것부터, i+1까지 가능하다.
                    // startIdx~startIdx+k 와 startIdx+k+1~endIdx를 더한다.

                    int minOpsNow = Integer.MAX_VALUE;

                    for (int k = 0; k < i+1; k++) {
                        int A = matrix[startIdx][0];
                        int B = matrix[startIdx+k][1];
                        int C = matrix[endIdx][1];
                        int ops = A*B*C;

                        ops += minOps[startIdx][startIdx+k];
                        ops += minOps[startIdx+k+1][endIdx];

                        minOpsNow = Math.min(minOpsNow, ops);
                    }

                    minOps[startIdx][endIdx] = minOpsNow;
                }
            }
        }

        // 0부터 N-1까지를 모두 고려한 최소 연산 값이 정답이다.
        int answer = minOps[0][N-1];

        System.out.println(answer);
    }
}

회고

틀렸던 이유
위의 문제 접근에서 이야기했듯이 행렬곱의 길이를 증가시키면서 최소 곱셈 횟수를 구할 때는 다양한 길이의 행렬을 곱할 수 있는데, 단순히 행렬을 뒤에 하나 덧붙인다고 생각하고 계산하여 틀렸다.

0개의 댓글