Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks

tobigs-nlp-article·2022년 2월 21일
1
post-custom-banner

SBERT: https://arxiv.org/pdf/1908.10084.pdf

Multilingual SBERT: https://arxiv.org/pdf/2004.09813.pdf

BERT를 공부해본 사람들이라면 모두 의문이 들만한 것이 있습니다. BERT는 2개의 문장을 [SEP] 토큰을 붙여서 입력값으로 넣는데, 그렇다면 과연 개별 문장에 대해서 제대로 sentence embedding을 하냐는 것입니다.

이번에 리뷰할 논문인 SentenceBERT, 통칭 SBERT는 이러한 의문에서 시작된 논문으로 샴 네트워크(siamese network)와 triplet network를 이용해 BERT를 개선합니다.

그렇다면 구체적으로 BERT의 어떠한 점이 문제였는지, 그리고 저자들이 어떠한 구조를 통해 이를 개선하는지 알아보도록 합시다.

1. Introduction: BERT의 문제점 + other works

1-1. Computational Overhead in Cross-Encoder

image-encoders

BERT는 cross-encoder, 다르게 말하면 두 문장을 하나의 tranformer network에 넘겨서 target value를 얻는 형태를 가집니다.

트랜스포머 아키텍처를 활용한 Cross-encoder 방식의 장점은 두 문장이 Self-attention을 통해 상호 간 attend 될 수 있으므로 두 문장 사이의 관계를 면밀하게 파악할 수 있다는 점에 있습니다. 두 문장을 구성하는 모든 토큰들 간 어텐션 연산이 수행되기 때문이죠.

그런데, 이러한 구조는 sentence-pair regression task를 수행할 경우, 두 개의 문장이 하나의 네트워크에 같이 들어가야 하기 때문에 계산량이 방대하게 커지게 됩니다. 즉, 문장 간 유사도를 활용한 검색, 임베딩을 활용한 클러스터링 등에 있어 연산량의 한계가 존재합니다. Cross-encoder 방식에서는 매번 달라지는 서치 쿼리와 새로운 연산을 통해 유사도를 구해야하므로 미리 계산을 해놓을 수 없기 때문입니다.

생각을 해보면, 우리에게 n개의 문장이 있을 때, 가장 유사한 문장을 고르라는 task가 주어지면 n*(n-1)/2번의 연산이 필요하게 됩니다.

Cross Encoder 구조는 10000개의 문장을 hierarchical clustering하는데 약 65시간이 소요하는 반면, SBERT는 5초만에 완료된다고 합니다.⚡️

1-2. BAD sentence embedding

마찬가지로 BERT는 두 개의 문장을 하나의 입력값으로 넣기에 각각이 독립적인 sentence embedding을 가진다고 할 수 없습니다. 즉, 좋은 representation이라고 하기에는 무리가 있습니다.

일반적으로 sentence clustering이나 semantic search를 한다고 하면, 의미적으로 유사한 문장들을 vector space에 가까이 mapping하기 위해 노력합니다. 이를 위해서 연구자들도 BERT에 single sentence를 넣어서 fixed size vector로 만들어서 비교하기 시작했습니다. 크게 두 가지 접근법이 있습니다:
1. avg. BERT output layer
2. [CLS] token

그런데, 문제는 과연 이렇게 얻어낸 임베딩 값들이 좋다고 말할 수 있냐는 것입니다.

뒤의 실험 파트에서 나중에 확인하겠지만, 이렇게 얻어낸 BERT의 sentence embedding값들은 GloVe보다도 별로입니다.

즉, SBERT는 1) 개별 문장의 fixed size embedding vecotr를 얻을 수 있고, 따라서 2) cos-sim이나 Manhatten distance같은 similarity measure들을 이용해 문장을 비교할 수 있습니다.

cosine similarity나 Manhatten/Euclidean 거리는 의미적으로 유사한 문장들을 찾을 수 있을 뿐만 아니라 하드웨어에 매우 효율적입니다. 즉, clustering이나 semantic similarity search에 연산적으로 굉장히 유리한 것이죠.

1-3. Other works

studiesdetails
Skip-Thoughtsurrounding sentences를 예측하는 encoder-decoder 구조
InferSentlabeled Stanford Natural Language Inference dataset과 Multi-Genre NLI dataset을 siamese BiLSTM 으로 training ➡️ SkipThought보다 높은 성능
Universal Sentence Encodertransformer 네트워크 training, SNLI로 비지도학습 증대 ➡️ SNLI dataset이 sentence embedding training에 적절한 dataset임 확인

2. SBERT Model

SBERT는 pre-trained BERT와 RoBERTa에 fine-tune을 진행하는 방식으로 학습합니다.

2-1. Pooling startegies

고정된 사이즈의 sentence embedding 구하기 위해 3가지 방식을 도입해서 실험에 적용했습니다.

  1. the output of the [CLS] token
  2. the mean of all output vectors ➡️ MEAN-strategy
  3. computing a max-over-time of the output vectors ➡️ MAX-strategy

즉, pooling을 통해 (sent_len, n_dim)에서 (n_dim)으로 차원이 변경됩니다.

기본값은 볼드 처리가 되어있는 MEAN-strategy 입니다.

2-2. Architecture

샴 네트워크란 weights를 공유하는 2개의 네트워크 구조입니다.

샴 네트워크는 다음과 같이 학습됩니다.
1. 두 개의 입력 데이터를 준비
2. 각 입력에 대한 임베딩 값 얻기
3. 두 임베딩 사이의 거리를 계산
4. 두 입력이 같은 클래스에 속한다면 거리를 가깝게, 다른 클래스에 속한다면 거리를 멀게 하기

SBERT는 위의 Siamese Network를 이용해 모든 문장을 Cross-encoder 형식이 아닌 단일 인코더 형식으로 연산해 각 문장의 임베딩을 기준으로 유사도를 구하는 아키텍처를 가집니다.

이렇게 되면, 각 문장은 SBERT에 의해 미리 연산되어 저장될 수 있게 됩니다. 즉, 각 문장의 임베딩 연산이 독립적으로 수행되기 때문에 Cross-encoder 방식보다 훨씬 빠르게 문장 간 유사도를 구할 수 있습니다. 즉, 검색 후보가 되는 문서들에 대해서는 미리 임베딩 계산을 해놓은 후, 새로이 들어온 서치 쿼리의 임베딩만 얻은 후 문서들의 임베딩과 유사도만 구하면 되는 문제로 치환할 수 있습니다.

2-3. Objective function

SBERT는 training data에 따라 각기 다른 목적 함수를 이용해서 학습됩니다.

1) classification

enter image description here

o=softmax(Wt(u,v,uv))o = softmax(W_{t}(u, v, |u − v|))

  • WtR3n×kWt ∈ ℝ^{3n×k}
  • nn: the dimesnsion of the sentence embeddings
  • kk: the number of labels
  • cross-entropy loss

이때, element-wise difference |𝑢−𝑣| 는 두 개의 embedding의 차원 사이의 거리를 측정해서 유사한 것들끼리는 가깝게 그렇지 않은 것들끼리는 멀게 만들어줍니다.

2) regression

enter image description here

  • mean-squared error loss

3) triplet

enter image description here

max(saspsasn+ε,0)max(||s_{a} - s_{p}|| − ||s_{a} − s_{n}|| + ε, 0)

  • sxs_{x}: the sentence embeddings for a/n/pa/n/p
  • Anchor / Positive / Negative

즉, anchor sentence와 positive sentence 사이의 거리인 sasp||s_{a} - s_{p}||는 작기를 기대하고, anchor sentence와 negative sentence 사이의 거리인 sasn||s_{a} − s_{n}||는 크기를 기대합니다. 이때, 마진 ε를 설정하여, 적어도 ε만큼 anchor와 positive의 거리가 anchor와 negative보다 가깝도록 할 수 있습니다. 기본 값은 ε = 1입니다.

2-4. Training Detail

  • fine-tuning data
    • SNLI
      • 570,000 sentence pairs
      • labels: contradiction, entailment, neutral
    • MNLI
      • 430,000 sentence pairs
      • spoken and written text
  • batch size = 16
  • optimizer = Adam
  • learning rate = 2e-5
  • a linear learning rate warm-up over 10% of the training data
  • pooling strategy = MEAN

3. Evaluation

3-1. Semantic Textual Similarity

cos-sim을 사용해 두 sentence embedding의 유사도를 비교해봅니다.

1) Unsupervised

BERT는 Glove embedding보다도 낮은 성능을 보였으며, SBERT는 InferSent와 USE보다 더 높은 성능을 보였습니다. 이때, SICK-R에 대해서만 USE가 더 높은 성능올 보였는데, 이는 (S)BERT는 Wikipedia에서만 pre-train된 반면, USE는 news, question-answer pages, discussion forums로 더 다양하기 때문이라고 유추해볼 수 있습니다. 마지막으로, SBERT와 SRoBERTa의 성능 차이는 미미했습니다.

2) Supervised

STSb의 supervised dataset(8,628 sentence paris)으로 finetuning을 진행한 후, 성능 비교를 한 결과입니다.

BERT와 SBERT 모두 NLI에 fine-tuning한 경우 성능이 더 높습니다. 그리고 보면 알겠지만, finetuning을 진행한 경우, BERT가 거의 다 SBERT보다 성능이 좋았네요!

3) Wikipedia Sections Distinction

wiki에서 같은 section에 있는 문장이면 다른 section에 있는 문장보다 더 가까울 것이라고 가정하고, triplet으로 데이터셋을 구성하여 1 epoch 학습을 진행하여 실험을 진행합니다.

accuracy의 기준은 positive sentence(같은 section)이 negative sentence(다른 section)보다 anchor와 가까운지 입니다.

스크린샷 2022-02-17 오후 5.01.39

보다시피, SRoBERTa보다 SBERT가 조금 더 좋은 성능을 보였습니다.

3-2. SentEval (Transfer learning)

transfer tasks
MRSentiment prediction for movie reviews snippets on a five start scale
CRSentiment prediction of customer product reviews
SUBJSubjectivity prediction of sentences from movie reviews and plot summaries
MPQAPhrase level opinion polarity classification from newswire
SSTStanford Sentiment Treebank with binary labels
TRECFine grained question-type classification frmo TREC
MRPCMicrosoft Research Paraphrase Corpus from parallel news sources

여러 모델에서 sentence embedding을 뽑아내고, 이를 logistic regression classifier에 feeding합니다. 10-fold cross-validation setup 으로 train 후, prediction accuracy 계산한 결과는 다음과 같습니다.

SBERT의 sentence embedding이 원래 transfer learning을 목적으로 하지 않았음에도 불구하고 7개의 task중 5개에서 InferSent와 USE보다 더 높은 성능을 보였습니다. 특히 sentiment task인 MR,CR,SST에서 SBERT가 성능 향상이 크다는 것을 확인할 수 있습니다. TREC에서는 USE보다 성능이 많이 떨어졌는데, USE는 Q&A data에 pre-traine되어있기 때문입니다.

4. Ablation Study

1) Evaluate pooling strategies

classification의 경우, pooling strategy의 영향이 적고, concatenation mode의 영향이 컸습니다. InferSent와 USE는 softmax classifier에 (u,v,uv,uv)(u, v, |u-v|, u*v)를 사용했었는데, SBERT에서는 the element-wise uvu*v가 성능을 감소시켰다는 것을 확인할 수 있습니다.(80.78 -> 80.44)

regression의 경우, concat을 사용하지 않기 때문에, pooling strategy의 차이만 확인하자면, MEAN > CLS >>> MAX의 양상을 보였스빈다. InferSent의 BiLSTM에서는 MAX의 성능이 더 좋았던 것과 굉장히 대비됩니다.

5. Knowledge Distillation: Mono to Multilingual

마지막으로, 위에도 링크를 달아두었던 Multi-SBERT, 논문명 Making Monolingual Sentence Embeddings Multilingual using Knowledge Distillation에 대해서 간단하게 소개하려 합니다.

SBERT는 single language에 국한되어 있기 때문에, multilignuality로의 확장이 필요합니다. 그런데, 그냥 무작정 finetuning을 시키면 문제가 생기게 됩니다. 생각을 해보면, 다양한 언어라고 해도 같은 의미를 가지는 문장들끼리는 거리가 가까워야한다는 것이죠.

SBERT가 input sequence의 의미를 보존하면서 sentence embedding을 만든다는 점에서, 이를 여러 언어들의 set에 내부적으로 그리고 동시에 전범위로도 작동하게 해보자는 것이 이 논문의 목표입니다. 그리고 저자들은 이를 위해 knowledge distillation을 제시합니다.

knowledge distillation

  • SBERT 👩🏻‍🏫 Teacher
  • multilingual model 🧑🏻‍💻 Student

제시된 방법론에서는 teacher-student로 불리는 두 모델이 등장합니다. 목적은 student가 teacher의 knowledge를 distil하는 것이죠.

student model Mˉ\bar M은 2가지 속성을 가진 multilingual sentence embedding space를 학습하게 됩니다. 1) Vector spaces are aligned across languages, 즉 다른 언어로 적힌 같은 문장은 벡터 공간 내에서 가까이 위치합니다. 2) teacher model MM의 언어의 vector space properties는 다른 언어들에게 transfer됩니다.

보다 자세히 설명해보면, 모델은 T(sj)S(sj)T(s_{j}) ≈ S(s_{j}) && S(tj)T(sj)S(t_{j}) ≈ T(s_{j})을 목적으로, 즉, Student의 두 언어(source,target)(source, target)의 sentence representation 쌍을 Teacher의 sourcesource language embedding과 가깝게 만드는 목적을 가지고 학습됩니다.

구체적으로, teacher model MM은 English SBERT, student model Mˉ\bar M은 XLM-RoBERTa (XLM-R)가 사용되었습니다.

다만 이 모델에서는 번역이 되어있는 Parallel dataset이 필요하다는 한계가 있습니다.

6. Code Sample : clustering

from sentence_transformers import SentenceTransformer
from sklearn.cluster import KMeans

embedder = SentenceTransformer('paraphrase-distilroberta-base-v1')

# Corpus with example sentences
corpus = ['A man is eating food.',
          'A man is eating a piece of bread.',
          'A man is eating pasta.',
          'The girl is carrying a baby.',
          'The baby is carried by the woman',
          'A man is riding a horse.',
          'A man is riding a white horse on an enclosed ground.',
          'A monkey is playing drums.',
          'Someone in a gorilla costume is playing a set of drums.',
          'A cheetah is running behind its prey.',
          'A cheetah chases prey on across a field.'
          ]
corpus_embeddings = embedder.encode(corpus)

# Perform kmean clustering
num_clusters = 5
clustering_model = KMeans(n_clusters=num_clusters)
clustering_model.fit(corpus_embeddings)
cluster_assignment = clustering_model.labels_

clustered_sentences = [[] for i in range(num_clusters)]
for sentence_id, cluster_id in enumerate(cluster_assignment):
    clustered_sentences[cluster_id].append(corpus[sentence_id])

for i, cluster in enumerate(clustered_sentences):
    print("Cluster ", i+1)
    print(cluster)
    print("")
Cluster  1
['A man is eating food.', 'A man is eating a piece of bread.', 'A man is eating pasta.']

Cluster  2
['The girl is carrying a baby.', 'The baby is carried by the woman']

Cluster  3
['A cheetah is running behind its prey.', 'A cheetah chases prey on across a field.']

Cluster  4
['A monkey is playing drums.', 'Someone in a gorilla costume is playing a set of drums.']

Cluster  5
['A man is riding a horse.', 'A man is riding a white horse on an enclosed ground.']
  • Data: 웰니스 대화 스크립트 데이터셋 from AIhub
from sentence_transformers import SentenceTransformer
from sklearn.cluster import KMeans

embedder = SentenceTransformer('paraphrase-xlm-r-multilingual-v1')

# Corpus with example sentences

corpus = ['더 이상 내 감정을 내가 컨트롤 못 하겠어.', # 감정/감정조절이상
          '맨정신일 때는 저를 주체할 수 가 없었거든요.', # 감정/감정조절이상
          '화가 안 참아져.', # 감정/감정조절이상/화
          '또 실수하지는 않았을까, 걱정이 들어요.', # 감정/걱정
          '원래 병에 잘 걸리지 않는 편이었는데, 한 번 심장이 이상하다고 느끼니까 건강 걱정이 늘었어요.', # 감정/걱정/건강염려
          '누군가 나를 좀 이해해줬으면 좋겠어요.', # 감정/고독감
          '사람들 사이에 있으면 나만 버려져 있는 거 같아요.', # 감정/고독감
          '그때만 생각하면 소름이 돋아요.', # 감정/공포
          '새가 너무 무서워. 보기만해도 소리를 지를 정도야.', # 감정/공포/새
          ]
corpus_embeddings = embedder.encode(corpus)

# Perform kmean clustering
num_clusters = 4
clustering_model = KMeans(n_clusters=num_clusters)
clustering_model.fit(corpus_embeddings)
cluster_assignment = clustering_model.labels_

clustered_sentences = [[] for i in range(num_clusters)]
for sentence_id, cluster_id in enumerate(cluster_assignment):
    clustered_sentences[cluster_id].append(corpus[sentence_id])

for i, cluster in enumerate(clustered_sentences):
    print("Cluster ", i+1)
    print(cluster)
    print("")
Cluster  1
['맨정신일 때는 저를 주체할 수 가 없었거든요.', '또 실수하지는 않았을까, 걱정이 들어요.', '원래 병에 잘 걸리지 않는 편이었는데, 한 번 심장이 이상하다고 느끼니까 건강 걱정이 늘었어요.', '사람들 사이에 있으면 나만 버려져 있는 거 같아요.', '그때만 생각하면 소름이 돋아요.']

Cluster  2
['더 이상 내 감정을 내가 컨트롤 못 하겠어.', '화가 안 참아져.']

Cluster  3
['새가 너무 무서워. 보기만해도 소리를 지를 정도야.']

Cluster  4
['누군가 나를 좀 이해해줬으면 좋겠어요.']

Reference

https://www.sbert.net/examples/training/nli/README.html
https://tyami.github.io/deep%20learning/Siamese-neural-networks/
https://medium.com/dair-ai/making-monolingual-sentence-embeddings-multilingual-using-knowledge-distillation-59d8a7713672
https://towardsdatascience.com/advance-nlp-model-via-transferring-knowledge-from-cross-encoders-to-bi-encoders-3e0fc564f554
https://towardsdatascience.com/a-complete-guide-to-transfer-learning-from-english-to-other-languages-using-sentence-embeddings-8c427f8804a9

profile
투빅스 NLP 논문세미나입니다!
post-custom-banner

0개의 댓글