PyTorch와 사전 학습된 NLP 모델을 사용하여 텍스트 임베딩을 통해 유사 문서를 찾는 방법을 소개합니다. 이 시스템은 특정 쿼리와 가장 유사한 문서를 반환하는 검색 엔진과 같은 기능을 수행합니다. 코드의 각 함수에 대해 하나씩 설명하며, 각 단계가 어떤 역할을 하는지 알아보겠습니다.
def embed_text(text, tokenizer, model):
tokens = tokenizer([text], return_tensors="pt").input_ids
emb = model(tokens)[0]
emb = torch.mean(emb, dim=1).squeeze()
return emb
embed_text 함수는 입력된 텍스트를 사전 학습된 NLP 모델을 통해 벡터 형식의 임베딩으로 변환하는 역할을 합니다.
def process_docs_embedding(docs, tokenizer, model):
docs_embeddings = torch.zeros([len(docs), model.config.hidden_size])
for i, doc in enumerate(docs):
docs_embeddings[i, :] = embed_text(doc, tokenizer, model)
return docs_embeddings
process_docs_embedding 함수는 여러 개의 문서 리스트를 받아 각 문서의 임베딩을 계산해 하나의 텐서에 저장합니다.
python
def similarity_score(docs_embeddings, query_emb):
scores = docs_embeddings.matmul(query_emb) / (
(docs_embeddings ** 2).sum(dim=1).sqrt() * (query_emb ** 2).sum().sqrt())
return scores
def cosine_similarity(self, query_embedding):
# 내적 (dot product) 계산
dot_product = self.documents_embeddings.matmul(query_embedding)
# documents_embeddings와 query_embedding의 크기 계산
doc_norms = (self.documents_embeddings ** 2).sum(dim=1).sqrt()
query_norm = (query_embedding ** 2).sum().sqrt()
# 코사인 유사도 계산
scores = dot_product / (doc_norms * query_norm)
return scores
def euclidean_distance(self, query_embedding):
distances = torch.linalg.norm(self.documents_embeddings - query_embedding, dim=1)
return distances
def euclidean_distance(self, query_embedding):
# 두 벡터의 차이 계산
differences = self.documents_embeddings - query_embedding
# 각 차이를 제곱하고 합산
squared_differences = differences ** 2
summed_differences = squared_differences.sum(dim=1)
# 제곱근을 취해 최종 거리 계산
distances = summed_differences.sqrt()
return distances
similarity_score 함수는 코사인 유사도를 계산하여 쿼리 임베딩과 각 문서 임베딩 간의 유사도를 측정합니다.
python
def locate_doc(query, docs_embeddings, docs, tokenizer, model):
query_emb = embed_text(query, tokenizer, model)
best_match_idx = torch.argmax(similarity_score(docs_embeddings, query_emb))
doc = docs[best_match_idx]
return doc
locate_doc 함수는 입력 쿼리에 대해 가장 유사한 문서를 찾아 반환하는 역할을 합니다.