Pytorch로 모델 만들기(작성중)

가람·2021년 7월 27일
0

Pytorch로 모델을 만들 때, 주로 사용하는 방법으로
torch.nn.Module의 subclass를 정의하는 방법이 있다.

모델을 subclass로 정의할 때 기본구조는 아래와 같다.

import torch.nn as nn

class Model(nn.Module):
    def __init__(self):
	super(Model, self).__init__()
        ...
    
    def forward(self, x):
    	...
        return output

Pytorch 모델로 사용하기위해 아래의 세가지 조건을 만족해야한다.
1. torch.nn.Module 을 상속해야한다.
torch.nn.Module을 상속하면 다음과 같은 이점이 있다.
(작성예정)
2. __init__()을 override 해야한다.
모델을 어떻게 구성할지 정의 및 초기화 하는 메소드이다.
모델구조와 파라미터 초기화등을 정의해준다.
3. forward()을 override 해야한다.
backward()와 대비되는 개념이다.
모델에 input을 넣으면 어떤 과정을 통해 output을 구할지 정의해 준다.

간단한 MLP 예시는 아래와 같다.

import torch.nn as nn

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.layer1 = nn.Linear(2, 10)
        self.layer2 = nn.Linear(10, 10)
        self.layer3 = nn.Linear(10, 1)
        self.relu = nn.ReLU()

    def forward(self, x):
        output = self.layer1(x)
        output = self.relu(output)
        output = self.layer2(output)
        output = self.relu(output)
        output = self.layer3(output)
        return output

Pytorch 코드를 보면 super()관련해서 아래와 같이 parameter가 없는 방식과 있는 방식이 있다.

super(Model, self).__init__
super().__init__

두 코드의 차이는 python version에 따른 문법 차이이다.
super(Model, self).__init__ 은 python 2.x 문법이고,
super().__init__은 python 3.x 문법이다.
파이썬 3.x 버전은 2.x버전 문법도 사용 가능하므로
super(Model, self).__init__와 같이 사용하면 좀 더 범용성이 높다고 할 수 있다.

profile
hello world :)

0개의 댓글