torch.squeeze()
torch.topk()
torch.mul()
torch.matmul() / @
[torch] mul vs matmul
torch.cat()
torch.stack()
[torch] cat vs stack