MSE Loss

그는사악해·2023년 2월 7일
0

Back_to_Basic

목록 보기
2/2
post-thumbnail

INTRO

: Regression에서 일반적으로 쓰이는 Loss Function 중 하나.

Reference

Import

import numpy as np

import torch 
import torch.nn as nn

Code

def mse_loss(preds, trues):
    # preds, trues: torch tensor
    # reduction == 'mean'
    assert preds.shape == trues.shape
    
    return torch.sum((preds - trues)**2) / preds.view(-1).shape[0]

Test - 1

# 임의의 2차원 행렬
preds1 = torch.randn(3, 4)
trues1 = torch.randn(3, 4)

# nn.MSELoss()와 mse_loss로 구현한 함수로 구한 loss를 각각 비교
loss1_1 = mse_loss(preds1, trues1)
loss1_2 = nn.MSELoss()(preds1, trues1)

loss1_1, loss1_2

Test - 2

# 3차원 torch tensor에 대해서도 해보자.
preds2 = torch.randn(3, 5, 4)
trues2 = torch.randn(3, 5, 4)

# nn.MSELoss()와 mse_loss로 구현한 함수로 구한 loss를 각각 비교
loss2_1 = mse_loss(preds2, trues2)
loss2_2 = nn.MSELoss()(preds2, trues2)
loss2_1, loss2_2

End

: 거의 개인 공부 공간이 되어가는 구만

profile
데이터를 베어라

0개의 댓글