[백준] #11049 행렬 곱셈 순서 (Java)

주재완·2023년 12월 29일
0

백준

목록 보기
3/8
post-thumbnail

중요하거나 어려웠던 문제에 대해 작성합니다.

2차원 DP


문제 📝

실제로 필자가 구현이 꽤 까다롭기도 하였고, 2차원 배열을 이용한 동적 계획법에서 꼭 알아두면 좋은 테크닉이 있어서 선정하게 되었다. 문제는 여기 있는 문제 링크를 클릭하면 된다.


해결 💡

Java ☕️

import java.util.*;
import java.io.*;

class Matrix { // 단순히 배열로 처리 가능 하나 편의를 위해 별도 클래스 생성
    int row = 0;	// 해당 행렬의 행
    int col = 0;	// 해당 행렬의 열
    int value = 0;	// 해당 행렬이 되기까지 연산 횟수
	
    // 생성자 설정
    public Matrix(int row, int col) {
        this.row = row;
        this.col = col;
    }
}

public class Main {
    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        int N = Integer.parseInt(br.readLine());
        Matrix[] matrixs = new Matrix[N + 1];
        Matrix[][] dp = new Matrix[N + 1][N + 1];
		// dp의 행: 연산의 시작점, dp의 열: 연산의 끝점
        // 헷갈리기 쉬운데, 시작과 끝 index만 사용한다고 생각하면 헷갈리지 않는다.
        
        // 입력 받는 for문
        for(int i = 1; i <= N; i++) {
            StringTokenizer st = new StringTokenizer(br.readLine());
            int row = Integer.parseInt(st.nextToken());
            int col = Integer.parseInt(st.nextToken());
            matrixs[i] = new Matrix(row, col);
        }
		
        // 각 dp항에 행렬의 행, 열을 넣는 작업
        for(int first = 1; first <= N; first++) {
            for(int last = 1; last <= N; last++) {
                dp[first][last] = new Matrix(matrixs[first].row, matrixs[last].col);
                // dp의 행: 연산의 시작점, dp의 열: 연산의 끝점
                // 행렬의 곱셈 결과는 행렬의 크기가 (첫 행렬의 행)X(마지막 행렬의 열)이 됨을 잊지 말자.
            }
        }
		
        // 각 dp항의 value를 구하는 작업
        // 예제를 계산할 때 우선 진행 과정 아래와 같다.
        // 첫번째 to 두번째, 두번째 to 세번째, 첫번째 to 세번째
        // dp index로는 (1,2),(2,3),(1,3) 이다.
        // 값을 보면 가장 안쪽 루프(가장 자주 변하는 값)는 시작점이고, 그다음으로 시작점과 끝점의 간격을 두면 된다.
        // 시작점: first, 간격: length, 끝점은 당연히 last = first + length
        for(int length = 1; length < N; length++) {
            for(int first = 1; first + length <= N; first++) {
                int last = first + length;
                dp[first][last].value = Integer.MAX_VALUE;
                // min을 for문 돌려 구하므로 주어진 범위의 최댓값 보다 크게 초기화한다.
                // first ~ i 연산 결과까지 행렬(이하 앞 행렬), i+1 ~ last 까지 행렬(이하 뒤 행렬) 두 행렬을 서로 곱하는 연산을 시행한다.
                // 따라서 i는 first 이상 last 미만이다.
                for(int i = first; i < last; i++) {
                	// n, m, k는 문제에도 나와있는 기본적인 연산 횟수를 구할 때 사용한다.
                    // 다만 앞뒤 행렬을 연산할 때, nmk 연산에 추가적으로 앞뒤 행렬을 얻기까지 필요한 연산도 필요하다.
                    // 앞뒤 행렬을 구하는 데까지 사용한 value 들 역시 더하자.
                    // 앞 행렬 얻기까지 연산 = dp[first][i].value
                    // 뒤 행렬 얻기까지 연산 = dp[i+1][last].value
                    // 기본적으로 행렬을 곱할 때 필요한 연산 = n * m * k
                    // 윗 값들을 다 더하자.
                    int n = dp[first][i].row;
                    int m = dp[first][i].col;
                    int k = dp[i + 1][last].col;
                    int tmp = dp[first][i].value + dp[i + 1][last].value + n * m * k;
                    
                    dp[first][last].value = Math.min(dp[first][last].value, tmp);
                }
            }
        }

        System.out.println(dp[1][N].value);
        br.close();
    }
}
profile
언제나 탐구하고 공부하는 개발자, 주재완입니다.

0개의 댓글