행렬 제곱 - 분할정복 응용

조해빈·2023년 3월 16일
0

백준

목록 보기
25/53

행렬 제곱 - 10830번

문제

크기가 N*N인 행렬 A가 주어진다. 이때, A의 B제곱을 구하는 프로그램을 작성하시오. 수가 매우 커질 수 있으니, A^B의 각 원소를 1,000으로 나눈 나머지를 출력한다.

입력
첫째 줄에 행렬의 크기 N과 B가 주어진다. (2 ≤ N ≤ 5, 1 ≤ B ≤ 100,000,000,000)

둘째 줄부터 N개의 줄에 행렬의 각 원소가 주어진다. 행렬의 각 원소는 1,000보다 작거나 같은 자연수 또는 0이다.

출력
첫째 줄부터 N개의 줄에 걸쳐 행렬 A를 B제곱한 결과를 출력한다.

풀이

우선, 행렬을 입력받아 그 둘을 제곱하는 코드는 다음과 같을 것이다. 참고로, 거듭제곱이 아니라 한 번 제곱하는 수식이다.

import sys
input = sys.stdin.readline
N, B = map(int, input().split())
A = [ list(map(int, input().split())) for _ in range(N) ]

arr = [ [0 for _ in range(N)] for _ in range(N) ]

for r in range(N):
    for c in range(N):
        sum = 0
        for i in range(N):
            sum += A[r][i] * A[i][c]
            arr[r][c] = sum%1000

print(arr)

A[r][i] * A[i][c] 부분은 즉 matrix1의 한 행과 matrix2의 한 열을 곱하는 것이다.

위를 def화 시키면 다음과 같은 것이다. (길이가 같은 다른 두 행렬을 곱한다는 전제다.)

def mul(matrix1, matrix2):
    n = len(matrix1)
    arr = [ [0 for _ in range(N)] for _ in range(N) ]
    for r in range(n):
        for c in range(n):
            sum = 0
            for i in range(n):
                sum += matrix1[r][i] * matrix2[i][c]
                arr[r][c] = sum%1000
    return arr

이제 우리가 만든 이 함수 mul()로 거듭제곱을 할 것이다. 거듭제곱은 분할정복으로 푸는 아이디어를 그대로 사용한다. 다음은 우리가 만들 행렬 거듭제곱 함수 square()의 초기 세팅이다.

def square(a, b):
    if b==1:
        for x in range(len(a)):
            for y in range(len(a)):
                a[x][y] %= 1000
        return a

    tmp = square(a, b//2)
    tmp = mul(tmp, tmp)
    .
    .
    .
    .

함수는 행렬 a와 정수의 숫자 b를 인자로 받고 있다. if b==1: 의 코드를 보면, 주어진 제곱의 횟수 B가 1일 때는 원소마다 직접 1000으로 나누는 연산을 해준 뒤 바로 연산처리된 행렬 a를 프린트한다.

그 외의 경우들은 모두 변수 tmp를 선언하는데 이는 인자 b를 2분할한 채 호출된 함수 square이다.

이어 더 적어 보자.

	.
    .
    tmp = square(a, b//2)
    tmp = mul(tmp, tmp)
    if b%2==0:
        return tmp
    else:
        return mul(tmp, a)

result = square(A, B)
for r in result:
    print(*r)

if b%2: 즉 B가 홀수인 경우, 인자 b를 2분할 것에 나머지 1회를 더 제곱해줘야 하므로 mul(tmp)에 a를 한 번 더 mul 해준다.

다음의 코드는 최종적으로 정답이다.

import sys
input = sys.stdin.readline
N, B = map(int, input().split())
A = [ list(map(int, input().split())) for _ in range(N) ]


def mul(matrix1, matrix2):
    n = len(matrix1)
    arr = [ [0 for _ in range(N)] for _ in range(N) ]
    for r in range(n):
        for c in range(n):
            sum = 0
            for i in range(n):
                sum += matrix1[r][i] * matrix2[i][c]
                arr[r][c] = sum%1000
    return arr

def square(a, b):
    if b==1:
        for x in range(len(a)):
            for y in range(len(a)):
                a[x][y] %= 1000
        return a
    tmp = square(a, b//2)
    tmp = mul(tmp, tmp)
    if b%2==0:
        return tmp
    else:
        return mul(tmp, a)

result = square(A, B)
for r in result:
    print(*r)
profile
JS, CSS, HTML, React etc

0개의 댓글