Pytorch ResNet에 Softmax함수가 없는 이유

chanykim·2022년 6월 30일
0

정리

ResNet 구현 당시 Pytorch(https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py)를 참고하여 구현하였는데 아래와 같이 논문에서는,

마지막에 출력함수로서 활성화 함수softmax를 쓴다고 했지만 pytorch에서 제공하는 resnet과 다른 블로그의 글을 보니 신경망 마지막에 Linear함수로 끝났습니다.

그래서 '논문에선 softmax를 쓴다고 명시되어 있는데 왜 나온대로 하지 않은 걸까?' 라고 생각이 들었고, 구글에 해당 내용에 대해 검색해보니 https://stats.stackexchange.com/questions/542054/why-does-torchvision-models-resnet18-not-use-softmax 이런 글을 발견하였습니다.
요약하자면 loss함수에서 torch.nn.CrossEntropyLoss를 사용하면 신경망에 softmax를 넣을 필요는 없다는 글이었습니다.

그래서 왜 그럴까 확인해보니, softmax의 아래 수식과
다음 CrossEntropyLoss의 수식을 비교했을 때

비슷한 수식었고, https://wikidocs.net/60572 에서 잘 설명해주었습니다.

#F.softmax() + torch.log() = F.log_softmax()
z = torch.rand(3, 5, requires_grad=True)

print(torch.log(F.softmax(z, dim=1)))
print("\n")
print(F.log_softmax(z, dim=1))
#tensor([[-1.5267, -1.9271, -1.5901, -1.3657, -1.7256],
#        [-1.3906, -1.2852, -2.1131, -1.6060, -1.8778],
#        [-1.4346, -2.0793, -1.6576, -1.4251, -1.5814]], grad_fn=<LogBackward0>)


#tensor([[-1.5267, -1.9271, -1.5901, -1.3657, -1.7256],
#        [-1.3906, -1.2852, -2.1131, -1.6060, -1.8778],
#        [-1.4346, -2.0793, -1.6576, -1.4251, -1.5814]], grad_fn=<LogSoftmaxBackward0>)

F.cross_entropy는 비용 함수에 소프트맥스 함수까지 포함하고 있음을 기억하고 있어야 구현 시 혼동하지 않습니다. 라는 글을 보는 순간 그래서 안쓰는구나를 확인할 수 있었습니다.
어짜피 loss함수에서 똑같은 작동을 하는데 굳이 두 번 쓸필요 없다는 뜻 같습니다.

결과값을 보면 똑같다는 것을 알 수 있습니다.

#F.log_softmax() + F.nll_loss() = F.cross_entropy()
y = torch.randint(5, (3,)).long()
z = torch.rand(3, 5, requires_grad=True)

print(F.nll_loss(F.log_softmax(z, dim=1), y))	#tensor(1.7025, grad_fn=<NllLossBackward0>)
print(F.cross_entropy(z, y))					#tensor(1.7025, grad_fn=<NllLossBackward0>)

torch.nn.functional.cross_entropy vs torch.nn.CrossEntropyLoss

함수냐, 클래스냐 차이만 있고 하는 일은 똑같습니다. torch.nn으로 구현한 클래스의 경우에는 attribute를 활용해 state를 저장하고 활용할 수 있고 torch.nn.functional로 구현한 함수의 경우에는 인스턴스화 시킬 필요 없이 사용이 가능합니다.

# nn.CrossEntropyLoss()
loss = nn.CrossEntropyLoss()
input = torch.randn(3, 5, requires_grad=True)
target = torch.empty(3, dtype=torch.long).random_(5)
output = loss(input, target)
output.backward()

# F.cross_entropy()
input = torch.randn(3, 5, requires_grad=True)
target = torch.randint(5, (3,), dtype=torch.int64)
loss = F.cross_entropy(input, target)
loss.backward()

참고

profile
오늘보다 더 나은 내일

0개의 댓글