[PyTorch] nn.ModuleList()

영이·2024년 5월 30일
0

nn.ModuleList()란

파이토치에서 사용되는 모듈들을 리스트 형태로 관리하는 클래스다.
nn.Sequential()과 비교하는 경우가 많은데, nn.Sequential()에는 자동으로 forward를 호출하는 기능이 있지만, nn.ModuleList()에는 그 기능이 없다. 또한 모듈끼리의 연결관계도 없다.

왜 사용하는 걸까?

모듈을 파이썬 리스트에 넣어두면 PyTorch가 인식하지 못한다. 따라서 우리는 모듈들의 리스트의 존재를 PyTorch에 알리고자 nn.ModuleList()로 wrapping을 한다.

언제 사용하면 좋을까?

모듈끼리 받는 input이 서로 다르고, 여러 개를 정의해야 할 때 사용하면 편리하게 모듈을 관리할 수 있다.

예시

import torch
import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.linears = nn.ModuleList()
        for i in range(5):
            self.linears.append(nn.Linear(10, 20))

    def forward(self, x):
        for layer in self.linears:
            x = layer(x)
        return x

참고문헌

https://wikidocs.net/194942
https://bo-10000.tistory.com/62

profile
연구가 싫었는데 어쩌다보니 대학원생이 되어버린 몸

0개의 댓글