torch.squeeze(input: Tensor, dim: Optional[int, Tuple] = None) -> Tensor
torch.unsqueeze(input: Tensor, dim: int) -> Tensor
import torch
import torch.nn as nn
x = torch.rand(16, 4)
torch.sum(x, dim=1).size() # torch.Size([16])
torch.sum(x, dim=1).unsqueeze(-1).size() # torch.Size([16, 1])
linear = nn.Linear(4, 1)
linear(x).size() # torch.Size([16, 1])
linear(x).squeeze(-1).size # torch.Size([16])
[1] PyTorch 2.0 documentation (링크)