class Decoder(nn.Module):
def __init__(self, seq_len, input_dim=64, n_features=1):
super(Decoder, self).__init__()
self.seq_len, self.input_dim = seq_len, input_dim
self.hidden_dim, self.n_features = 2 * input_dim, n_features
self.rnn1 = nn.LSTM(
input_size=input_dim,
hidden_size=input_dim,
num_layers=1,
batch_first=True
)
self.rnn2 = nn.LSTM(
input_size=input_dim,
hidden_size=self.hidden_dim,
num_layers=1,
batch_first=True
)
self.output_layer = nn.Linear(self.hidden_dim, n_features)
def forward(self, x):
x = x.repeat(self.seq_len, self.n_features)
x = x.reshape((self.n_features, self.seq_len, self.input_dim))
x, (hidden_n, cell_n) = self.rnn1(x)
x, (hidden_n, cell_n) = self.rnn2(x)
x = x.reshape((self.seq_len, self.hidden_dim))
return self.output_layer(x)
decoder의 forward에서 x.repeat라는걸 쓴다
뭘까?
https://seducinghyeok.tistory.com/9
이 블로그에서는 torch.repeat()와 torch.expand()를 비교하여 보여줬다
차원을 늘리는거란다
Decoder에 딱 맞는 코드구만
덤으로
expand 또한 텐서를 반복시키는건데 1차원에서만 쓰는거라고 한다
끝.