텐서 조작 - 3

Sukhun-Net·2024년 6월 20일

1) 뷰(View) - 원소의 수를 유지하면서 텐서의 크기 변경. 매우 중요!

텐서의 뷰(View)는 넘파이에서의 리쉐이프(Reshape)와 같은 역할


t = np.array([[[0, 1, 2],
               [3, 4, 5]],
              [[6, 7, 8],
               [9, 10, 11]]])
ft = torch.FloatTensor(t)

행렬 ft는 torch.Size([2, 2, 3]) (depth, row, column)

3차원 텐서에서 2차원 텐서로 변경

ft.view([-1, 3])

결과 
tensor([[ 0.,  1.,  2.],
        [ 3.,  4.,  5.],
        [ 6.,  7.,  8.],
        [ 9., 10., 11.]]

행렬 ft는 torch.Size([4, 3])로 변환

  • -1은 첫번째 차원은 파이토치에 맡기겠다는 의미
  • 내부적으로 크기 변환은 다음과 같이 이루어졌습니다. (2, 2, 3) -> (2 × 2, 3) -> (4, 3)

규칙 정리

  • view는 기본적으로 변경 전과 변경 후의 텐서 안의 원소의 개수가 유지

  • 파이토치의 view는 사이즈가 -1로 설정되면 다른 차원으로부터 해당 값을 유추

차원 텐서에서 3차원 텐서로 차원은 유지, Shape 변경

(2 × 2 × 3) = (? × 1 × 3) = 12를 만족해야 하므로 ?는 4

ft.view([-1, 1, 3])

tensor([[[ 0.,  1.,  2.]],

        [[ 3.,  4.,  5.]],

        [[ 6.,  7.,  8.]],

        [[ 9., 10., 11.]]])
torch.Size([4, 1, 3])

2) 스퀴즈(Squeeze) - 1인 차원을 제거

ft = torch.FloatTensor([[0], [1], [2]])


결과
tensor([[0.],
        [1.],
        [2.]])
        
torch.Size([3, 1])

두번째 차원이 1이므로 squeeze를 사용하면 (3,)의 크기를 가지는 텐서로 변경

ft.squeeze()
tensor([0., 1., 2.])

torch.Size([3])

=> 1이었던 두번째 차원이 제거되면서 (3,)의 크기를 가지는 텐서로 변경되어 1차원 벡터가 됨

3) 언스퀴즈(Unsqueeze) - 특정 위치에 1인 차원을 추가

# (3,)의 크기를 가지는 1인 차원 텐서 생성 
ft = torch.Tensor([0, 1, 2])

결과 
torch.Size([3])

# 첫번째 차원에 1인 차원을 추가해보겠습니다. 첫번째 차원의 인덱스를 의미하는 숫자 0을 인자로 넣으면 첫번째 차원에 1인 차원이 추가

ft = torch.Tensor([0, 1, 2])
ft.unsqueeze(0)

결과 
tensor([[0., 1., 2.]])
torch.Size([1, 3])

*View 로 구현할 수도 있다. ft.view(1, -1)

4) 타입 캐스팅(Type Casting)

자료형을 변환하는 것

lt = torch.LongTensor([1, 2, 3, 4])

lt.float()

5) 연결하기(concatenate)


x = torch.FloatTensor([[1, 2], [3, 4]])
y = torch.FloatTensor([[5, 6], [7, 8]])


torch.cat([x, y], dim=0)
dim=0은 첫번째 차원을 늘리라는 의미

결과 

tensor([[1., 2.],
        [3., 4.],
        [5., 6.],
        [7., 8.]])

6) 스택킹(Stacking)

x = torch.FloatTensor([1, 4])
y = torch.FloatTensor([2, 5])
z = torch.FloatTensor([3, 6])
torch.stack([x, y, z])


결과 
tensor([[1., 4.],
        [2., 5.],
        [3., 6.]])
        

torch.cat([x.unsqueeze(0), y.unsqueeze(0), z.unsqueeze(0)], dim=0) 와 동일

profile
Data Scientist (Computer Vision, Multimodal)

0개의 댓글