PyTorch Workflow: (2) Building a PyTorch Linear model

Yul Kang·2022년 12월 5일
0

PyTorch Workflow

목록 보기
2/5
post-thumbnail

This content is from: https://www.youtube.com/@mrdbourke, specifically, https://www.youtube.com/watch?v=Z_ikDlimN6A&ab_channel=DanielBourke

Create a linear model by subclassing nn.Module

class LinearRegressionModelV2(nn.Module):
  def __init__(self):
    super().__init__()
    # Use nn.Linear() for creating the model parameters / also called: linear transform, probing layer, fully connected layer, dense layer
    self.linear_layer = nn.Linear(in_features=1, out_features=1)
  
  # forward() defines the computation in the model  
  def forward(self, x: torch.Tensor) -> torch.Tensor:
    return self.linear_layer(x)

# Set the manual seed
torch.manual_seed(42)
model_1 = LinearRegressionModelV2()
model_1, model_1.state_dict()

Result

(LinearRegressionModelV2(
   (linear_layer): Linear(in_features=1, out_features=1, bias=True)
 ),
 OrderedDict([('linear_layer.weight', tensor([[0.7645]])),
              ('linear_layer.bias', tensor([0.8300]))]))
profile
A coder who wants to be a programmer

0개의 댓글