๐Ÿ“pytorch ํด๋ž˜์Šค๋กœ ๋ชจ๋ธ ๊ตฌํ˜„ํ•˜๊ธฐ

Yoonsneeยท2022๋…„ 12์›” 16์ผ
0

๋‹จ์ˆœํ•œ ์„ ํ˜•ํšŒ๊ท€ ๋ชจ๋ธ

ํ•„์š”ํ•œ ๋ชจ๋“ˆ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ

import torch
import torch.nn as nn
import torch.nn.functional as F # ๋ชจ๋“  

๋ชจ๋ธ ์„ ์–ธ

model = torch.nn.Linear(1,1)

ํŒŒ์ดํ† ์น˜๋Š” Define by run๋ฐฉ์‹์œผ๋กœ ์‹คํ–‰๋˜๋ฏ€๋กœ ์œ ์—ฐํ•˜๋‹ค๋Š” ์žฅ์ ์ด ์žˆ์Šต๋‹ˆ๋‹ค. Define by run์ด๋ž€ ์—ฐ์‚ฐ ์ •์˜์™€ ๋™์‹œ์— ๊ฐ’์ด ์ดˆ๊ธฐํ™” ๋˜๋Š” ๋ฐฉ์‹์ด๋ฉฐ, ๋ฐ˜๋ฉด์— Define and run ๋ฐฉ์‹์€ ๋จผ์ € ๋ชจ๋ธ์„ ๋งŒ๋“ค์–ด์ฃผ๊ณ  ๊ฐ’์„ ๋ชจ๋‘ ๋”ฐ๋กœ ๋„ฃ์–ด์ฃผ๊ฒŒ ๋ฉ๋‹ˆ๋‹ค.

class LinearRegressionModel(torch.nn.Module): # torch.nn.Module์„ ์ƒ์†๋ฐ›์Œ
    def __init__(self): 
        super().__init__() # ๋‹ค๋ฅธ ํด๋ž˜์Šค์˜ ์†์„ฑ ๋ฐ ๋ฉ”์†Œ๋“œ๋ฅผ ์ž๋™์œผ๋กœ ๋ถˆ๋Ÿฌ ์˜ด
        self.linear = torch.nn.Linear(1,1) # input_dim 1, output_dim 1
    def forward(self, x):
        return self.linear(x)

์ผ๋‹จ ๋‹จ์ˆœํ•˜๊ฒŒ ํ•˜๊ธฐ ์œ„ํ•ด์„œ input_dim๋„ 1๋กœ ์ฃผ๊ณ , ์ถœ๋ ฅ์ธต๋„ 1๋กœ ์ฃผ์—ˆ์–ด์š”.

model = LinearRegressionModel()

๋‹ค์ค‘์„ ํ˜•ํšŒ๊ท€ ๋ชจ๋ธ

๋ชจ๋ธ ์„ ์–ธ

model = torch.nn.Linear(3,1)
class MultivariateLinearRegressionModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(3,1)
    def forward(self, x):
        return self.linear(x)

๋ชจ๋ธ์„ ๋ถˆ๋Ÿฌ์™€ ์˜ตํ‹ฐ๋งˆ์ด์ € ์„ค์ •ํ•ด์ฃผ๊ธฐ

model= MultivariateLinearRegressionModel()

์˜ตํ‹ฐ๋งˆ์ด์ €๋Š” SGD๋ฅผ ์‚ฌ์šฉํ–ˆ๊ณ , lr์€ 0.00001๋กœ ์ฃผ์—ˆ์Šต๋‹ˆ๋‹ค.

optimizer = torch.optim.SGD(model.parameters(), lr = 1e-5)

ํ™œ์„ฑํ™” ํ•จ์ˆ˜์™€, loss๋“ฑ์„ ๋„ฃ์–ด์ฃผ๊ฒŒ ๋ฉ๋‹ˆ๋‹ค.

nb_epochs = 2000
for epoch in range(1,nb_epochs+1):
    # H(x) ๊ณ„์‚ฐ
    prediction = model(x_train)
    # cost ๊ณ„์‚ฐ
    cost = F.mse_loss(prediction,y_train)
    # gradient๋ฅผ 0์œผ๋กœ ์ดˆ๊ธฐํ™”
    optimizer.zero_grad()
    # ๋น„์šฉํ•จ์ˆ˜๋ฅผ ๋ฏธ๋ถ„ํ•˜์—ฌ gradient ๊ณ„์‚ฐ
    cost.backward()
    # W, b๋ฅผ ์—…๋ฐ์ดํŠธ
    optimizer.step()
    
    if epoch % 100 == 0:
         print('Epoch {:4d} / {} Cost: {:.6f}'.format(epoch, nb_epochs, cost.item()
                                                     ))

ํŒŒ์ดํ† ์น˜์—์„œ๋Š” ๊ฐ€์ค‘์น˜๋ฅผ ์ดˆ๊ธฐํ™”ํ•ด์ฃผ๋Š” ๊ณผ์ •์ด ๋ฐ˜๋“œ์‹œ ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค!
gradient๋ฅผ ์ดˆ๊ธฐํ™”ํ•ด์ฃผ๋Š”๋ฐ์—๋Š” ์ถ”ํ›„์— ์„ค๋ช…๋“œ๋ฆฌ๋„๋ก ํ• ๊ฒŒ์š”!

์ฐธ๊ณ ๋ฌธํ—Œ: PyTorch๋กœ ์‹œ์ž‘ํ•˜๋Š” ๋”ฅ ๋Ÿฌ๋‹ ์ž…๋ฌธ

profile
์œค์“ฐ๋„ค๋ฝ€๋ผ

0๊ฐœ์˜ ๋Œ“๊ธ€