[Paper Review] SURE: SUMMARIZING RETRIEVALS USING ANSWER CANDIDATES FOR OPEN-DOMAIN QA OF LLMS

김진수·2024년 2월 14일

Paper Review

목록 보기
6/10
post-thumbnail

Abstract

Open-Domain QA Task에서 LLM을 할 때, 널리 사용하는 방법은 외부 Retriever로 Retrieve한 Passages를 Question과 같이 Prompt로서 주는 Retrieval-Augmented Generation(RAG)입니다. 하지만 이러한 방법은 Retrieved Passages를 잘 활용하지 못한다는 한계가 있는데, SURE는 질문에 답변이 될만한 Answer Candidate을 생성하고 그에 맞게 각 Passage를 요약하는 Conditional Summarization 을 통해 Quesiton-Aware한 Summary를 만들고, Instance-Wise ValidationPair-Wise Ranking 을 통해 Best Answer를 도출합니다.

Notations

  • RetrieverRetriever : BM25, DPR, Contriever와 같은 pretrained retriever
  • CN+C^+_N : 전체 corpus CC에서 Retriever로 Retrieve한 Top NN개 Passages
  • CN+=Retriever(q,C,N)C^+_N = Retriever(q,C,N)
  • M\mathcal{M} : LLM
  • a^\hat{a} : LLM Prediction
  • a^=M(q,CN+)\hat{a}=\mathcal{M}(q,C^+_N)

Model Pipeline

Answer Candidate Generation

질문 qq와 retrieved passages CN+C^+_N를 프롬프트 pcanp_{can}에 넣어 LLM에게 주고 KK개의 Answer Candidates yk~, k=1,2,...,K\tilde{y_k},\ k=1,2,...,K 를 구합니다.

  • y~k=M(pcan(q,CN+))\tilde{y}_k=\mathcal{M}(p_{can}(q,C^+_N))

Conditional Summarization

qq, yk~\tilde{y_k}, CN+C^+_N를 프롬프트 psump_{sum}에 넣어 LLM에게 주고 각 candidate에 대응하는 KK개의 summary sks_k를 구합니다.

  • sk=M(pcan(q,CN+,y~k))s_k=\mathcal{M}(p_{can}(q,C^+_N, \tilde{y}_k))

Instance-Wise Validation

각 summary sks_k가 valid한지 묻는 프롬프트 pvalp_{val}를 통해 sks_k의 validity v(sk)v(s_k)를 구합니다.

  • v(sk)=1,v(s_k)=1, when M(pval(q,y~k,sk))\mathcal{M}(p_{val}(q,\tilde{y}_k,s_k)) =True or v(sk)=0v(s_k)=0, else.

Pair-Wise Ranking

K개의 summary를 pairwise하게 비교하기 위해 모든 si,sjs_i,s_j쌍에 대해 둘 중 더 나은 summary를 고르도록 하고, 이를 토대로 raking r(sk,SK)r(s_k,S_K)를 구합니다.

  • r(sk,SK)=ikKrpair(sk,si)r(s_k,S_K)=\sum\limits_{i\ne k}^Kr_{pair}(s_k,s_i),
  • rpair(sk,si)={1,M(prank(q,sk,si))=sk0 M(prank(q,sk,si))=si0.5 elser_{pair}(s_k,s_i) = \begin{cases} 1, & \mathcal{M}(p_{rank}(q,s_k,s_i))=s_k\\ 0 &\ \mathcal{M}(p_{rank}(q,s_k,s_i))=s_i \\ 0.5 &\ else\end{cases}

Final Prediction

Validity v(sk)v(s_k)와 ranking r(sk,SK)r(s_k,S_K)의 합을 최대로 하는 answer candidate y~k\tilde{y}_k를 fianl prediction으로 정합니다.

  • a^=y~k\hat{a}=\tilde{y}_{k^*}, k=arg maxkv(sk)+r(sk,SK)k^*=\argmax\limits_kv(s_k)+r(s_k,S_K)
profile
ML Student

0개의 댓글