[PyTorch] Torch.matmul : 텐서간의 곱셈 방식

seonjin2·2023년 8월 11일
0
post-custom-banner
  1. 벡터(1D) - 벡터(1D) 간의 곱셈

    • 벡터간의 내적값, 스칼라(scala) 값을 리턴
    • (예시) [2,2,2][333]=2×3+2×3+2×3=18[2, 2, 2] \cdot \begin{bmatrix} 3 \\ 3 \\ 3 \end{bmatrix} = 2 \times 3 + 2 \times 3 + 2 \times 3 = 18
    
    import torch
    A = torch.full((3, ), 2) # 벡터 생성
    B = torch.full((3, ), 3) # 벡터 생성
    
    result = torch.matmul(A, B)
    print (result, result.shape) # 스칼라값 리턴
    
    >>> tensor(18) torch.Size([])
  2. 행렬(2D) - 행렬(2D) 간의 곱셈

    • Matrix Multiplication, Inner Product, Dot Product
    • (조건) 앞 행렬의 열 차원과 뒷 행렬의 행 차원이 같아야 곱셈 가능
    • (예시) [222222][333333]=[6+66+66+66+66+66+66+66+66+6]\begin{bmatrix} 2 & 2 \\ 2 & 2 \\ 2 & 2 \end{bmatrix} * \begin{bmatrix} 3 & 3 & 3 \\ 3 & 3 & 3 \end{bmatrix} = \begin{bmatrix} 6 + 6 & 6 + 6 & 6 + 6 \\ 6 + 6 & 6 + 6 & 6 + 6 \\ 6 + 6 & 6 + 6 & 6 + 6 \end{bmatrix}
    
    A = torch.full((3, 2), 2) # 행렬 생성
    B = torch.full((2, 3), 3) # 행렬 생성
    
    result = torch.matmul(A, B)
    print (result ,result.size())
    
    >>> tensor([[12, 12, 12],
                  [12, 12, 12],
                  [12, 12, 12]]) torch.Size([3, 3])
  3. 벡터(1D) - 행렬(2D) 간의 곱셈

    • 벡터가 행 벡터로 치환되며, (행렬이 아닌) 벡터로 반환

    • (조건) 벡터의 차원과 행렬의 행 차원이 같아야 곱셈 가능

    • (예시) [222]×[333333][2 2 2]×[333333]=[2×3+2×3+2×32×3+2×3+2×3]\begin{bmatrix} 2 \\ 2 \\2 \end{bmatrix} \times \begin{bmatrix} 3 & 3 \\ 3 & 3 \\ 3 & 3 \end{bmatrix} \rightarrow [ 2 \ 2 \ 2 ] \times \begin{bmatrix} 3 & 3 \\ 3 & 3 \\ 3 & 3 \end{bmatrix} = \begin{bmatrix} 2 \times 3 + 2 \times 3 + 2 \times 3 & 2 \times 3 + 2 \times 3 + 2 \times 3 \end{bmatrix}

      
      x = torch.full((3, ), 2)  # 벡터 생성
      A = torch.full((3, 2), 3) # 행렬 생성
      result = torch.matmul(x, A)
      print (result, result.size())
      
      >>> tensor([18, 18]) torch.Size([2])
  4. 벡터(1D) - 텐서(3D 이상) 간의 곱셈

    • 텐서(3D 이상)은 batched matrix로 간주 : (batch, 행, 열)

      • 1차원 감소한 행렬 반환
    • (조건) 벡터의 차원과 텐서의 뒤에서 두번째 차원이 동일해야 곱셈 가능

    • (예시) batch1:([2 2 2]×[333333]=[2×3+2×3+2×32×3+2×3+2×3],  batch2:([2 2 2]×[333333]=[2×3+2×3+2×32×3+2×3+2×3]batch1 \, : ([ 2 \ 2 \ 2 ] \times \begin{bmatrix} 3 & 3 \\ 3 & 3 \\ 3 & 3 \end{bmatrix} = \begin{bmatrix} 2 \times 3 + 2 \times 3 + 2 \times 3 & 2 \times 3 + 2 \times 3 + 2 \times 3 \end{bmatrix},\\ \qquad \; batch2 \, : ([ 2 \ 2 \ 2 ] \times \begin{bmatrix} 3 & 3 \\ 3 & 3 \\ 3 & 3 \end{bmatrix} = \begin{bmatrix} 2 \times 3 + 2 \times 3 + 2 \times 3 & 2 \times 3 + 2 \times 3 + 2 \times 3 \end{bmatrix}

      
      x = torch.full((3, ), 2)     # 벡터 생성
      A = torch.full((2, 3, 2), 3) # 텐서 생성
      result = torch.matmul(x, A)
      print (result, result.size())
      
      >>> tensor([[18, 18],
                  [18, 18]]) torch.Size([2, 2])
  5. 행렬(2D) - 벡터(1D) 간의 곱셈

    • 벡터는 열 벡터로 활용, (행렬이 아닌) 벡터로 반환

    • (조건) 행렬의 열 차원과 벡터의 차원이 동일

    • (예시) [333333]×[22]=[3×2+3×23×2+3×23×2+3×2]\begin{bmatrix} 3 & 3 \\ 3 & 3 \\ 3 & 3 \end{bmatrix} \times \begin{bmatrix} 2 \\ 2 \end{bmatrix} = \begin{bmatrix} 3 \times 2 + 3 \times 2 & 3 \times 2 + 3 \times 2 & 3 \times 2 + 3 \times 2 \end{bmatrix}

      
      x = torch.full((2, ), 2)      # 벡터 생성
      A = torch.full((3, 2), 3)     # 행렬 생성
      result = torch.matmul(A, x)
      print (result, result.size())
      
       >>> tensor([12, 12, 12]) torch.Size([3])
  6. 텐서(3D 이상) - 벡터(1D) 간의 곱셈

    • 텐서(3D 이상)은 batched matrix로 간주 : (batch, 행, 열)

      • 1차원 감소한 행렬 반환
    • (조건) 텐서의 가장 뒷 차원과 벡터의 차원이 동일해야 곱셈 가능

    • (예시) batch1:[333333]×[22]=[3×2+3×23×2+3×23×2+3×2],  batch2:[333333]×[22]=[3×2+3×23×2+3×23×2+3×2]batch1 \, : \begin{bmatrix} 3 & 3 \\ 3 & 3 \\ 3 & 3 \end{bmatrix} \times \begin{bmatrix} 2 \\ 2 \end{bmatrix} = \begin{bmatrix} 3 \times 2 + 3 \times 2 & 3 \times 2 + 3 \times 2 & 3 \times 2 + 3 \times 2 \end{bmatrix}, \\ \qquad \; batch2 \,: \begin{bmatrix} 3 & 3 \\ 3 & 3 \\ 3 & 3 \end{bmatrix} \times \begin{bmatrix} 2 \\ 2 \end{bmatrix} = \begin{bmatrix} 3 \times 2 + 3 \times 2 & 3 \times 2 + 3 \times 2 & 3 \times 2 + 3 \times 2 \end{bmatrix}

      
      x = torch.full((2, ), 2)      # 벡터 생성
      A = torch.full((2, 3, 2), 3)  # 텐서 생성
      result = torch.matmul(A, x)
      print (result, result.size())
      
      >>> tensor([[12, 12, 12],
                  [12, 12, 12]]) torch.Size([2, 3])
      
  7. 텐서(3D 이상) - 텐서(3D 이상) 간의 곱셈

    • 텐서(3D 이상)은 batched matrix로 간주 : (batch, 행, 열)

    • Batched Matrix Multiplication : batch 만큼의 각 행렬이 곱해지는 방식

    • (조건)
      a. batch 차원이 동일
      b. 행렬(2D)간의 곱셈처럼, 앞 텐서의 가장 뒷 차원과 뒷 텐서의 뒤에서 두번째 차원이 동일

      • 단, batch 차원은 브로드캐스팅 적용 : (차원이 동일하지 않더라도) 한 텐서의 차원이 없거나 1일 경우 해당 차원이 확장
    • (예시) batch1:[222222][333333]=[6+66+66+66+66+66+66+66+66+6],  batch2:[222222][333333]=[6+66+66+66+66+66+66+66+66+6]batch1 \, :\begin{bmatrix} 2 & 2 \\ 2 & 2 \\ 2 & 2 \end{bmatrix} * \begin{bmatrix} 3 & 3 & 3 \\ 3 & 3 & 3 \end{bmatrix} = \begin{bmatrix} 6 + 6 & 6 + 6 & 6 + 6 \\ 6 + 6 & 6 + 6 & 6 + 6 \\ 6 + 6 & 6 + 6 & 6 + 6 \end{bmatrix}, \\ \qquad \; batch2 \,: \begin{bmatrix} 2 & 2 \\ 2 & 2 \\ 2 & 2 \end{bmatrix} * \begin{bmatrix} 3 & 3 & 3 \\ 3 & 3 & 3 \end{bmatrix} = \begin{bmatrix} 6 + 6 & 6 + 6 & 6 + 6 \\ 6 + 6 & 6 + 6 & 6 + 6 \\ 6 + 6 & 6 + 6 & 6 + 6 \end{bmatrix}

      
      data1 = torch.full((2, 3, 2), 2) # 텐서 생성
      data2 = torch.full((2, 2, 3), 3) # 텐서 생성
      
      result = torch.matmul(data1, data2) 
      print(result, data3.shape) 
      
      >>> tensor([[[12, 12, 12],
                   [12, 12, 12],
                   [12, 12, 12]],
      
                  [[12, 12, 12],
                   [12, 12, 12],
                   [12, 12, 12]]]) torch.Size([2, 3, 3])
                   
      # 브로드캐스팅 적용 : (차원이 동일하지 않더라도) 한 차원이 없거나 1일 경우 해당 차원이 확장 
      data1 = torch.full((1, 3, 2), 2) # 텐서 생성
      # data1 = torch.full((3, 2), 2) 
      data2 = torch.full((2, 2, 3), 3) # 텐서 생성
      
      result = torch.matmul(data1, data2) 
      print(result, data3.shape) 
      
      >>> tensor([[[12, 12, 12],
                   [12, 12, 12],
                   [12, 12, 12]],
      
                  [[12, 12, 12],
                   [12, 12, 12],
                   [12, 12, 12]]]) torch.Size([2, 3, 3])
      
profile
정리 노트
post-custom-banner

0개의 댓글