๐Ÿง  2-layer-MLP ๊ตฌํ˜„(XOR ๋ฌธ์ œ ํ•ด๊ฒฐ)

soyeonยท2025๋…„ 3์›” 27์ผ

ํ•ญํ•ด ํ”Œ๋Ÿฌ์Šค AI ์ฝ”์Šค

๋ชฉ๋ก ๋ณด๊ธฐ
2/11

๋‹จ์ผ ํผ์…‰ํŠธ๋ก (=์„ ํ˜• ๋ ˆ์ด์–ด)์—์„œ๋Š” XOR ๋ฌธ์ œ๋ฅผ ํ•ด๊ฒฐํ•  ์ˆ˜ ์—†์Œ

  • ๋‹จ์ผ ํผ์…‰ํŠธ๋ก  - ์„ ํ˜• ๋ถ„๋ฅ˜(y = ax + b๋กœ ํ‘œ์‹œ๋˜๋Š” ์„ ํ˜•(์ง์„ ) ํ•จ์ˆ˜)
    - โ‡’ XOR(๋ฒ ํƒ€์  ๋…ผ๋ฆฌํ•ฉ) ๊ฐ™์€ ๋น„์„ ํ˜•(๊ณก์„ ) ๋ฌธ์ œ ํ•ด๊ฒฐ X

    ์„ ํ˜• ํ•จ์ˆ˜. (์„ ํ˜•์  = ์ง์„  line)

    (์ง์„  ํ•˜๋‚˜๋กœ ๊ฐ™์€์ƒ‰๋ผ๋ฆฌ ๋ถ„๋ฅ˜๋ฅผ ํ•  ์ˆ˜๊ฐ€ ์—†๋‹ค๋Š” ๋œป)
  • XOR ๋ฌธ์ œ๋Š” ๋‘ ์ž…๋ ฅ์ด ๋‹ค๋ฅผ ๋•Œ๋งŒ True(1)์„ ์ถœ๋ ฅํ•˜๊ธฐ ๋•Œ๋ฌธ์—, ๋‹จ์ผ ํผ์…‰ํŠธ๋ก ์œผ๋กœ๋Š” ํ•ด๊ฒฐX
  • BUT ๋‹ค์ค‘ ํผ์…‰ํŠธ๋ก ์€ ์€๋‹‰์ธต์„ ํ†ตํ•ด ๋น„์„ ํ˜•์„ฑ์„ ํ•™์Šตํ•  ์ˆ˜ ์žˆ์–ด์„œ XOR ๋ฌธ์ œ๋ฅผ ํ•ด๊ฒฐํ•  ์ˆ˜ ์žˆ์Œ
  • ํ™œ์„ฑํ™” ํ•จ์ˆ˜๊ฐ€ x โ†’ ์‹ ๊ฒฝ๋ง์€ ๊ทธ์ € ๋‹จ์ˆœ ์„ ํ˜•๋ณ€ํ™˜๋งŒ OK โ‡’ ๋ณต์žกํ•œ ํŒจํ„ด ํ•™์ŠตX
  • ๋น„์„ ํ˜•์„ฑ์„ ๋„์ž…ํ•ด์„œ ์‹ ๊ฒฝ๋ง์ด ๋ณต์žกํ•œ ํŒจํ„ด์„ ํ•™์Šตํ•  ์ˆ˜ ์žˆ๊ฒŒ ํ•จ

MLP ๋ชจ๋ธ ๊ตฌํ˜„

https://tutorials.pytorch.kr/beginner/basics/optimization_tutorial.html

  1. randomness seed ๊ณ ์ •

import random

# randomness(๋ฌด์ž‘์œ„ ์ƒ์„ฑ)ํ•˜๋Š” ๊ฐ’๋“ค์„ seed๋กœ ํ†ตํ•ด ๊ณ ์ •์‹œํ‚ฌ ์ˆ˜ ์žˆ์Œ
seed = 7777

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
  1. ๋ฐ์ดํ„ฐ ์ƒ์„ฑ
x = torch.tensor([
    [0., 0.],
    [0., 1.],
    [1., 0.],
    [1., 1.]
])
y = torch.tensor([0, 1, 1, 0])

print(x.shape, y.shape)
  • x๊ฐ€ ์ž…๋ ฅ๊ฐ’, y๊ฐ€ ๋น„๊ตํ•ด๋ณผ ์ •๋‹ต๊ฐ’
  • y์˜ shape๋Š” 1์ฐจ์› ๋ฒกํ„ฐ torch.Size([4]) 1*4
  • x์˜ shape๋Š” 2์ฐจ์› ๋ฒกํ„ฐ 4*2
  • x์˜ ํŠน์„ฑ์˜ ๊ฐœ์ˆ˜(=ํ–‰๋ ฌ์˜ ์—ด ๊ฐœ์ˆ˜) โ†’ 2
  • y์˜ ํŠน์„ฑ์˜ ๊ฐœ์ˆ˜(=ํ–‰๋ ฌ์˜ ์—ด ๊ฐœ์ˆ˜) โ†’ 1
  1. ๋ชจ๋ธ ์ •์˜
from torch import nn

class Model(nn.Module):
  def __init__(self, d, d_prime):
    super().__init__()

    self.layer1 = nn.Linear(d, d_prime) #์ž…๋ ฅ -> ์€๋‹‰์ธต
    self.layer2 = nn.Linear(d_prime, 1) #y์˜ ํŠน์„ฑ์˜ ๊ฐœ์ˆ˜ 
    self.act = nn.ReLU()

  def forward(self, x):
    # x: (n, d)
    x = self.layer1(x)  # (n, d_prime) ์€๋‹‰์ธต
    x = self.act(x)     # (n, d_prime) ํ™œ์„ฑํ™” ํ•จ์ˆ˜ ์ ์šฉ
    x = self.layer2(x)  # (n, 1)

    return x

model = Model(2, 10) //model(x์˜ ํŠน์„ฑ ๊ฐœ์ˆ˜, ์€๋‹‰์ถฉ์˜ ๋…ธ๋“œ ์ˆ˜)
  • ํŒŒ์ดํ† ์น˜์—์„œ nn.Module class๋ฅผ ์ƒ์†๋ฐ›์•„์„œ ๊ตฌํ˜„ํ•จ
  • nn.module : abstract ํด๋ž˜์Šค. ์ˆœ์ „ํŒŒ ๋ฉ”์„œ๋“œ (def forward)๋ฅผ ๊ตฌํ˜„ํ•˜๋„๋ก abstract method๋ฅผ ์ œ๊ณตํ•จ
  • nn.Linear(์ž…๋ ฅ ๋ฐ์ดํ„ฐ์˜ ํŠน์„ฑ ๊ฐœ์ˆ˜, ์ถœ๋ ฅ ๋ฐ์ดํ„ฐ์˜ ํŠน์„ฑ ๊ฐœ์ˆ˜) : ์„ ํ˜• ํ•จ์ˆ˜
  • nn.Relu : ํ™œ์„ฑํ™” ํ•จ์ˆ˜์˜ ํ•˜๋‚˜์ธ Relu๋„ importํ•ด์„œ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Œ
  • model(x์˜ ํŠน์„ฑ ๊ฐœ์ˆ˜, ์€๋‹‰์ถฉ์˜ ๋…ธ๋“œ ์ˆ˜) ์€๋‹‰์ธต์„ ๋” ํฌ๊ฒŒ ์žก์œผ๋ฉด ํ•™์Šต์ด ์ž˜ ๋จ(์ ๋‹นํžˆ ์ปดํ“จํŒ… ํ™˜๊ฒฝ ์„ฑ๋Šฅ ๋ด๊ฐ€๋ฉด์„œ)

์ˆœ์ „ํŒŒ ์„ค๋ช…

  1. self.layer1(x) โ†’ ์ž…๋ ฅ๊ฐ’์„ ์€๋‹‰์ธต์˜ ํฌ๊ธฐ์— ๋งž๊ฒŒ ๋ณ€ํ™˜ํ•ด์„œ ํ•™์Šต ๊ฐ€๋Šฅํ•œ ๋ฐ์ดํ„ฐ๋กœ ๋งŒ๋“œ๋Š” ์ž‘์—…

    ์ž…๋ ฅ๊ฐ’ (4,2)์— 10์ฐจ์› ํ–‰๋ ฌ๊ณฑ โ†’ (4, 10)

  2. self.act(x) โ†’ ReLU : ์–‘์ˆ˜๋ฉด ๊ทธ๋Œ€๋กœ, ์Œ์ˆ˜๋ฉด 0 ์ฒ˜๋ฆฌ. ๊ฒฐ๊ณผ์˜ shape๋Š” ๋ณ€ํ•˜์ง€ ์•Š์Œ (4,10)

  3. self.layer2(x) โ†’ ์˜ˆ์ธก ๊ฐ’์„ ๋งŒ๋“ค๊ธฐ ์œ„ํ•ด 1์ฐจ์› ํ–‰๋ ฌ๊ณฑํ•ด์„œ 1์ฐจ์›์œผ๋กœ ๋ณ€ํ™˜

  4. ์ตœ์ ํ™” ํ•จ์ˆ˜(๊ฒฝ์‚ฌํ•˜๊ฐ•๋ฒ•) ์„ค์ •

from torch.optim import SGD

optimizer = SGD(model.parameters(), lr=0.1)
  • ์—ํญ๋งˆ๋‹ค ๊ฒ€์ฆ/ํ…Œ์ŠคํŠธ
  1. ํ•™์Šต
def train(n_epochs, model, optimizer, x, y):
  for e in range(n_epochs):
    model.zero_grad()

    y_pred = model(x)
    loss = (y_pred[:, 0] - y).pow(2).sum()

    loss.backward()
    optimizer.step()

    print(f"Epoch {e:3d} | Loss: {loss}")
  return model
  
  
  
n_epochs = 100
model = train(n_epochs, model, optimizer, x, y)
  • zero_grad() : ๊ธฐ์šธ๊ธฐ ์ดˆ๊ธฐํ™”
  • loss.backward() : ๋กœ์Šค์— ๋Œ€ํ•œ ๊ธฐ์šธ๊ธฐ ๊ณ„์‚ฐ
  • optimizer.stop : ๊ธฐ์šธ๊ธฐ ๊ณ„์‚ฐํ•œ ๊ฒƒ์„ ๊ฐ€์ง€๊ณ  ํŒŒ๋ผ๋ฏธํ„ฐ ์—…๋ฐ์ดํŠธ
  1. ํ…Œ์ŠคํŠธ
print(model(x))
print(y)

tensor([[0.0208], 
        [1.0484],
        [1.0156],
        [0.0496]], grad_fn=<AddmmBackward0>)
tensor([0, 1, 1, 0])
  • ์˜ˆ์ธก ์ž˜ ๋จ
profile
๐Ÿ“š ๋ฐฐ์šด ๊ฒƒ์„ ์ •๋ฆฌํ•˜๋Š” ํ”„๋ก ํŠธ์—”๋“œ ๊ฐœ๋ฐœ์ž

0๊ฐœ์˜ ๋Œ“๊ธ€