๐ŸŽฒ[AI] Loss Function

manduยท2025๋…„ 4์›” 27์ผ

[AI]

๋ชฉ๋ก ๋ณด๊ธฐ
7/20

ํ•ด๋‹น ๊ธ€์€ FastCampus - '[skill-up] ์ฒ˜์Œ๋ถ€ํ„ฐ ์‹œ์ž‘ํ•˜๋Š” ๋”ฅ๋Ÿฌ๋‹ ์œ ์น˜์› ๊ฐ•์˜๋ฅผ ๋“ฃ๊ณ ,
์ถ”๊ฐ€ ํ•™์Šตํ•œ ๋‚ด์šฉ์„ ๋ง๋ถ™์—ฌ ์ž‘์„ฑํ•˜์˜€์Šต๋‹ˆ๋‹ค.


1. ๋‹ค์‹œ, ๋”ฅ๋Ÿฌ๋‹์˜ ๋ชฉ์ ์€?

  • ์šฐ๋ฆฌ์˜ ๋ชฉ์ ์€ ๋ฐ์ดํ„ฐ๋ฅผ ์ž…๋ ฅํ–ˆ์„ ๋•Œ, ์›ํ•˜๋Š” ์ถœ๋ ฅ์„ ๋ฐ˜ํ™˜ํ•˜๋Š” ๊ฐ€์ƒ์˜ ํ•จ์ˆ˜๋ฅผ ๋ชจ์‚ฌํ•˜๋Š” ๊ฒƒ
  • Linear Layer๋ฅผ ํ†ตํ•ด ์ด๋Ÿฌํ•œ ํ•จ์ˆ˜๋ฅผ ๊ทผ์‚ฌํ•จ
  • ํ•ด๋‹น ํ•จ์ˆ˜๊ฐ€ ์–ผ๋งˆ๋‚˜ ์ž˜ ์ž‘๋™ํ•˜๋Š”์ง€๋ฅผ ์ธก์ •ํ•  ์ˆ˜ ์žˆ๋Š” ๊ธฐ์ค€์ด ํ•„์š”ํ•จ
  • ์ด ๊ธฐ์ค€์ด ๋ฐ”๋กœ Loss Function

2. Loss๋ž€?

  • Loss(์†์‹ค๊ฐ’): ์‹ค์ œ ์ถœ๋ ฅ๊ฐ’(ลท)๊ณผ ๋ชฉํ‘œ ์ถœ๋ ฅ๊ฐ’(y) ์‚ฌ์ด์˜ ์ฐจ์ด
  • Loss๊ฐ€ ์ž‘์„์ˆ˜๋ก ๋ชจ๋ธ์ด ๋ชฉํ‘œ ํ•จ์ˆ˜์— ์ž˜ ๊ฐ€๊นŒ์›Œ์ง„๋‹ค๊ณ  ๋ณผ ์ˆ˜ ์žˆ์Œ
  • ๋”ฐ๋ผ์„œ Loss๊ฐ€ ์ž‘์€ Linear Layer๋ฅผ ์„ ํƒํ•˜๋Š” ๊ฒƒ์ด ์ค‘์š”

3. Loss Function์˜ ๊ฐœ๋…

  • Linear Layer์˜ ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ๋ฐ”๊ฟ€ ๋•Œ๋งˆ๋‹ค Loss ๊ฐ’์„ ๊ณ„์‚ฐ
  • Loss Function์€ ๋‹ค์Œ๊ณผ ๊ฐ™์€ ํ•จ์ˆ˜๋กœ ์ดํ•ดํ•  ์ˆ˜ ์žˆ์Œ
์ž…๋ ฅ: Linear Layer์˜ ํŒŒ๋ผ๋ฏธํ„ฐ
์ถœ๋ ฅ: ํ•ด๋‹น ํŒŒ๋ผ๋ฏธํ„ฐ์—์„œ์˜ Loss ๊ฐ’

4. ๋Œ€ํ‘œ์ ์ธ Loss ๊ณ„์‚ฐ ๋ฐฉ์‹

4.1 ์œ ํด๋ฆฌ๋””์•ˆ ๊ฑฐ๋ฆฌ (Euclidean Distance)

  • L2 (L1์€ ์ ˆ๋Œ€๊ฐ’)
  • ๋‘ ์  ์‚ฌ์ด์˜ ์ง์„  ๊ฑฐ๋ฆฌ

4.2 RMSE (Root Mean Square Error)

  • ํ‰๊ท  ์ œ๊ณฑ ์˜ค์ฐจ์˜ ์ œ๊ณฑ๊ทผ
  • ์œ ํด๋ฆฌ๋””์•ˆ ๊ฑฐ๋ฆฌ์—์„œ ์ฐจ์›์˜ ํฌ๊ธฐ๋งŒํผ Normalize ํ•œ ๊ฒƒ

4.3 MSE (Mean Square Error)

  • ๊ฐ€์žฅ ์ž์ฃผ ์‚ฌ์šฉ๋˜๋Š” Loss Function
  • RMSE์—์„œ ์ œ๊ณฑ๊ทผ์„ ์ƒ๋žตํ•œ ํ˜•ํƒœ (์ตœ์ ํ™” ๋ฐฉํ–ฅ์€ ๋™์ผ)
  • 1/n๋„ ์ œ๊ฑฐํ•ด์„œ ์“ฐ๊ธฐ๋„ ํ•จ (์–ด์ฐจํ”ผ ์ตœ์ ํ™” ๋ฐฉํ–ฅ์€ ๋ฐ”๋€Œ์ง€ ์•Š์œผ๋‹ˆ)

  • ์—ฌ๋Ÿฌ ๊ฐœ์˜ batch ๊ฒฝ์šฐ

  • pytorch ์ฝ”๋“œ
# ์ง์ ‘ ๊ตฌํ˜„
def mse(x_hat, x):
    # |x_hat| = (batch_size, dim)
    # |x| = (batch_size, dim)
    y = ((x - x_hat)**2).mean()
    
    return y
    
# torch.nn.functional ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ์˜ mse_loss ํ•จ์ˆ˜ ์‚ฌ์šฉ     
import torch.nn.functional as F

print(F.mse_loss(x_hat, x)) # default: mean
print(F.mse_loss(x_hat, x, reduction='sum'))
print(F.mse_loss(x_hat, x, reduction='none'))

# torch.nn.MSELoss ํด๋ž˜์Šค ์‚ฌ์šฉ
import torch.nn as nn

mse_loss = nn.MSELoss()
mse_loss(x_hat, x)

5. Summary

  • ์šฐ๋ฆฌ๊ฐ€ ์›ํ•˜๋Š” ํ•จ์ˆ˜๋ฅผ ์ž˜ ๋ชจ์‚ฌํ•˜๋ ค๋ฉด,

    • ํ•™์Šต์šฉ ์ž…๋ ฅ ๋ฐ์ดํ„ฐ๋ฅผ Linear Layer์— ๋„ฃ๊ณ 
    • ๊ทธ ๊ฒฐ๊ณผ ์ถœ๋ ฅ๊ฐ’(ลท)๊ณผ ๋ชฉํ‘œ๊ฐ’(y)์˜ ์ฐจ์ด๋ฅผ ๊ณ„์‚ฐํ•ด
    • ๊ทธ ์ฐจ์ด๋ฅผ ์ตœ์†Œํ™”ํ•˜๋Š” ํŒŒ๋ผ๋ฏธํ„ฐ(ฮธ)๋ฅผ ์ฐพ์•„์•ผ ํ•จ
  • ๋”ฐ๋ผ์„œ, ํ•™์Šต์˜ ํ•ต์‹ฌ์€ Loss๋ฅผ ์ตœ์†Œํ™”ํ•˜๋„๋ก ๋ชจ๋ธ์„ ์กฐ์ •ํ•˜๋Š” ๊ฒƒ์ž„


profile
๋งŒ๋‘๋Š” ๋ชฉ๋ง๋ผ

0๊ฐœ์˜ ๋Œ“๊ธ€