텐서와 미분자동화(Tensor and Autograd)

hottogi·2022년 11월 11일
0

Tensor

import torch

pytorch를 불러옵니다.

x = torch.tensor([[1, 2, 3],[4, 5, 6],[7, 8, 9]])
print(x)
print(x.size())
print(x.shape)
print(x.ndimension())

tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
torch.Size([3, 3])
torch.Size([3, 3])
2

차원(랭크)이 0인 텐서는 숫자 하나인 스칼라
차원(랭크)이 1인 텐서는 일렬로 숫자를 나열한 벡터
차원(랭크)이 2인 텐서는 2차원 행렬
차원(랭크)이 3인 텐서는 정육면체 같은 3차원 행렬

pytorch tensor = numpy array = pandas dataframe
랭크 = 차원
완전히 같은것은 아니자만 편의상 같다고 정의.

x = torch.unsqueeze(x, 0)
print(x)
print(x.size())
print(x.shape)
print(x.ndimension())

tensor([[[1, 2, 3],
[4, 5, 6],
[7, 8, 9]]])
torch.Size([1, 3, 3])
torch.Size([1, 3, 3])
3

x = torch.squeeze(x)
print(x)
print(x.size())
print(x.shape)
print(x.ndimension())

tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
torch.Size([3, 3])
torch.Size([3, 3])
2

try:
  x = x.view(3, 3)
  print(x)
except Exception as e:
  print(e)

tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])

w = torch.randn(5,3, dtype=torch.float)
x = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
print(w.size())
print(x.size())
print(w)
print(x)

torch.Size([5, 3])
torch.Size([3, 2])
tensor([[ 1.6905, 0.3728, 0.1947],
[-0.0254, 0.9622, 0.3893],
[ 0.1812, -0.6835, -0.5771],
[-0.0055, 0.0462, -0.9516],
[-0.3549, -1.1712, -0.9546]])
tensor([[1., 2.],
[3., 4.],
[5., 6.]])

b = torch.randn(5, 2, dtype = torch.float)
print(b.size())
print(b)

torch.Size([5, 2])
tensor([[ 1.2620, -1.4799],
[ 0.8538, -1.3717],
[-0.4126, 1.1260],
[ 1.8840, -0.9163],
[ 0.6686, -0.1546]])

wx = torch.mm(w, x)
print(wx.size())
print(wx)

torch.Size([5, 2])
tensor([[ 3.7826, 6.0406],
[ 4.8079, 6.1341],
[ -4.7546, -5.8339],
[ -4.6250, -5.5360],
[ -8.6414, -11.1221]])

result = wx + b
print(result.size())
print(result)

torch.Size([5, 2])
tensor([[ 5.0446, 4.5608],
[ 5.6618, 4.7625],
[ -5.1672, -4.7079],
[ -2.7410, -6.4523],
[ -7.9728, -11.2767]])

Autograd

Autograd는 수식의 기울기를 자동으로 계산한다는 뜻입니다
파이토치의 Autograd는 미분 계산을 자동화하여 경사하강법을 구현하는 수고를 덜어줍니다.

w = torch.tensor(1.0, requires_grad=True)
a = w*3
l = a**2
l.backward()
print(w.grad)
print(format(w.grad))

tensor(18.)
18.0

import torch
import pickle
import matplotlib.pyplot as plt
from google.colab import files
uploaded = files.upload()
shp_original_img = (100, 100)
broken_image = torch.FloatTensor(pickle.load(open("broken_image_t.p", 'rb'), encoding='latin1'))

오염된 사진 불러오기.

plt.imshow(broken_image.view(100,100)) 

<matplotlib.image.AxesImage at 0x7f2a5c022c50>

경사하강법을 사용하여 사진을 복원합니다.

def weird_function(x, n_iter = 5):
  h = x
  filt = torch.tensor([-1./3, 1./3, -1./3])
  for i in range(n_iter):
    zero_tensor = torch.tensor([1.0*0])
    h_l = torch.cat((zero_tensor, h[:-1]), 0)
    h_r = torch.cat((h[1:], zero_tensor), 0)
    h = filt[0]*h + filt[2]*h_l +filt[1]*h_r
    if i % 2 == 0:
      h = torch.cat((h[h.shape[0]//2:], h[:h.shape[0]//2]), 0)
  return h

def distance_loss(hypothesis, broken_image):
  return torch.dist(hypothesis, broken_image)

random_tensor = torch.randn(10000, dtype = torch.float)
lr = 0.8
for i in range(0, 20000):
  random_tensor.requires_grad_(True)
  hypothesis = weird_function(random_tensor)
  loss = distance_loss(hypothesis, broken_image)
  loss.backward()

  with torch.no_grad():
    random_tensor = random_tensor - lr*random_tensor.grad
  if i % 1000 == 0:
    print('Loss at {} = {}'. format(i, loss.item()))

Loss at 0 = 12.02292251586914
Loss at 1000 = 1.0950708389282227
Loss at 2000 = 0.5325770974159241
Loss at 3000 = 0.38084402680397034
Loss at 4000 = 0.30580782890319824
Loss at 5000 = 0.257468044757843
Loss at 6000 = 0.22201725840568542
Loss at 7000 = 0.19379359483718872
Loss at 8000 = 0.16995538771152496
Loss at 9000 = 0.1489325910806656
Loss at 10000 = 0.12980693578720093
Loss at 11000 = 0.11201657354831696
Loss at 12000 = 0.09520231187343597
Loss at 13000 = 0.07912485301494598
Loss at 14000 = 0.0636184811592102
Loss at 15000 = 0.04856501892209053
Loss at 16000 = 0.03387928381562233
Loss at 17000 = 0.020067676901817322
Loss at 18000 = 0.02116371877491474
Loss at 19000 = 0.021166490390896797

사진이 잘 복원됐는지 확인합니다.

plt.imshow(random_tensor.view(100, 100).data)

<matplotlib.image.AxesImage at 0x7f2a59f8c350>

profile

0개의 댓글