def strassen(A, B, size):
if size == 4:
return A @ B
d = size // 2
A11, A12, A21, A22 = A[:d, :d], A[:d, d:], A[d:, :d], A[d:, d:]
B11, B12, B21, B22 = B[:d, :d], B[:d, d:], B[d:, :d], B[d:, d:]
M1 = strassen(A11 + A22, B11 + B22,d)
M2 = strassen(A21 + A22, B11, d)
M3 = strassen(A11, B12 - B22, d)
M4 = strassen(A22, B21 - B11, d)
M5 = strassen(A11 + A12, B22, d)
M6 = strassen(A21 - A11, B11 + B12, d)
M7 = strassen(A12 - A22, B21 + B22, d)
C11 = M1 + M4 - M5 + M7
C12 = M3 + M5
C21 = M2 + M4
C22 = M1 - M2 + M3 + M6
return np.vstack((np.hstack((C11, C12)), np.hstack((C21, C22))))