텐서 조작하기

최지안·2024년 9월 23일

벡터, 행렬, 텐서

  • 벡터 : 1차원으로 구성된 값
  • 행렬 : 2차원으로 구성된 값
  • 텐서 : 3차원 이상

PyTorch로 텐서 선언하기

import torch

1D dimension

t = np.array([0., 1., 2., 3., 4., 5., 6.])
# 파이썬으로 설명하면 List를 생성해서 np.array로 1차원 array로 변환함.
print(t)
[0. 1. 2. 3. 4. 5. 6.]

2D dimension

t = np.array([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.], [10., 11., 12.]])
print(t)
[[ 1.  2.  3.]
 [ 4.  5.  6.]
 [ 7.  8.  9.]
 [10. 11. 12.]]

브로드캐스팅

m1 = torch.FloatTensor([[3, 3]])
m2 = torch.FloatTensor([[2, 2]])
print(m1 + m2)
tensor([[5., 5.]])

행렬 연산

행렬 곱셈과 곱셈의 차이

m1 = torch.FloatTensor([[1, 2], [3, 4]])
m2 = torch.FloatTensor([[1], [2]])
print('Shape of Matrix 1: ', m1.shape) # 2 x 2
print('Shape of Matrix 2: ', m2.shape) # 2 x 1
print(m1.matmul(m2)) # 2 x 1
Shape of Matrix 1:  torch.Size([2, 2])
Shape of Matrix 2:  torch.Size([2, 1])
tensor([[ 5.],
        [11.]])

평균

t = torch.FloatTensor([1, 2])
print(t.mean())
tensor(1.5000)

덧셈

t = torch.FloatTensor([[1, 2], [3, 4]])
print(t.sum()) # 단순히 원소 전체의 덧셈을 수행
print(t.sum(dim=0)) # 행을 제거
print(t.sum(dim=1)) # 열을 제거
print(t.sum(dim=-1)) # 열을 제거
tensor(10.)
tensor([4., 6.])
tensor([3., 7.])
tensor([3., 7.])

최대와 아그맥스

t = torch.FloatTensor([[1, 2], [3, 4]])
print(t)
tensor([[1., 2.],
        [3., 4.]])

원소의 수를 유지하면서 텐서의 크기 변경

t = np.array([[[0, 1, 2],
               [3, 4, 5]],
              [[6, 7, 8],
               [9, 10, 11]]])
ft = torch.FloatTensor(t)

언스퀴즈

특정 위치에 1인 차원을 추가

ft = torch.Tensor([0, 1, 2])
print(ft.shape)
print(ft.unsqueeze(0)) # 인덱스가 0부터 시작하므로 0은 첫번째 차원을 의미한다.
print(ft.unsqueeze(0).shape)
tensor([[0., 1., 2.]])
torch.Size([1, 3])

연결하기

x = torch.FloatTensor([[1, 2], [3, 4]])
y = torch.FloatTensor([[5, 6], [7, 8]])
print(torch.cat([x, y], dim=0))
tensor([[1., 2.],
        [3., 4.],
        [5., 6.],
        [7., 8.]])
print(torch.cat([x, y], dim=1))
tensor([[1., 2., 5., 6.],
        [3., 4., 7., 8.]])

스택킹

x = torch.FloatTensor([1, 4])
y = torch.FloatTensor([2, 5])
z = torch.FloatTensor([3, 6])
print(torch.stack([x, y, z]))
tensor([[1., 4.],
        [2., 5.],
        [3., 6.]])

0개의 댓글