P10830. 행렬 제곱

wnajsldkf·2023년 4월 19일
0

Algorithm

목록 보기
50/58
post-thumbnail

설명

P1629 곱셈과 재귀적으로 작은 단위로 나누어 계산한다는 점에서 비슷한 느낌을 받았다. 이번 글에서는 어떻게 나누는지와 행렬의 곱을 계산하는 것을 함께 정리해보겠다.

문제를 보면 N*N 행렬 A가 주어지고 이것을 B제곱하는 프로그램이다.

B라는 값을 2로 나누면서 가장 작은 단위로 쪼개고, 그 단위에서 두 행렬끼리 곱셈 연산을 한다. 가장 작은 단위인 b == 1 까지 도달하면 숫자가 커지는 것을 방지하기 위해 1000으로 나눈 값을 가지고 계산한다.

행렬의 제곱 횟수는 B가 짝수인지, 홀수인지에 따라 구분된다.

def square(matrix, b):
	if b == 1:
    	for y in range(N):
			for x in range(N):
            	matrix[y][x] %= 1000
		return matrix               
    
    # 짝수인 경우
    if b % 2 == 0:
    	return cal(square(matrix, b // 2), square(matrix, b // 2))
	# 홀수인 경우
    else:
    	return cal(cal(square(matrix, b // 2), square(matri, b // 2)), matrix)

행렬 곱셈

가장 작은 단위로 행렬을 구했다면, 이제 가장 작은 단위의 두 행렬을 곱한다. (관련문제: P2740. 행렬 곱셈)
두 행렬의 곱을 생각해보면 다음과 같다.

두 행렬을 계산하는데 3중 반복문이 사용되었는데, 행렬 A는 ROW 방향으로 행렬B는 COL 방향으로 이동하면서 계산하여 합산하는 방식이다.

matrixA = [[1,2], [3,4], [5,6]]
matrixB = [[-1,-2,0], [0,0,3]]

result = [[0]*3]*2

for row in range(3):
    for col in range(3):
        total = 0
        for i in range(2):
            total += matrixA[row][i] * matrixB[i][col]
        result[row][col] = total

for r in result:       
    print(*r)

# 실행 결과
# -1 -2 6
# -3 -6 12
# -5 -10 18

코드

from sys import stdin as s

s = open("input.txt", "rt")

N, B = map(int, s.readline().split())  # N: 행렬크기, B:제곱횟수
matrix = [0] * N

for i in range(N):
    matrix[i] = list(map(int, s.readline().split()))


def cal(matrixA, matrixB):
    n = len(matrixA)
    result = [[0] * n for _ in range(n)]

    for i in range(n):
        for j in range(n):
            total = 0
            for k in range(n):
                total += matrixA[i][k] * matrixB[k][j]
            result[i][j] = total % 1000

    return result


def square(matrix, b):
    # 행렬의 모든 원소를 1000으로 나눠준다.
    if b == 1:
        for y in range(N):
            for x in range(N):
                matrix[y][x] %= 1000
        return matrix

    # b가 짝수인 경우
    if b % 2 == 0:
        return cal(square(matrix, b // 2), square(matrix, b // 2))
    # b가 홀수인 경우
    else:
        return cal(cal(square(matrix, b // 2), square(matrix, b // 2)), matrix)


result = square(matrix, B)

for r in result:
    print(*r)
profile
https://mywnajsldkf.tistory.com -> 이사 중

0개의 댓글