[pytorch] torch.nn.Module

J·2021년 2월 18일
0

pytorch

목록 보기
3/23

torch.nn.Module은 모든 뉴럴 네트워크 모듈의 기본 클래스이다. 일반적인 모델들은 이 클래스를 상속받아야한다. 모듈들은 다른 모듈을 또 포함할 수 있다.

__init__() 메소드에는 신경망 레이어의 구성요소들을 정의하고, __forward__에서는 호출 될 때 수행되는 연산을 정의한다. torch.nn.Module을 상속받는 모든 클래스에서 override되어야 한다. 일반적으로 이 두가지 메소드는 반드시 정의한다.

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

torch.nn.Module에는 많은 메소드들이 있지만 모두 소개할 수는 없어 가장 중요한 내용만 소개하였다.

출처

  1. https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module
profile
I'm interested in processing video&images with deeplearning and solving problem in our lives.

0개의 댓글