쉬트라쎈? 행렬곱쎈?

서정윤·2022년 4월 9일
0

알고리고리

목록 보기
3/3

쉬트라쎈 행렬 곱셈

행렬 곱셈 문제

  • 문제:두 n * n 행렬의 곱을 구하시오
  • 일반적인 행렬 곱셈의 시간복잡도는 ∈ Θ(n^3)
  • 쉬트라센의 방법을 사용해서 행렬 곱셈의 시간 복잡도(∈ Θ(n^
    2.81))를 더 줄여보자~

쉬트라쎈의 방법

일반적인 행렬 곱셈은 8번의 곱셈과 4번의 덧셈을 해야됨 하지만 쉬트라쎈은 7번의 곱셈과 18번의 덧셈/뺄셈!! 곱셈은 덧셈보다 더 연산 부담이 많다

쉬트라쎈의 방법:분할정복

  • 큰 행렬을 네 개의 부분 행렬로 나누어서 정복하자

Strassen's Matrix Multiplication

def strassen (A, B):
    n = len(A)
    if (n <= threshold):
        return matrixmult(A, B)
    A11, A12, A21, A22 = divide(A)
    B11, B12, B21, B22 = divide(B)
    M1 = strassen(madd(A11, A22), madd(B11, B22))
    M2 = strassen(madd(A21, A22), B11)
    M3 = strassen(A11, msub(B12, B22))
    M4 = strassen(A22, msub(B21, B11))
    M5 = strassen(madd(A11, A12), B22)
    M6 = strassen(msub(A21, A11), madd(B11, B12))
    M7 = strassen(msub(A12, A22), madd(B21, B22))
    return conquer(M1, M2, M3, M4, M5, M6, M7)

n이 threshold보다 크면 행렬을 divide하자

def divide(A):
    n = len(A)
    m = n // 2
    A11 = [[0] * m for _ in range(m)]
    A12 = [[0] * m for _ in range(m)]
    A21 = [[0] * m for _ in range(m)]
    A22 = [[0] * m for _ in range(m)]
    for i in range(m):
        for j in range(m):
            A11[i][j] = A[i][j]
            A12[i][j] = A[i][j + m]
            A21[i][j] = A[i + m][j]
            A22[i][j] = A[i + m][j + m]

    return A11, A12, A21, A22
def conquer(M1, M2, M3, M4, M5, M6, M7):
    C11 = madd(msub(madd(M1, M4), M5), M7)
    C12 = madd(M3, M5)
    C21 = madd(M2, M4)
    C22 = madd(msub(madd(M1, M3), M2), M6)
    m = len(C11)
    n = 2 * m
    C = [[0] * n for _ in range(n)]
    for i in range(m):
        for j in range(m):
            C[i][j] = C11[i][j]
            C[i][j + m] = C12[i][j]
            C[i + m][j] = C21[i][j]
            C[i + m][j + m] = C22[i][j]

    return C
def madd (A, B):
    n = len(A)
    C = [[0] * n for _ in range(n)]
    for i in range(n):
        for j in range(n):
        C[i][j] = A[i][j] + B[i][j]
    return C
def msub (A, B):
    n = len(A)
    C = [[0] * n for _ in range(n)]
        for i in range(n):
            for j in range(n):
                C[i][j] = A[i][j] - B[i][j]
    return C
def matrixmult (A, B):
    n = len(A)
    C = [[0] * n for _ in range(n)]
    for i in range(n):
        for j in range(n):
            for k in range(n):
                C[i][j] += A[i][k] * B[k][j]
    return C

3개의 댓글

comment-user-thumbnail
2022년 4월 9일

제 이번 중간고사 범위에 있는데.. 고마워요..

2개의 답글

관련 채용 정보