UPR Reranker는 Improving Passage Retrieval with Zero-Shot Question Generation이라는 논문에 나온 Passage Reranker이다. 깃허브는 여기에.
리랭크는 retrieve한 passage들을 다시 재정렬 하는 것이다.
그럼 바로 UPR Reranker가 뭔지 알아보자.
UPR Reranker에서는 무식하게 LLM에게 질문을 여러 개 만들게 하지 않는다. 유저의 질문 토큰이 생성될 확률을 이용한다.
LLM은 기본적으로 앞의 토큰들을 통해 뒤에 토큰이 무엇이 올지 확률적으로 예측한다는 사실을 기억하자. 일련의 토큰 뒤에, 모든 토큰들이 나올 확률이 주르르륵 나오게 된다.
위의 3번에서, 질문을 생성하게 하면 생성하는 동안 각 토큰들이 나올 확률이 나온다. 그 중 유저의 질문 토큰이 나올 확률들을 구하는 것이다.
정리하면, LLM에게 passage를 보고 질문을 생성하게 할 때, 유저가 물어본 바로 그 질문을 생성할 확률을 구하는 것이다. 그 확률이 곧 점수가 된다.
여기서 구하고자 하는 의 경우에는, i번째 passage에 대해서 유저의 input question q를 LLM이 생성할 확률이다.
(참고로 즉 유저가 입력한 질문에 가장 비슷한 passage를 찾는 reranker의 task는 와 비례한다고 논문에서 증명함)
여기서 t의 경우에는 질문 중 몇 번째 토큰인지를 뜻한다. |q|의 경우 질문의 토큰 길이 수이다.
그리고 의 경우에는 t번째 토큰 전에 모든 토큰들이 나올 확률이라고 보면 된다.
즉, 는 다음 토큰으로 t번째 토큰이 나올 확률이다.
torch 코드 전문은 RAGchain 레포에 날린 나의 PR에 나와있다.
여기 올리는 코드는 조금 단순화한 코드다. 대충 보면 알겠죠?
def calculate_likelihood(question: str, contexts: List[str]) -> List[str]:
model_name = 't5-large'
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)
prompts = [f"Prompt: {context} Please write a question based on this passage." for context in contexts]
# tokenize contexts and instruction prompts
context_tokens = tokenizer(prompts, padding='longest', pad_to_multiple_of=8, truncation=True, return_tensors='pt')
context_tensor, context_attention_mask = context_tokens.input_ids, context_tokens.attention_mask
# tokenize question
question_tokens = tokenizer([question], max_length=128, truncation=True, return_tensors='pt')
question_tensor = question_tokens.input_ids
question_tensor = torch.repeat_interleave(question_tensor, len(contexts), dim=0)
sharded_nll_list = []
# calculate log likelihood
with torch.no_grad():
logits = model(input_ids=context_tensor,
attention_mask=context_attention_mask,
labels=question_tensor).logits
log_softmax = torch.nn.functional.log_softmax(logits, dim=-1)
nll = -log_softmax.gather(2, question_tensor.unsqueeze(2)).squeeze(2)
avg_nll = torch.sum(nll, dim=1)
sharded_nll_list.append(avg_nll)
topk_scores, indexes = torch.topk(-torch.cat(sharded_nll_list), k=len(context_tensor))
result = [contexts[idx] for idx in indexes]
return result
논문에서는 UPR 리랭커가 다른 retrieval과 섞어 쓰면 그렇게 성능이 좋다고 자랑한다. (zero-shot unsupervised인데 supervised보다 성능이 좋다고...) 성능이 좋아질 수도 있으니 한 번 써보고, 직접 개발하기 귀찮다면 RAGchain를 활용해 보자!