[백준]행렬 곱셈 순서 with Java

hyeok ryu·2024년 3월 10일
1

문제풀이

목록 보기
94/154

문제

https://www.acmicpc.net/problem/11049


입력

첫째 줄에 행렬의 개수 N이 주어진다.
둘째 줄부터 N개 줄에는 행렬의 크기 r과 c가 주어진다.
항상 순서대로 곱셈을 할 수 있는 크기만 입력으로 주어진다.


출력

첫째 줄에 입력으로 주어진 행렬을 곱하는데 필요한 곱셈 연산의 최솟값을 출력한다.
정답은 2^31-1 보다 작거나 같은 자연수이다.
또한, 최악의 순서로 연산해도 연산 횟수가 2^31-1보다 작거나 같다.


풀이

제한조건

  • N(1 ≤ N ≤ 500).
  • (1 ≤ r, c ≤ 500)

접근방법

DP, 분할정복

메모이제이션 없이 모든 경우(곱셈 순서)를 고려하여 계산할 경우 시간초과가 발생할 것이다.
연산 횟수를 줄일 방법을 생각해보자.

2개의 행렬이 있다고 생각해보자.
AB 하나의 순서가 존재한다.

3개의 행렬이 있다고 가정하면,
(AB)C, A(BC) 2가지의 순서가 발생한다.

이때 AB 과정에서 몇 번의 연산이 발생했는지를 기록해둔다면, 행렬이 많아 졌을 때, 중복되는 계산을 줄일 수 있을 것이다.

즉 a번째 행렬부터 b번째 행렬까지의 곱셈 순서를 기록해 둔다면, 중복되는 계산을 줄일 수 있다.

어떤걸 메모이제이션 해야할 지 생각을 했으니, 이걸 문제에 어떻게 적용시킬 지 생각해보자.
마치 분할정복처럼 생각하면 쉽지 않을까.

전체를 구하기 위해서 중간지점을 설정하고, 구간을 나누어 계산하는 방식으로 진행해보자.

dp[left][right] = dp[left][mid] + dp[mid+1][right] + 두 행렬의 연산 비용
위와 같이 식을 둔다면, 모든 경우에 대해서 탐색이 가능하며, 메모이제이션을 적용할 수 있을 것다.

DP문제를 생각할 때,
피보나치에서 반복문 또는 재귀를 통한 메모이제이션 뿐만아니라
분할 정복+메모이제이션으로 하면 생각하기 쉬운 문제들이 많다!


코드

import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.util.Arrays;

public class Main {
	static class Matrix {
		int r;
		int c;

		Matrix(int a, int b) {
			r = a;
			c = b;
		}
	}

	static int N;
	static Matrix[] arr;
	static int[][] dp;

	public static void main(String[] args) throws Exception {
		BufferedReader in = new BufferedReader(new InputStreamReader(System.in));

		N = stoi(in.readLine());
		arr = new Matrix[N + 1];
		dp = new int[N + 1][N + 1];
		for (int i = 0; i < N; ++i) {
			String[] inputs = in.readLine().split(" ");
			arr[i + 1] = new Matrix(stoi(inputs[0]), stoi(inputs[1]));
		}

		int result = search(1, N);
		System.out.println(result);
	}

	private static int search(int left, int right) {
		if (left == right)
			return 0;

		if (dp[left][right] != 0)
			return dp[left][right];

		dp[left][right] = Integer.MAX_VALUE;
		for (int mid = left; mid < right; ++mid) {
			int leftCount = search(left, mid);
			int rightCount = search(mid + 1, right);
			int sum = leftCount + rightCount + arr[left].r * arr[mid].c * arr[right].c;
			dp[left][right] = Math.min(dp[left][right], sum);
		}
		return dp[left][right];
	}

	public static int stoi(String s) {
		return Integer.parseInt(s);
	}
}

0개의 댓글