아래 내용은 pytorch 2.1.0 버전으로 작성됨.
0. 바로 사용하기
result = torch.matmul(a, b)
a
: input data1
b
: input data2
1. 기본형
- torch 메서드로 사용
torch.matmul(input, other, *, out=None) → Tensor
- torch.Tensor 메서드로 사용
torch_Tensor_1.matmul(torch_Tensor_2) → Tensor
- 연산자 오버로딩
result = torch_Tensor_1 @ torch_Tensor_2
- in-place 연산
torch.matmul_(input_1, input_2)
2. 기능
- 행렬 곱(dot-product)를 연산하여 torch.Tensor로 반환해줌
- 특징
- Broadcasting (batch matrix에 대해서만)
- In-place 연산 (
torch.matmul_()
)
3. 파라미터
- 행렬 곱(dot-product)를 진행할 input data
other
- 행렬 곱을 진행할 other input data.
ref