처음에 ELECTRA를 이용한 MultiClassification을 구현하려고 하였지만 학습이 안 되는 문제가 생겨 사람들이 예시로 많이 올려놓은 BERT기반으로 먼저 구현해보기로 하였다.
사람들이 예시로 올려놓은 코드들을 보면 forward 함수에서
def forward(input_ids, attention_mask, token_type_ids) :
_, pooler_output = self.bert(input_ids, attention_mask, token_type_ids)
output = dropout(pooler_output)
output = linear(output)
return output
위와 같은 형태로 layer를 쌓는 것을 볼 수 있어서 나 또한 그렇게 쌓았었다.
하지만 내 코드는 돌아가기만 했지, 전혀 학습이 되지 않았다.
(loss가 1에서 진동하는 것을 볼 수 있었다.)
그래서 다른 코드에서 문제가 있었나 싶어서 봤는데
def configure_optimizers(self):
optimizer = torch.optim.AdamW(self.bert.parameters(), lr=self.learning_rate, weight_decay=1e-5)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10)
return {
'optimizer': optimizer,
'lr_scheduler': lr_scheduler
}
configure_optimizer
의 optimizer
선언 부분에 parameters를 bert한정으로만 해서 학습하게끔 되어 있는 것을 발견하고 이 부분만 고치면 학습이 가능할 것이라고 예상했다.
self.bert.parameters()
를 self.parameters()
로 바꿔 학습했지만 여전히 loss가 1에서 진동하는 모습을 보였다.(저게 가장 큰 문제라고 생각해서 고치자마자 드라마틱하게 학습될 거라고 기대했다.)
multiclass 뿐만이 아니라 이진분류인 nsmc 데이터셋으로도 학습이 잘 안 되는 것을 보고 데이터 문제는 아니라고 판단했다.
이 문제에 대해 고민하던 도중 깃헙에 들어온 pull request를 봤더니 본래의 forward
함수에서 fc layer을 하나 더 추가해주고 activation function을 취해주면서 loss 문제가 해결된 것을 알 수 있었다.
def forward(self, input_ids, attention_mask, token_type_ids):
output = self.bert(input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids
)
output = self.dropout(output.pooler_output)
output = F.leaky_relu(self.linear_1(output), negative_slope=0.1)
output = self.linear_2(output)
return output
forward
함수는 다음과 같았고 이 외에 random sampler와 optimizer
에 weight_decay
가 추가 되었고 lr_scheduler
함수가 바뀌었다.
(random sampler가 있는 경우 훨씬 학습이 잘 되는 것을 확인했다.)
이후 학습되는 것을 볼 수 있었지만 learning_rate
, dropout
비율 등의 하이퍼 파라미터들을 바꾸어도 validation accuracy가 0,55~0.64 정도의 정확도를 보였다.
베스트는 learning_rate=1e-5
, dropoout_p=0.3
이었다.
Encoder Layer를 이용한 의도 분류 성능 비교
위의 논문에서는 ELECTRA 기반이지만 CNN과 Bi-LSTM을 붙여 Multi Classification의 accuracy를 높인 것을 알 수 있다.
CNN은 문장 분류 성능을 향상시키고 Bi-LSTM 또한 의도 분류 성과를 향상시키는 연구 성과가 있다고 소개되어 있어 accuracy를 높이는 데 시도해볼만 하다고 느꼈다.
BERT로 구현 후에 나중에 ELECTRA기반으로도 다시 구현해볼 예정이다.
코드는 깃헙에 올려두었다.
깃헙이 삭제되었어요