[백준][10830] 행렬 제곱

suhan0304·2023년 11월 7일
0

백준

목록 보기
26/53
post-thumbnail

문제

  • 크기가 N*N인 행렬 A가 주어질 때 행렬 A의 B 제곱을 구하여라

입력

  • 첫째 줄, 행렬의 크기 N과 B가 주어진다.
  • 둘째 줄, N개의 줄에 행렬의 각 원소가 주어진다.

출력

  • 첫째 줄에부터 N개의 줄에 걸쳐 행렬 A를 B 제곱한 결과를 출력한다.

풀이

문제의 풀이에 앞서 행렬 곱을 구하는 방법은 다음과 같다.

i,j 원소 값 += i번째 행의 k 원소 * j번째 열의 k 원소

def matrix_square(a, b):
    ans = [[0 for _ in range(N)] for _ in range(N)]
    for i in range(N):
        for j in range(N):
            for k in range(N):
                ans[i][j] += a[i][k] * b[k][j]
                ans[i][j] %= 1000
    return ans

B가 100,000,000,000까지 입력 가능하다는 것을 보고 가장 먼저 떠오른 것은 단순 제곱으로는 실행 시간 초과를 해결 할 수 없다는 것이었다. 따라서 시간 복잡도를 줄이기 위해 제곱을 거듭제곱을 구하는 건 어떨까 하고 생각했다. 따라서 다음과 같이 리스트를 만들고 필요한 값만 꺼내고 행렬 곱으로 계산하면 B번 제곱하는 것보다 log B번 계산하는 것이 시간 복잡도가 매우 작기 때문에 실행 시간 초과가 발생하지 않을 것이라고 생각했다.

12345---
A1A^1A2A^2A4A^4A8A^8A16A^{16}···

이 표를 보고나서 바로 2진법이 떠올랐다. 만약 B가 11이라고 하면 A8A^8 * A2A^2 * A1A^1 로 표현할 수 있고 이는 11을 2진법으로 표현한 1011의 자리에 맞는 표의 값들을 곱한 것이다. 따라서 위의 방식처럼 코드를 구현하면 다음과 같다.

idx = bin(B)[:1:-1]
ans = []
for i in range(len(idx)):
    if idx[i] == "1":
        if len(ans) == 0:
            ans = arr[i]
        else:
            ans = matrix_square(ans, arr[i])

위의 bin으로 idx를 2진수로 바꾼후 거꾸로 배치한다. (위 표를 보면 작은 자리수가 앞쪽에 있기 때문에) 그 이후 2진수의 값 중 1이 들어있는 값의 행렬 결과를 가지고 와서 ans와 곱해서 제곱해준다. 이러한 방식으로 2진수를 모두 돌고 나면 우리가 원하는 A의 값을 구할 수 있다.

실제로 시간 복잡도가 많이 줄었을까? 만약 B가 최대 100,000,000,000 들어온다면?

  • O(B) = 100,000,000,000
  • O(log2(B)log_2(B)) = 36.5$

시간 복잡도 O(logN)이 얼마나 좋은 시간 복잡도인지 새삼 느껴졌다.


코드

import sys

input = sys.stdin.readline


def matrix_square(a, b):
    ans = [[0 for _ in range(N)] for _ in range(N)]
    for i in range(N):
        for j in range(N):
            for k in range(N):
                ans[i][j] += a[i][k] * b[k][j]
                ans[i][j] %= 1000
    return ans


N, B = map(int, input().split())

A = []
for _ in range(N):
    A.append(list(map(int, input().split())))
arr = [A]
for i in range(len(bin(B)[2:]) - 1):
    arr.append(matrix_square(arr[i], arr[i]))
idx = bin(B)[:1:-1]
ans = []
for i in range(len(idx)):
    if idx[i] == "1":
        if len(ans) == 0:
            ans = arr[i]
        else:
            ans = matrix_square(ans, arr[i])
for line in ans:
    for x in line:
        print(x % 1000, end=" ")
    print()
profile
Be Honest, Be Harder, Be Stronger

0개의 댓글