파이토치에서 사용되는 모듈들을 리스트 형태로 관리하는 클래스다.
nn.Sequential()과 비교하는 경우가 많은데, nn.Sequential()에는 자동으로 forward를 호출하는 기능이 있지만, nn.ModuleList()에는 그 기능이 없다. 또한 모듈끼리의 연결관계도 없다.
모듈을 파이썬 리스트에 넣어두면 PyTorch가 인식하지 못한다. 따라서 우리는 모듈들의 리스트의 존재를 PyTorch에 알리고자 nn.ModuleList()로 wrapping을 한다.
모듈끼리 받는 input이 서로 다르고, 여러 개를 정의해야 할 때 사용하면 편리하게 모듈을 관리할 수 있다.
import torch
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.linears = nn.ModuleList()
for i in range(5):
self.linears.append(nn.Linear(10, 20))
def forward(self, x):
for layer in self.linears:
x = layer(x)
return x