딥러닝 - PyTorch: 텐서 다루기

dumbbelldore·2025년 1월 13일
0

zero-base 33기

목록 보기
72/97

1. 텐서 (Tensor)

1-1. 주요 개념

  • PyTorch의 텐서는 Tensorflow와 달리 Constant, Variable을 별도의 클래스로 구분하지 않음
  • 기본적으로 모든 텐서에 대한 자동 미분 기능을 지원하므로, 불필요한 경우 requires_grad=False로 지정할 수 있음
import torch

# 기본 텐서 (Tensorflow의 Constant 역할)
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=False)

# 미분 가능 텐서 (Tensorflow의 Variable 역할)
y = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)

print(x) # tensor([1., 2., 3.])
print(y) # tensor([1., 2., 3.], requires_grad=True)

# rank, shape, dtype 확인 
print(x.ndim) # 1
print(x.shape) # torch.Size([3])
print(x.dtype) # torch.float32

1-2. 텐서 다루기

- Numpy 배열 변환

print(x.numpy()) # array([1., 2., 3.], dtype=float32)
  • 선형 증가 1차원 텐서 생성
print(torch.arange(0, 10, 2)) # tensor([0, 2, 4, 6, 8])
  • 모든 요소가 1인 텐서 생성
print(torch.ones(2, 3))
# tensor([[1., 1., 1.],
#         [1., 1., 1.]])
  • 선형 균등간격 텐서 생성
# 0 부터 10 까지 균등 간격의 수 3개 생성
print(torch.linspace(0, 10, 3))
# tensor([ 0.,  5., 10.])
  • 로그스케일 균등간격 텐서 생성
# 10^0 부터 10^2 까지 균등 간격의 수 3개 생성
print(torch.logspace(0, 2, 3))
# tensor([1., 10., 100.])

2. 난수 생성

  • Tensorflow와 마찬가지로 torch의 하위 클래스로 다양한 난수 생성 기능을 제공함
# 0 이상 1 미만 균등분포 난수
rand_tensor = torch.rand(2, 3)
print(rand_tensor)
# tensor([[0.9912, 0.0085, 0.3654],
#        [0.3622, 0.1672, 0.8118]])

# 평균이 0이고 표준편차가 1인 정규분포 난수
randn_tensor = torch.randn(3, 3)
print(randn_tensor)
# tensor([[-0.3203,  0.7413, -0.9294],
#        [-0.8491, -0.5490, -0.4873],
#        [ 1.0603, -0.9379, -0.3244]])
      
# 특정 범위의 정수 값 난수
randint_tensor = torch.randint(0, 10, (2, 3))
print(randint_tensor)
# tensor([[0, 2, 8],
#         [2, 5, 9]])
  • 결과의 재현가능성을 위해서는, 시드(seed)를 설정하고 작업을 진행하여야 함
# 동일한 seed 하 랜덤 값 두개 생성 후 비교
torch.manual_seed(42)
rand1 = torch.rand(3)

torch.manual_seed(42)
rand2 = torch.rand(3)

print(rand1 == rand2) # tensor([True, True, True])

3. 데이터 타입 변경

  • tf.cast()를 사용하던 Tensorflow와 달리, PyTorch에서는 텐서의 type() 또는 to() 함수를 이용해 데이터 간 타입 변경을 수행함
a = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32)
print(a.dtype) # torch.float32

a = a.type(torch.int8) # a.to(torch.int8) 동일 기능
print(a.dtype) # torch.int8

4. GPU 사용 설정

  • PyTorch에서는 GPU를 사용하기 위해 텐서에 적절한 device를 명시해주어야 함
  • NVIDIA GPU의 경우 device="cuda"를, AMD GPU(Mac 포함)은 device="mps"를 포함하여 텐서를 정의하여야 함
# Mac 환경 예제
# GPU 사용 가능하면 "mps"를, 불가능하면 "cpu"를 사용하도록 지정
dev = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
b = torch.tensor([1, 2, 3], device=dev)
print(b) # tensor([1, 2, 3], device='mps:0')

5. view() vs. reshape()

  • PyTorch에는 텐서의 형태를 변형하기 위해 view() 함수와 reshape() 함수 두 가지를 활용할 수 있음
  • 두 함수 모두 텐서의 데이터는 변경하지 않으며, 크기만 변형함
  • view()메모리에 연속적으로 저장된 텐서만 사용할 수 있지만 연산 측면에서 효율적임
  • reshape()는 메모리 연속성이 없어도 되지만, 필요 시 새로운 메모리 블록을 할당한다는 점에 차이가 있음
x = torch.tensor([[1, 2, 3], [4, 5, 6]])
print(x)
# tensor([[1, 2, 3],
#         [4, 5, 6]])

# view()와 reshape() 모두 사용 가능
x_view = x.view(3, 2)
x_reshape = x.reshape(3, 2)

print(x_view)
# tensor([[1, 2],
#         [3, 4],
#         [5, 6]])
print(x_reshape)
# tensor([[1, 2],
#         [3, 4],
#         [5, 6]])

# Transpose 하여 메모리를 강제로 비연속적으로 변환
y = x.t()
print(y)
# tensor([[1, 4],
#         [2, 5],
#         [3, 6]])

# view()는 에러 발생
try:
    y_view = y.view(-1)
except Exception as e:
    print(f"view() 에러: {e}")
# view() 에러: view size is not compatible with input tensor's size 
# and stride (at least one dimension spans across two contiguous 
# subspaces). Use .reshape(...) instead.

# reshape()는 에러없이 동작
y_reshape = y.reshape(-1)
print(y_reshape)
# tensor([1, 4, 2, 5, 3, 6])

6. 주요 연산

6-1. 사칙 연산

  • 덧셈: torch.add(x, y) 또는 x + y
  • 뺄셈: torch.sub(x, y) 또는 x - y
  • 곱셈: torch.mul(x, y) 또는 x * y
  • 나눗셈: torch.div(x, y) 또는 x / y

6-2. 행렬 연산

  • 행렬 곱: torch.matmul(x, y) 또는 x @ y
  • 전치: x.T 또는 x.transpose(dim0, dim1)
  • 행렬식: torch.det(x)
  • 역행렬: torch.linalg.inv(x)
  • 고유값 분해: torch.linalg.eig(x)

6-3. 통계 연산

  • 합계: torch.sum(x)
  • 평균: torch.mean(x)
  • 표준편차: torch.std(x)
  • 최대값: torch.max(x)
  • 최소값: torch.min(x)
  • 인덱스 포함 최대값: torch.max(x, dim)
  • 누적 합계: torch.cumsum(x, dim)

6-4. 기타 연산

  • 지수: torch.exp(x)
  • 로그: torch.log(x)
  • 절댓값: torch.abs(x)
  • 제곱근: torch.sqrt(x)
  • 삼각 함수: torch.sin(x), torch.cos(x), torch.tan(x)

*이 글은 제로베이스 데이터 취업 스쿨의 강의 자료 일부를 발췌하여 작성되었습니다.

profile
데이터 분석, 데이터 사이언스 학습 저장소

0개의 댓글