[백준] 10830번 - 행렬 제곱

chanyeong kim·2022년 5월 31일
0

백준

목록 보기
108/200
post-thumbnail

📩 출처

문제

크기가 N*N인 행렬 A가 주어진다. 이때, A의 B제곱을 구하는 프로그램을 작성하시오. 수가 매우 커질 수 있으니, A^B의 각 원소를 1,000으로 나눈 나머지를 출력한다.

입력

첫째 줄에 행렬의 크기 N과 B가 주어진다. (2 ≤ N ≤ 5, 1 ≤ B ≤ 100,000,000,000)

둘째 줄부터 N개의 줄에 행렬의 각 원소가 주어진다. 행렬의 각 원소는 1,000보다 작거나 같은 자연수 또는 0이다.

출력

첫째 줄부터 N개의 줄에 걸쳐 행렬 A를 B제곱한 결과를 출력한다.

👉 생각

  • 거듭 제곱을 분할정복으로 나타내면 다음과 같다.
# 거듭제곱 분할정복
def power(base, exponent):
    if exponent == 0 or base == 0:
        return 1
    
    if exponent % 2 == 0:
        newbase = power(base, exponent/2)
        return newbase * newbase
    else:
        newbase = power(base, (exponent-1)/2)
        return (newbase * newbase) * base
  • 행렬의 곱을 구하는 함수 matrix를 만들어서 위의 거듭제곱 분할정복에 곱 대신에 넣어주면 된다.
  • 그렇지만 b의 최대값이 매우 크기 때문에 시간초과가 발생한다. 따라서 재귀함수에서 중복 호출을 막기 위해 메모이제이션이 필요하다.
  • 아래와 같이 한번 호출 된 적이 있다면 바로 그값을 리턴해 버리는 것이다. check 안에 들어가는 값은 새롭게 재귀를 보내기 전에 값을 할당하도록 했다.
    if int(exponent) in check:
        return check[int(exponent)]
  • 정답
import sys
n, b = map(int, input().split())
arr = [list(map(int, sys.stdin.readline().split())) for _ in range(n)]

def matrix(matrix1, matrix2):
    answer = [[0]*n for _ in range(n)]
    matrix2 = list(zip(*matrix2))
    for i in range(n):
        for j in range(n):
            tmp = 0
            for k in range(n):
                tmp += matrix1[i][k] * matrix2[j][k]
            answer[i][j] = tmp % 1000
    return answer

check = {}
def power(base, exponent):
    if int(exponent) in check:
        return check[int(exponent)]

    if exponent == 1:
        for i in range(n):
            for j in range(n):
                base[i][j] %= 1000
        check[int(exponent)] = base
        return base

    if exponent % 2 == 0:
        newbase = power(base, exponent / 2)
        check[int(exponent)] = matrix(newbase, newbase)
        return matrix(newbase, newbase)
    else:
        newbase = power(base, (exponent - 1) / 2)
        check[int(exponent)] = matrix(matrix(newbase, newbase), base)
        return matrix(matrix(newbase, newbase), base)

arr = power(arr, b)
for num in check[b]:
    print(*num)

0개의 댓글