softmax regression의 비용 함수 구현하기

개발하는 G0·2023년 9월 6일
0

pytorch

목록 보기
2/3

https://wikidocs.net/60572

import torch
import torch.nn.functional as F

torch.manual_seed(1)

z = torch.FloatTensor([1, 2, 3])

hypothesis = F.softmax(z, dim=0)
print(hypothesis)
tensor([0.0900, 0.2447, 0.6652])
hypothesis.sum()
tensor(1.)
z = torch.rand(3, 5, requires_grad=True)

hypothesis = F.softmax(z, dim=1)
print(hypothesis)
tensor([[0.2645, 0.1639, 0.1855, 0.2585, 0.1277],
        [0.2430, 0.1624, 0.2322, 0.1930, 0.1694],
        [0.2226, 0.1986, 0.2326, 0.1594, 0.1868]], grad_fn=<SoftmaxBackward>)
y = torch.randint(5, (3,)).long()
print(y)
tensor([0, 2, 1])
# 모든 원소가 0의 값을 가진 3 × 5 텐서 생성
y_one_hot = torch.zeros_like(hypothesis) 
y_one_hot.scatter_(1, y.unsqueeze(1), 1)
tensor([[1., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0.],
        [0., 1., 0., 0., 0.]])
print(y.unsqueeze(1))
tensor([[0],
        [2],
        [1]])
print(y_one_hot)
tensor([[1., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0.],
        [0., 1., 0., 0., 0.]])
cost = (y_one_hot * -torch.log(hypothesis)).sum(dim=1).mean()
print(cost)
tensor(1.4689, grad_fn=<MeanBackward1>)
# Low level
torch.log(F.softmax(z, dim=1))
tensor([[-1.3301, -1.8084, -1.6846, -1.3530, -2.0584],
        [-1.4147, -1.8174, -1.4602, -1.6450, -1.7758],
        [-1.5025, -1.6165, -1.4586, -1.8360, -1.6776]], grad_fn=<LogBackward>)
# High level
F.log_softmax(z, dim=1)
tensor([[-1.3301, -1.8084, -1.6846, -1.3530, -2.0584],
        [-1.4147, -1.8174, -1.4602, -1.6450, -1.7758],
        [-1.5025, -1.6165, -1.4586, -1.8360, -1.6776]], grad_fn=<LogSoftmaxBackward>)
# Low level
# 첫번째 수식
(y_one_hot * -torch.log(F.softmax(z, dim=1))).sum(dim=1).mean()
tensor(1.4689, grad_fn=<MeanBackward1>)
# 두번째 수식
(y_one_hot * - F.log_softmax(z, dim=1)).sum(dim=1).mean()
tensor(1.4689, grad_fn=<MeanBackward0>)
# High level
# 세번째 수식
F.nll_loss(F.log_softmax(z, dim=1), y)
tensor(1.4689, grad_fn=<NllLossBackward>)
# 네번째 수식
F.cross_entropy(z, y)
tensor(1.4689, grad_fn=<NllLossBackward>)
  • F.cross_entropy는 비용 함수에 소프트맥스 함수까지 포함하고 있음을 기억하고 있어야 구현 시 혼동하지 않습니다.
profile
초보 개발자

0개의 댓글