torch.nn.Module 이란?

Jadon·2021년 12월 27일
0

[CLASS] torch.nn.Module

공식문서에 따르면

torch.nn.Module 은 PyTorch의 모든 Neural Network의 Base Class이다. 모듈은 다른 모듈을 포함할 수 있고, 트리 구조로 형성할 수 있다.

공식문서에 예제를 코딩해보면서 감을 잡아보자.

import torch.nn as nn
import torch.nn.functional as F

class MyModel(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

MyModel을 만들 때, nn.Module을 상속받아서 기본적인 기능들을 사용할 수 있게 만들어준다.

이것을 응용해서 더하기 모델을 만들어 보자.

import torch
from torch import nn


class Add(nn.Module):
    def __init__(self):
        super().__init__() # 반드시 Add class의 부모 클래스인 nn.Module을 super()을 사용해서 초기화 시켜줘야 한다.

    def forward(self, x1, x2):
        output = torch.add(x1, x2)
        return output


x1 = torch.tensor([1])
x2 = torch.tensor([2])

add = Add()
output = add(x1, x2)

print(output)

여기서 의문점이 들었던 것이 forward() 함수를 호출하지 않았는데 add객체에 파라미터로 전달하면 바로 forward function이 시작된다. 이것은 nn.Module을 상속받은 클래스의 특징이라고 볼 수 있다.

0개의 댓글