import torch def mse(x_hay, x): # |x_hat| = (batch_size, dim) # |x| = (batch_Size, dim) y = ((x - x_hat)**2).mean() return y x = torch.FloatTensor([[1,1], [2, 2]]) x_hat = torch.FloatTensor([[0,0], [0,0]]) mse(x_hat, x)
impoty torch.nn.functional as F F.mse_loss(x_hat,x) import torch.nn as nn mse_loss = nn.MSELoss() mse_loss(x_hat,x)