[PyTorch] 신경망 모델 정의하기

김희원·2025년 1월 5일
post-thumbnail

PyTorch 모델의 기본 구조

PyTorch로 설계하는 신경망은 기본적으로 다음과 같은 구조를 갖는다.

import torch.nn as nn
import torch.nn.functional as F

class Model_Name(nn.Module):
    def __init__(self):
    
        super(Model_Name, self).__init__()
        self.module1 = ...
        self.module2 = ...
        
        """
        ex)
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)
        """

    def forward(self, x):
    
        x = some_function1(x)
        x = some_function2(x)
        
        """
        ex)
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        """
        return x
        
model = Model_Name()

PyTorch 모델로 쓰기 위해선 다음 두 가지 조건을 따라야한다. 내장된 모델(nn.Linear등)도 이를 만족한다.

  1. torch.nn.Module을 상속해야한다.
  • interitance: 상속; 어떤 클래스를 만들 때 다른 클래스의 기능을 그대로 가지고오는 것.
  1. __init()__forward()를 override 해야한다.
  • override: 재정의; torch.nn.Module(부모클래스)에서 정의한 메소드를 자식클래스에서 변경하는 것.
  • __init()__에서는 모델에서 사용될 module(nn.Linear, nn.Conv2d), activation function(nn.functional.relu, nn.functional.sigmoid)등을 정의한다.
  • forward()에서는 모델에서 실행되어야하는 계산을 정의한다. backward 계산은 backward()를 이용하면 PyTorch가 알아서 해주니까 forward()만 정의해주면 된다. input을 넣어서 어떤 계산을 진행하하여 output이 나올지를 정의해준다고 이해하면 됨.

클래스 상속 (Class inheritance)

PyTorch 모델 구조를 보다가 클래스 상속에 대한 궁금증이 생겼다. [Python] 클래스 상속(Class inheritance) 그리고 Pytorch 모델에서의 해석을 참고하여 작성하였다.

class Person:
  def __init__(self, fname, lname): # init으로 fname, lanme을 받아 firstname, lastname에 각각 저장함.
    self.firstname = fname
    self.lastname = lname

  def printname(self): # 저장한 firstname, lastname을 출력함.
    print(self.firstname, self.lastname)


x = Person("John", "Doe")
x.printname()

Parent 클래스 상속 받기

class Student(Person):
  pass

Parent class인 Person에서 모든 것을 상속받았고 그 외의 기능은 없으므로(pass) Person과 동일하게 사용할 수 있다.

y = Student("HH", "jj")
y.printname()
# output:
HH jj

init 함수 대체하기 (override)

Parent class의 __init__ 함수가 마음에 안든다면 대체할 수 있다. 그냥 원래 __init__ 함수를 선언하듯 하면 된다.

class Student(Person):
  def __init__(self, fname, lname):
    print(fname, lname)
    
y = Student("hello", 100)
output:
hello 100

__init__ 함수만 바뀐 것이기 때문에 Parent class의 다른 기능들은 사용할 수 있다.

y.printname()은 원래는 사용가능하지만, 이 경우 parent class에서 self.firstname, self.lastname으로 저장을 했었는데 그러지 않았으므로 에러를 출력할 것이다.

super() 함수로 상속하기

개인적으로 이 부분이 가장 궁금했던 부분이다.

Parent class의 __init__을 유지하고 싶다면 아래처럼 작성하면 된다.

class Student(Person):
  def __init__(self, fname, lname):
    Person.__init__(self, fname, lname)

하지만 굳이 이름을 넣지 않고도 super()를 통해서도 가능하다.

class Student(Person):
  def __init__(self, fname, lname):
    super().__init__(fname, lname) # __init__에 'self' 인자가 들어가지 않는다.

다만 super()를 사용 시 __init__self 인자를 넣으면 안된다.

처음에는 "어 근데 유지를 하고싶은데 굳이 선언을 해줘야 하나? 위에 pass처럼 그냥 두면 되는거 아닌가?" 라는 생각이 들었는데 아래를 보고 의문이 풀렸다.

class Student(Person):
  def __init__(self, fname, lname, year):
    super().__init__(fname, lname)
    self.graduationyear = year

x = Student("Mike", "Olsen", 2019)

이렇게 코드를 짜면 Person의 __init__을 받으면서 self.graduationyear까지 init에 추가시킬 수 있다.
말하자면 Person.__init__ + 내가 추가하고 싶은 __init__

Person의 __init__에서 self.firstname, self.lastname이 저장되고 나의 __init__에서 self.graduationyear가 저장되었다.

0개의 댓글