[파이토치] 1. 텐서 개요

2star_·2024년 10월 26일
0

ML/DL

목록 보기
12/18

파이토치 텐서 개념정리

1. 텐서 개요

  • 일반적으로 1차원 데이터를 벡터, 2차원 데이터를 행렬(matrix), 3차원 데이터를 텐서라고 말한다. 하지만 파이토치에서는 이런 차원에 관계없이 입출력 그리고 모든 데이터를 텐서 데이터 타입으로 정의하여 처리한다. 딥러닝에서 학습 파라미터인 W와 b에 대한 미분을 파이토치에서 자동으로 해 주는 기능이 최적화되어 있음.

1.1 리스트 형식의 데이터를 텐서로 직접 만들 수 있다.

list_data = [[1,1],[2,2]]

tensor_1 = torch.Tensor(list_data)

print(tensor_1)
print(f"tensor type : {type(tensor_1)}, tensor shape: {tensor_1.shape}")
print(f"tensor dtype : {tensor_1.dtype}, tensor device : {tensor_1.device}")

결과값

tensor([[1., 1.],
        [2., 2.]])
tensor type : <class 'torch.Tensor'>, tensor shape: torch.Size([2, 2])
tensor dtype : torch.float32, tensor device : cpu
if torch.cuda.is_available():
    tensor_1 = tensor_1.to("cuda")
# .to("cuda") 메소드를 이용해서 텐서를 GPU로 이동시키기

print(f"tensor type : {type(tensor_1)}, tensor shape: {tensor_1.shape}")
print(f"tensor dtype : {tensor_1.dtype}, tensor device : {tensor_1.device}")

결과값

tensor type : <class 'torch.Tensor'>, tensor shape: torch.Size([2, 2])
tensor dtype : torch.float32, tensor device : cuda:0

1.2 numpy 데이터로부터 텐서 생성

리스트와 달리 정수형(int)로 변환된 것을 알 수 있다.

import numpy as np

numpy_data = np.array(list_data)

tensor2_1 = torch.from_numpy(numpy_data) # 넘파이 데이터로부터 텐서 만들기

print(tensor2_1)
print(f"tensor type : {type(tensor2_1)}, tensor shape : {tensor2_1.shape}")
print(f"tensor dtype : {tensor2_1.dtype}, tensor device : {tensor2_1.device}")

결과값

tensor([[1, 1],
        [2, 2]], dtype=torch.int32)
tensor type : <class 'torch.Tensor'>, tensor shape : torch.Size([2, 2])
tensor dtype : torch.int32, tensor device : cpu

딥러닝에서는 기본 데이터타입이 실수(float)이므로 type casting을 해서 float() 해주는 것이 필요하다.

tensor2_2 = torch.from_numpy(numpy_data).float() # 정수형 -> 실수형으로 변환

print(tensor2_2)
print(f"tensor type : {type(tensor2_2)}, tensor shape : {tensor2_2.shape}")
print(f"tensor dtype : {tensor2_2.dtype}, tensor device : {tensor2_2.device}")

결과값

tensor([[1., 1.],
        [2., 2.]])
tensor type : <class 'torch.Tensor'>, tensor shape: torch.Size([2, 2])
tensor dtype : torch.float32, tensor device : cpu

1.3 random 데이터로부터 텐서 생성 -> numpy로 바꾸기

tensor3 = torch.rand(2,3)  # rand() 메소드는 0~1 사이의 균일한 분포의 random 값을 생성한다.
print(tensor3)

tensor4 = torch.randn(2,2) # randn() 메소드는 평균 0, 분산 1 인 정규분포 random 값을 생성한다.
# 딥러닝에서 가중치와 바이어스를 초기화 할 때 많이 사용된다.
print(tensor4)

결과값

tensor([[0.4952, 0.5712, 0.6845],
        [0.3435, 0.5821, 0.2674]])
tensor([[-1.4153, -1.0185],
        [-0.4283,  0.9513]])
tensor5 = torch.randn(2,2)
print(tensor5)
# 텐서를 numpy로 바꾸기
numpy_from_tensor = tensor5.numpy()
print(numpy_from_tensor)

결과값

tensor([[ 2.1022,  0.6033],
        [ 1.3188, -0.9308]])
[[ 2.1022055   0.60325843]
 [ 1.3187758  -0.9308399 ]]

2. 파이토치 텐서 연산

tensor6tensor7
123789
456101112

2.1 텐서 슬라이싱, 인덱싱

tensor6 = torch.Tensor([[1,2,3], [4,5,6]])
tensor7 = torch.Tensor([[7,8,9], [10,11,12]])

print(tensor6[0]) # 텐서6의 첫번째 행의 모든 데이터
print(tensor6[:,1:]) # 텐서6의 모든 행의 데이터와, 두번째 열 이후의 데이터와의 교집합
print(tensor7[0:2, 0:-1]) # 텐서7의 첫째 행부터 두번째 행까지의 데이터와, 첫번째 열 부터 두번째 열 까지의 데이터와의 교집합
print(tensor7[-1,-1]) # 텐서7의 두번째 행과 마지막 열의 데이터
print(tensor7[..., -2]) # 텐서7의 모든 행의 두번째 마지막 열 데이터

결과값

tensor([1., 2., 3.])
tensor([[2., 3.],
        [5., 6.]])
tensor([[ 7.,  8.],
        [10., 11.]])
tensor(12.)
tensor([ 8., 11.])

2.2 텐서 연산

tensor8 = tensor6.mul(tensor7) # tensor8 = tensor6 * tensor7
print(tensor8)

결과값

tensor([[ 7., 16., 27.],
        [40., 55., 72.]])

matrix multiplication (matmul) 계산은 앞 텐서의 열과 뒤 텐서의 행을 맞춰줘야 한다.

tensor9 = tensor6.matmul(tensor7) # tensor6 @ tensor7
# 결과는 오류가 뜬다. 열과 행이 맞지 않기 때문.

# tensor7을 (3,2) 형태로 맞춘 후 계산해야 한다.
tensor9 = tensor6.matmul(tensor7.view(3,2)) # tensor6 @ tensor7.view(3,2)
print(tensor9)

결과값

tensor([[ 58.,  64.],
        [139., 154.]])

2.3 텐서 합치기

tensor_cat = torch.cat([tensor6, tensor7]) # 열을 기준으로(세로로) 합침 dim = 0
print(tensor_cat)

결과값

tensor([[ 1.,  2.,  3.],
        [ 4.,  5.,  6.],
        [ 7.,  8.,  9.],
        [10., 11., 12.]])
tensor_cat_dim1 = torch.cat([tensor6, tensor7], dim=1) # dim=1 이면 행을 기준으로(가로로) 합침
print(tensor_cat_dim1)

결과값

tensor([[ 1.,  2.,  3.,  7.,  8.,  9.],
        [ 4.,  5.,  6., 10., 11., 12.]])
profile
안녕하세요.

0개의 댓글