PyTorch: torch.nn.Module

danbibibi·2022년 1월 20일
0

PyTorch 🔥

목록 보기
2/20

torch.nn.Module

모든 neural network modules의 기본 class

  • 모듈은 다른 모듈을 포함할 수 있고, 트리 구조로 형성할 수 있다.
  • 새로운 모델을 만들 때 상속 받아서 사용한다.
  • nn.Module을 상속받아서 nn.Module의 기본적인 기능들을 사용할 수 있다.
  • super().__init__(): nn.Module의 생성자 호출
  • nn.Module 상속 후 __init__()forward()를 override해서 자신만의 모델을 만들 수 있다!
  • __init__(): 모델에서 사용될 module( ex) nn.Linear, nn.Conv2d), activation function 등을 정의
  • forward() : 모델에서 실행되어야하는 계산을 정의
    ( nn.Module을 상속받은 클래스의 객체는 forward() 함수를 호출하지 않아도 model 객체에 파라미터로 전달하면 바로 forward()가 시작된다. )
import torch
import torch.nn as nn

class Model(nn.Module): # 상속  => __init__, forward를 override
    def __init__(self):
        super().__init__() # nn.Module(부모 class)의 생성자 호출 ( -> 자식 class에서 사용 가능)

    def forward(self, x1, x2):
        output = torch.add(x1, x2)
        return output

x1 = torch.tensor([1])
x2 = torch.tensor([2])

add = Model()
output = add(x1, x2)

print(output)
profile
블로그 이전) https://danbibibi.tistory.com

0개의 댓글