import torch x = torch.linspace(-1,1,20) x.shape
import torch x = torch.linspace(-1,1,20).unsqueeze(dim=1) x.shape
x.flatten() x.shape