Strassen

spring·2022년 10월 21일
0
def strassen(A, B, size):
    # 1. 슈트라센은 2^n @ 2^n 만 지원한다.
    if size == 4:
        return A @ B
    d = size // 2
    A11, A12, A21, A22 = A[:d, :d], A[:d, d:], A[d:, :d], A[d:, d:]
    B11, B12, B21, B22 = B[:d, :d], B[:d, d:], B[d:, :d], B[d:, d:]
    M1 = strassen(A11 + A22, B11 + B22,d)
    M2 = strassen(A21 + A22, B11, d)
    M3 = strassen(A11, B12 - B22, d)
    M4 = strassen(A22, B21 - B11, d)
    M5 = strassen(A11 + A12, B22, d)
    M6 = strassen(A21 - A11, B11 + B12, d)
    M7 = strassen(A12 - A22, B21 + B22, d)

    C11 = M1 + M4 - M5 + M7
    C12 = M3 + M5
    C21 = M2 + M4
    C22 = M1 - M2 + M3 + M6

    return np.vstack((np.hstack((C11, C12)), np.hstack((C21, C22))))
profile
Researcher & Developer @ NAVER Corp | Designer @ HONGIK Univ.

0개의 댓글