Self-RAG: Learning to Retrieve, Generate, and Critique through Self-Reflection (link)
논문과 구현 코드 차이점 정리
구현 코드 설명
... (import 생략) ...
... (format_docs 함수 생략) ...
def create_retrieval_grader(model):
"""retriever가 검색한 문서의 연관성을 채점"""
class GradeDocuments(BaseModel):
"""검색된 문서들이 질문과 얼마나 관련이 있는지를 나타내는 이진 점수"""
binary_score: str = Field(
description="문서들이 질문과 연관성이 있는지 '예' 또는 '아니오'로 판단"
)
structured_llm_grader = model.with_structured_output(GradeDocuments)
system = """당신은 사용자 질문에 대해 검색된 문서의 연관성을 평가하는 채점자입니다. \n
엄격하게 채점할 필요는 없습니다. 목적은 잘못된 검색 결과를 걸러내는 것입니다. \n
문서가 질문과 관련된 키워드나 유사한 의미를 포함한다면, 해당 문서를 연관성이 있다고 판단하세요. \n
문서와 질문 사이의 연관성을 나타내기 위해 '예' 또는 '아니오'라는 이진 점수를 사용하세요."""
grade_prompt = ChatPromptTemplate.from_messages(
[
("system", system),
("human", "검색된 문서: \n\n {document} \n\n 사용자 질문: {question}"),
]
)
retrieval_grader = grade_prompt | structured_llm_grader
return retrieval_grader
def create_hallu_grader(model):
"""생성된 응답이 환각인지 여부를 측정"""
class GradeHallucinations(BaseModel):
"""생성된 응답에 환각(hallucination)이 있는지를 나타내는 이진 점수"""
binary_score: str = Field(
description="응답이 사실에 근거하는지 '예' 또는 '아니오'로 판단"
)
# LLM with function call
structured_llm_grader = model.with_structured_output(GradeHallucinations)
# Prompt
system = """당신은 언어모델이 생성한 응답이 검색된 사실들에 기반하고 있는지 여부를 평가하는 채점자입니다. \n
'예' 혹은 '아니오' 둘 중 하나로 대답하세요. '예'는 응답이 사실들에 기반하고 있지 않음을 의미합니다."""
hallucination_prompt = ChatPromptTemplate.from_messages(
[
("system", system),
("human", "사실들: \n\n {documents} \n\n 언어모델 응답: {generation}"),
]
)
hallucination_grader = hallucination_prompt | structured_llm_grader
return hallucination_grader
def create_answer_grader(model):
class GradeAnswer(BaseModel):
"""질문에 적절한 대답인지를 평가하는 이진 점수"""
binary_score: str = Field(
description="질문에 적절한 대답인지 '예' 혹은 '아니오'로 판단"
)
# LLM with function call
structured_llm_grader = model.with_structured_output(GradeAnswer)
# Prompt
system = """당신은 응답이 질문에 적절한지 혹은 응답이 질문을 해결할 수 있는지 여부를 평가하는 채점자입니다. \n
'예' 혹은 '아니오' 둘 중 하나로 대답하세요. '예'는 응답이 질문을 해결할 수 있음을 의미합니다."""
answer_prompt = ChatPromptTemplate.from_messages(
[
("system", system),
("human", "사용자 질문: \n\n {question} \n\n 언어모델 응답: {generation}"),
]
)
answer_grader = answer_prompt | structured_llm_grader
return answer_grader
def create_query_rewriter(model):
"""재검색을 위해 사용자 질문을 재작성"""
system = """당신은 입력된 질문을 문서 검색에 최적화된 바꿔주는 질문 재작성자입니다. \n
입력된 질문을 보고 내재된 의도나 의미를 추론해보세요."""
re_write_prompt = ChatPromptTemplate.from_messages(
[
("system", system),
(
"human",
"최초 질문은 다음과 같습니다: \n\n {question} \n 개선된 질문을 작성해주세요.",
),
]
)
question_rewriter = re_write_prompt | model | StrOutputParser()
return question_rewriter
class SelfRAG:
def __init__(self, initial_question, prompt, retriever, model):
self.initial_question = initial_question # for backup
self.question = initial_question
self.prompt = prompt
self.retriever = retriever
self.model = model
self.retrieval_grader = create_retrieval_grader(model)
self.hallucination_grader = create_hallu_grader(model)
self.answer_grader = create_answer_grader(model)
self.query_rewriter = create_query_rewriter(model)
self.state = "retrieve_documents"
self.documents = None
self.answer = None
self.rate_limit_break_sec = 15
self.patience = 10
def _countdown(self):
self.patience -= 1
def retrieve_documents(self):
logging.info('--- retrieve documents ---')
documents = self.retriever.invoke(self.question)
self.documents = documents
self.state = "grade_documents"
self._countdown()
def grade_documents(self):
logging.info('--- grade documents ---')
filtered_docs = []
for d in self.documents:
while True:
try:
score = self.retrieval_grader.invoke({'question': self.question,
'document': d.page_content})
break
except RateLimitError as e:
logging.warning(f'RateLimitReached가 발생하여 {self.rate_limit_break_sec}초간 쉽니다.')
time.sleep(self.rate_limit_break_sec)
grade = score.binary_score
if grade == '예':
filtered_docs.append(d)
elif grade == '아니오':
continue
if filtered_docs:
self.state = "generate_llm_answer"
self.documents = filtered_docs
else:
self.state = "rewrite_question"
def generate_llm_answer(self):
logging.info('--- generate llm answer ---')
rag_chain = self.prompt | self.model | StrOutputParser()
while True:
try:
answer = rag_chain.invoke({"retrieved": format_docs(self.documents),
"question": self.question})
break
except RateLimitError as e:
logging.warning(f'RateLimitReached가 발생하여 {self.rate_limit_break_sec}초간 쉽니다.')
time.sleep(self.rate_limit_break_sec)
self.answer = answer
self.state = "check_hallucination"
self._countdown()
def rewrite_question(self):
logging.info('--- rewrite question ---')
while True:
try:
new_question = self.query_rewriter.invoke({"question": self.question})
break
except RateLimitError as e:
logging.warning(f'RateLimitReached가 발생하여 {self.rate_limit_break_sec}초간 쉽니다.')
time.sleep(self.rate_limit_break_sec)
self.question = new_question
self.state = "retrieve_documents"
def check_hallucination(self):
logging.info('--- check hallucination ---')
while True:
try:
hallucination_score = self.hallucination_grader.invoke({"documents": self.documents,
"generation": self.answer})
break
except RateLimitError as e:
logging.warning(f'RateLimitReached가 발생하여 {self.rate_limit_break_sec}초간 쉽니다.')
time.sleep(self.rate_limit_break_sec)
hallucination_grade = hallucination_score.binary_score
if hallucination_grade == '예':
self.state = "generate_llm_answer"
else:
self.state = "check_answer_relevance"
def check_answer_relevance(self):
logging.info('--- check answer relevance ---')
while True:
try:
relevancy_score = self.answer_grader.invoke({"question": self.question,
"generation": self.answer})
break
except RateLimitError as e:
logging.warning(f'RateLimitReached가 발생하여 {self.rate_limit_break_sec}초간 쉽니다.')
time.sleep(self.rate_limit_break_sec)
relevancy_grade = relevancy_score.binary_score
if relevancy_grade == '예':
self.state = "finished"
else:
self.state = "rewrite_question"
def run(self):
state_methods = {
"retrieve_documents": self.retrieve_documents,
"grade_documents": self.grade_documents,
"generate_llm_answer": self.generate_llm_answer,
"rewrite_question": self.rewrite_question,
"check_hallucination": self.check_hallucination,
"check_answer_relevance": self.check_answer_relevance,
"finished": lambda: print("Process finished successfully with answer:", self.answer)
}
while self.state != "finished" and self.patience > 0:
state_methods[self.state]()
if self.state == "finished":
logging.info('=== Self RAG is properly finished ===')
return self.answer
else:
logging.info('=== Self RAG cannot generate a proper answer. Use simple RAG ===')
self.question = self.initial_question
self.retrieve_documents()
self.generate_llm_answer()
return self.answer
# Self RAG
# question, prompt, retriever, model (언어모델) 정의 부분 생략
self_rag = SelfRAG(question, prompt, retriever, model)
response = self_rag.run()