Python List와 마찬가지로 nn.Module을 저장하는 역할을 하며, index 접근도 가능하다.
Python List를 nn.ModuleList()
로 감싸 주면 된다!
class MyModule(nn.Module):
def __init__(self):
super(MyModule, self).__init__()
self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])
def forward(self, x):
# ModuleList can act as an iterable, or be indexed using ints
for i, l in enumerate(self.linears):
x = self.linears[i // 2](x) + l(x)
return x
nn.ModuleList안에 Module들을 넣어 줌으로써 Module의 존재를 PyTorch에게 알려 주어야 한다. 만약 nn.ModuleList에 넣어 주지 않고, Python List에만 Module들을 넣어 준다면, PyTorch는 모델 파라미터의 존재를 알 수 없다! 때문에 optimizer 선언 시 model.parameter()
를 사용하여 파라미터를 넘겨주려 할 때 에러가 발생한다. 따라서 Module들을 Python List에 넣어 보관하는 경우에는 마지막에 nn.ModuleList로 wrapping을 해줘야 한다!