백준 C++ 10830 행렬 제곱

DoooongDong·2023년 3월 27일
0
post-thumbnail

문제 설명

문제: 백준 10830 행렬 제곱
난이도: 골드 4

문제 요약

  • 크기가 N*N인 행렬 A가 주어집니다.
  • 이때, A의 B제곱을 구하는 프로그램을 작성합니다.
  • 숫자가 너무 커질 수 있으므로, 각 원소를 1000으로 나눈 나머지를 출력합니다.
  • N은 2이상 5이하의 숫자이고, B는 최대 1000억입니다.
  • 첫 째 줄부터 N개의 줄에 걸쳐 행렬 A를 B제곱한 결과를 출력합니다.

문제 해결 방법

먼저 A라는 행렬이 있고 B가 5로 주어졌다고 하겠습니다.

그럼 구하고자 하는 결과는 A를 5번 곱한 A^5를 구해야합니다.

A^5를 구하는 가장 쉬운 방법은 진짜 A를 5번 곱해주는 것 입니다.

그럼 만약, B가 1000억이라면?
A^1000억을 구하기 위해서 A를 1000억번 곱해주어야하는데
이런 방법으로 구하게 된다면 시간복잡도가 O(N^3*B)가 되어 1초안에 풀리지 않습니다.

그래서 저희는 빠른 행렬 제곱 알고리즘을 사용해야합니다.

방법은 이렇습니다.

위와 똑같이 A라는 행렬이 있고 B가 5로 주어졌다고하겠습니다.

A^5를 다르게 표현하면 (A^2)^2 * A 입니다.

다른 경우도 보겠습니다. A^10을 다르게 표현하면 (A^5)^2 입니다.

B가 5로 홀수인 경우에는 5번 곱해주던 것을 B를 반으로 나눈 A^2를 구해서 A만 곱해주면 되는 것이죠.

짝수인 10일 경우에는 A^10을 구하기 위해서 10번 곱해주던 것을 A^5를 구하고 A^5를 한번더 곱해주기만 하면 됩니다.

절반만 계산하기 때문에 이때, 시간복잡도는 O(N^3*logB)가 됩니다.

코드를 보면서 더 이해해 봅시다.

전체 코드

#include<iostream>
#include<algorithm>
#include<vector>
using namespace std;
typedef long long ll;
typedef vector<vector<ll>> vll;
typedef vector<ll> vl;
ll n,b;
ll mod = 1000;

vll multi(vll& a, vll& b){ // 두 행렬을 곱해주는 함수입니다.
    vll ret(n, vl(n));
    for(int i=0; i<n; i++){
        for(int j=0; j<n; j++){
            for(int k=0; k<n; k++){
                ret[i][j] += a[i][k] * b[k][j];
                ret[i][j] %= mod;
            }
        }
    }
    return ret;
}

vll pow(vll& a, ll exp) { // A^exp 를 구하는 함수입니다.
    if(exp == 1) { // 지수가 1일 경우 mod 연산을 수행하고 A 행렬을 반환합니다.
        for(int i=0; i<n; i++) {
            for(int j=0; j<n; j++) {
                a[i][j] %= mod;
            }
        }
        return a;
    }
    vll half = pow(a, exp / 2); // 재귀적으로 지수를 절반씩 줄여가면서 절반만 계산해나갑니다.
    vll ret = multi(half, half); // A^10이라면 half는 A^5이고 multi 함수를 통해 A^5 * A^5를 수행합니다.
    if (exp % 2 == 1) // 지수가 홀수일 경우에는 half * half 한 뒤에 a 행렬을 한번 더 곱해줍니다.
    {
        ret = multi(ret, a);
    }
    return ret;
}

void printInfo(vll& a) {
    for(int i=0; i<n; i++){
        for(int j=0; j<n; j++){
            cout << a[i][j] << ' ';
        }
        cout << "\n";
    }
}

int main(void) {
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);
    cin >> n >> b;
    vll m(n, vl(n)); // n X n
    for(int i=0; i<n; i++) {
        for(int j=0; j<n; j++) {
            cin >> m[i][j];
        }
    }
    vll ret = pow(m, b);
    printInfo(ret);
    return 0;
}
profile
꺾이지 말자 :)

0개의 댓글