UPR Reranker 핵심만 알아보기 + 코드까지

Jeffrey Kim·2023년 10월 5일
1

UPR Reranker는 Improving Passage Retrieval with Zero-Shot Question Generation이라는 논문에 나온 Passage Reranker이다. 깃허브는 여기에.

리랭크는 retrieve한 passage들을 다시 재정렬 하는 것이다.

그럼 바로 UPR Reranker가 뭔지 알아보자.

UPR Reranker란

  1. OpenAI의 GPT와 같은 LLM을 준비한다.
  2. BM25, DPR 등의 방법으로 유저의 질문에 대한 passage들을 retrieve한다.
  3. LLM에게 passage만을 주고, passage에서 만들 수 있는 질문을 생성하게 한다.
  4. LLM이 생성한 질문과 유저의 질문이 얼마나 비슷한 지 점수를 매기고, 그 점수로 rerank를 한다.

어떻게 점수를 매길까?

UPR Reranker에서는 무식하게 LLM에게 질문을 여러 개 만들게 하지 않는다. 유저의 질문 토큰이 생성될 확률을 이용한다.
LLM은 기본적으로 앞의 토큰들을 통해 뒤에 토큰이 무엇이 올지 확률적으로 예측한다는 사실을 기억하자. 일련의 토큰 뒤에, 모든 토큰들이 나올 확률이 주르르륵 나오게 된다.
위의 3번에서, 질문을 생성하게 하면 생성하는 동안 각 토큰들이 나올 확률이 나온다. 그 중 유저의 질문 토큰이 나올 확률들을 구하는 것이다.
정리하면, LLM에게 passage를 보고 질문을 생성하게 할 때, 유저가 물어본 바로 그 질문생성할 확률을 구하는 것이다. 그 확률이 곧 점수가 된다.

수식으로 알아보기

logp(qzi)=1qtlogp(qtq<t,zi,θ)\log p(q|z_i) = \frac{1}{|q|}\sum_t\log p(q_t | q_{<t},z_i,\theta)

여기서 구하고자 하는 logp(qzi)\log p(q|z_i)의 경우에는, i번째 passage에 대해서 유저의 input question q를 LLM이 생성할 확률이다.
(참고로 logp(ziq)\log p(z_i|q) 즉 유저가 입력한 질문에 가장 비슷한 passage를 찾는 reranker의 task는 logp(qzi)\log p(q|z_i)와 비례한다고 논문에서 증명함)

여기서 t의 경우에는 질문 중 몇 번째 토큰인지를 뜻한다. |q|의 경우 질문의 토큰 길이 수이다.
그리고 q<tq_{<t}의 경우에는 t번째 토큰 전에 모든 토큰들이 나올 확률이라고 보면 된다.
즉, logp(qtq<t)\log p(q_t|q_{<t})는 다음 토큰으로 t번째 토큰이 나올 확률이다.

코드로 알아보기

torch 코드 전문은 RAGchain 레포에 날린 나의 PR에 나와있다.
여기 올리는 코드는 조금 단순화한 코드다. 대충 보면 알겠죠?

def calculate_likelihood(question: str, contexts: List[str]) -> List[str]:
    model_name = 't5-large'
    tokenizer = T5Tokenizer.from_pretrained(model_name)
    model = T5ForConditionalGeneration.from_pretrained(model_name)
    prompts = [f"Prompt: {context} Please write a question based on this passage." for context in contexts]
    # tokenize contexts and instruction prompts
    context_tokens = tokenizer(prompts, padding='longest', pad_to_multiple_of=8, truncation=True, return_tensors='pt')
    context_tensor, context_attention_mask = context_tokens.input_ids, context_tokens.attention_mask

    # tokenize question
    question_tokens = tokenizer([question], max_length=128, truncation=True, return_tensors='pt')
    question_tensor = question_tokens.input_ids
    question_tensor = torch.repeat_interleave(question_tensor, len(contexts), dim=0)

    sharded_nll_list = []
    # calculate log likelihood
    with torch.no_grad():
        logits = model(input_ids=context_tensor,
                       attention_mask=context_attention_mask,
                       labels=question_tensor).logits

    log_softmax = torch.nn.functional.log_softmax(logits, dim=-1)
    nll = -log_softmax.gather(2, question_tensor.unsqueeze(2)).squeeze(2)
    avg_nll = torch.sum(nll, dim=1)
    sharded_nll_list.append(avg_nll)

    topk_scores, indexes = torch.topk(-torch.cat(sharded_nll_list), k=len(context_tensor))
    result = [contexts[idx] for idx in indexes]
    return result

결론

논문에서는 UPR 리랭커가 다른 retrieval과 섞어 쓰면 그렇게 성능이 좋다고 자랑한다. (zero-shot unsupervised인데 supervised보다 성능이 좋다고...) 성능이 좋아질 수도 있으니 한 번 써보고, 직접 개발하기 귀찮다면 RAGchain를 활용해 보자!

profile
https://github.com/NomaDamas

0개의 댓글