Prototype Network를 NLP를 이용해 구현해보았다.
Phillip Lippe님이 CIFAR100과 DenseNet을 사용하여 작성하신 기존 코드를 Bert모델과 NLP Task에 맞게끔 조금씩 변경하면서 공부를 하였다,
그로인해 코드 구성과 ProtoNet을 구현하는 것이 다소 옳바르지 않을 수도 있다.
Colab을 사용하였고, Bert모델과 Pytorch Lightning을 통해 코드를 구현하였다.
ATIS Dataset을 통해서 Train, Valid, Test을 모두 사용하였으며
ATIS로 학습된 모델에 제 3의 Dataset인 SNIPS Dataset을 사용해 검증을 진행해 보았다.
Dataset을 살펴보면 class당 개수가 적은것들이 존재해 10개 이하의 label을 가지고 있는 class는 삭제해 준 후
Train / Valid / Test로 구분해주기 위해 class의 개수가 비슷하게 (5개, 5개, 6개)로 구성하게끔 수작업으로 Dataset을 분리
pytorch Dataset 클래스를 구성해 tokenizer로 변경하고 train과 valid Dataset을 변환
Quert_set과 Support_set으로 배치가 구성되야 하기 때문에 따로 batch_sampler를 구성
코드별 상세 구성은 코드 이미지 참고
pytorch Dataset과 batch_sampler를 통해 DataLoader 구성
split_batch 함수는 DataLoader를 통해 들어오는 data를 Query_set과 Support_set으로 분리해주기 위한 함수
각 k-shot의 개수 별로 성능을 확인하기 위한 Test 함수
Test_set의 개수가 180개 밖에 안되기 때문에 k-shot의 크기를 16이상 했을 때, zero-division error가 발생하기 때문에 k-shot의 개수를 15 미만으로 잡고 성능확인을 진행하였다.
colab pro를 사용하고 50번 모델을 학습을 하였다.
기존 ATIS Dataset의 Test을 통해 검증해본 결과 k개의 개수가 증가할 수록 모델의 정확도가 증가하는 것을 볼 수 있다.
제 3의 Dataset인 SNIPS를 통해 검증해본 결과 동일하게 k개의 개수가 증가할 수록 모델의 정확도는 증가하는 추세를 보이나 모델이 한 번도 보지못한 Dataset이다 보니 정확도가 생각보다 낮은 것을 볼 수 있다.
epoch을 100 ~ 200으로 늘릴 경우 정확도의 변화가 있을 것으로 예상이 되나, 컴퓨팅 자원과 시간상 시도해 보지 못한 점이 아쉽다.
전체 코드와 코드별 설명은 Github에서 확인이 가능하다.