백준 알고리즘 13246번 : 행렬 제곱의 합(제출 예정)

Zoo Da·2021년 12월 11일
0

백준 알고리즘

목록 보기
288/337
post-thumbnail

링크

https://www.acmicpc.net/problem/13246

sol1) 행렬 연산 구현

#pragma GCC target("avx,avx2,fma")
#pragma GCC optimize("Ofast")
#pragma GCC optimize("unroll-loops")
#include <bits/stdc++.h>
#define fastio ios::sync_with_stdio(0), cin.tie(0), cout.tie(0)
#define int int64_t
using namespace std;

using ll = long long;
using matrix = vector<vector<ll>>;
const int MOD = 1000;

struct Matrix
{
    int sz;
    Matrix(int n)
    {
        sz = n;
    }
    matrix unitMatrix()
    {
        matrix I(sz, vector<ll>(sz));
        for (int i = 0; i < sz; i++)
            I[i][i] = 1;
        return I;
    }

    matrix matrix_mul(matrix a, matrix b)
    {
        matrix ret(sz, vector<ll>(sz, 0));
        for (int i = 0; i < sz; i++)
        {
            for (int j = 0; j < sz; j++)
            {
                for (int k = 0; k < sz; k++)
                {
                    ret[i][j] += a[i][k] * b[k][j];
                }
                ret[i][j] %= MOD;
            }
        }
        return ret;
    }

    matrix matrix_plus(matrix a, matrix b)
    {
        matrix ret(sz, vector<ll>(sz, 0));
        for (int i = 0; i < sz; i++)
        {
            for (int j = 0; j < sz; j++)
            {
                ret[i][j] = a[i][j] + b[i][j];
                ret[i][j] %= MOD;
            }
        }
        return ret;
    }

    matrix matrix_pow(matrix x, int n)
    {
        matrix ret = unitMatrix();
        for (; n; n >>= 1)
        {
            if (n & 1)
                ret = matrix_mul(ret, x);
            x = matrix_mul(x, x);
        }
        return ret;
    }

    void printMatrix(const matrix x)
    {
        const int n = x.size();
        for (int i = 0; i < n; i++)
        {
            for (int j = 0; j < n; j++)
            {
                cout << x[i][j] << ' ';
            }
            cout << "\n";
        }
    }
};

int32_t main()
{
    fastio;
    int n, b;
    cin >> n >> b;
    Matrix M(n);
    matrix base(n, vector<ll>(n));
    matrix ans(n, vector<ll>(n, 0));
    for (int i = 0; i < n; i++)
        for (int j = 0; j < n; j++)
            cin >> base[i][j];
    for (int i = 1; i <= b; i++)
    {
        auto ret = M.matrix_pow(base, i);
        ans = M.matrix_plus(ret, ans);
    }
    M.printMatrix(ans);
}

기초적인 행렬의 연산을 구현해주면 되는 문제였습니다.

profile
메모장 겸 블로그

0개의 댓글