squeeze, unsqueeze

서민석·2023년 4월 3일

torch.squeeze

torch.squeeze(input: Tensor, dim: Optional[int, Tuple] = None) -> Tensor

  • size가 1인 dim이 제거된 input이 반환됨
  • dim 인자를 통해 제거할 dim을 지정할 수 있음

torch.unsqueeze

torch.unsqueeze(input: Tensor, dim: int) -> Tensor

  • 지정된 dim에 size가 1인 dim이 추가된 input이 반환됨

예시: tensor의 size를 맞춰줄 때 사용할 수 있음

import torch
import torch.nn as nn


x = torch.rand(16, 4)
torch.sum(x, dim=1).size() # torch.Size([16])
torch.sum(x, dim=1).unsqueeze(-1).size() # torch.Size([16, 1])

linear = nn.Linear(4, 1)
linear(x).size() # torch.Size([16, 1])
linear(x).squeeze(-1).size # torch.Size([16])

Reference

[1] PyTorch 2.0 documentation (링크)

0개의 댓글