[11049/DP] 행렬 곱셈 순서 (JAVA)

Jiwoo Kim·2020년 11월 13일
0

알고리즘 정복하기

목록 보기
2/85
post-thumbnail
post-custom-banner

문제

문제


풀이

DP 알고리즘 문제들 중 유명한 문제라고 한다. 찾아보기 전엔 몰랐음 ㅎ;

반복문으로 DP를 풀 수도 있지만 반복문은 딱 봤을 때 이해하기 힘들어서 재귀로 풀었다.
재귀가 실행속도는 조금 더 느리긴 한데, 이름도 더 만족스럽게 지을 수 있고 이해도 더 잘됐다.

1번부터 N번까지 행렬을 모두 곱할 때의 최소 카운트를 구하는 문제다.

DP는 큰 문제를 작은 문제로 쪼개는 것이니까, 작은 범위에서부터 최솟값을 구하고 그 최솟값을 바탕으로 다음 범위의 최솟값을 구하면 된다.

각 변수들은 다음을 의미한다.

rows[i]: i번째 행렬의 열 갯수
cols[i]: i번째 행렬의 행 갯수
dp[i][j]: i번부터 j번까지의 행렬을 곱할 때의 최소 카운트

getMinCount() 내부의 반복문을 돌 때 모든 경우의 수 중 최솟값을 찾아낼 수 있도록, i번째 행렬 기준으로 좌우의 행렬 2개를 곱할 때의 카운트를 구한다. 왼쪽 행렬 만드는 카운트 + 오른쪽 행렬 만드는 카운트 + 둘을 곱하는 카운트를 구하면 전체 카운트를 구할 수 있다.

종료 조건으로는 start==end 일 때 0 리턴만 넣었다가 시간초과를 받았다.
그래서 중복 계산을 줄이기 위해 이미 계산된 값이 있으면 바로 리턴하는 조건을 추가했다.


코드

// 백준 11049 '행렬 곱셈 순서'
// DP
// 2020.08.17

import java.io.*;
import java.util.Arrays;
import java.util.StringTokenizer;

public class Main {

    static int N;

    static int[] rows = new int[502];
    static int[] cols = new int[502];
    static int[][] dp = new int[502][502];

    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(System.out));

        // Get inputs
        StringTokenizer tk = new StringTokenizer(br.readLine());
        N = Integer.parseInt(tk.nextToken());

        for (int i = 1; i <= N; i++) {
            tk = new StringTokenizer(br.readLine());
            rows[i] = Integer.parseInt(tk.nextToken());
            cols[i] = Integer.parseInt(tk.nextToken());
            Arrays.fill(dp[i], Integer.MAX_VALUE);
        }

        // Print result
        bw.write(Integer.toString(getMinCount(1, N)));

        br.close();
        bw.close();
    }

    private static int getMinCount(int start, int end) {

        if (start == end)
            return 0;

        if (dp[start][end] != Integer.MAX_VALUE) {
            return dp[start][end];
        }

        for (int i = start; i < end; i++) {
            int cost = getMinCount(start, i) + getMinCount(i + 1, end) + rows[start] * cols[i] * cols[end];
            dp[start][end] = Math.min(dp[start][end], cost);
        }
        return dp[start][end];
    }
}
post-custom-banner

0개의 댓글