[PyTorch] torch.tensor.scatter_

qw4735·2023년 3월 8일
0

PyTorch

목록 보기
2/8

torch.tensor.scatter_

  • scatter_를 사용하는 방식은 원핫 인코딩을 할 때도 편리함.
    label = torch.tensor([3,4,5,6,7])
    one_hot = torch.zeros(5, 10)
    print(one_hot)
    > tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])
    label = label.view(-1,1)
    print(label)
    > tensor([[3],
              [4],
              [5],
              [6],
              [7]])
    print(one_hot.scatter_(1, label, 1))     # 열별로(dim=1)(가로방향), label 위치에 1을 넣기
    > tensor([[0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.]])
          
  • example
#열방향으로(가로방향으로, dim=1) (0,2) 자리에 1.23, (0,3) 자리에 1.23 넣기
z = torch.zeros(2, 4).scatter_(1, torch.tensor([[2],[3]]), 1.23)  
torch.tensor([[2],[3]])
> tensor([[2],
          [3]])
> tensor([[0.0000, 0.0000, 1.2300, 0.0000],
          [0.0000, 0.0000, 0.0000, 1.2300]])      
x = torch.rand(2,5)
torch.zeros(3,5).scatter_(0, torch.tensor([[0,1,2,0,0],[2,0,0,1,2]]), x)
torch.tensor([[0,1,2,0,0],[2,0,0,1,2]])

reference : https://aigong.tistory.com/35

0개의 댓글