Corrective Retrieval Augmented Generation (link)
os.environ['TAVILY_API_KEY'] = 'your-tavily-api-key'
def create_doc_grader(model):
"""retriever가 검색한 문서의 연관성을 채점"""
class GradeDocuments(BaseModel):
"""검색된 문서들이 질문과 얼마나 관련이 있는지를 나타내는 이진 점수"""
binary_score: str = Field(
description="문서들이 질문과 연관성이 있는지 '예' 또는 '아니오'로 판단"
)
structured_llm_grader = model.with_structured_output(GradeDocuments)
system = """당신은 사용자 질문에 대해 검색된 문서의 연관성을 평가하는 채점자입니다. \n
문서가 질문과 관련된 키워드나 유사한 의미를 포함한다면, 해당 문서를 연관성이 있다고 판단하세요. \n
문서와 질문 사이의 연관성을 나타내기 위해 '예' 또는 '아니오'라는 이진 점수를 사용하세요."""
grade_prompt = ChatPromptTemplate.from_messages(
[
("system", system),
("human", "검색된 문서: \n\n {document} \n\n 사용자 질문: {question}"),
]
)
retrieval_grader = grade_prompt | structured_llm_grader
return retrieval_grader
def create_query_rewriter(model):
"""웹 검색을 위해 사용자 질문을 재작성"""
system = """당신은 입력된 질문을 웹 검색에 최적화된 형태로 바꿔주는 질문 재작성자입니다. \n
입력된 질문을 보고 내재된 의도나 의미를 추론해보세요."""
re_write_prompt = ChatPromptTemplate.from_messages(
[
("system", system),
(
"human",
"최초 질문은 다음과 같습니다: \n\n {question} \n 개선된 질문을 작성해주세요.",
),
]
)
question_rewriter = re_write_prompt | model | StrOutputParser()
return question_rewriter
# 언어모델, 임베딩모델, 벡터저장소 정의 부분 생략
# 질문이 담긴 데이터셋 호출 부분 생략
# 질문을 활용해 벡터저장소에서 관련된 문서 청크 (docs) 검색하는 부분 생략
##### Corrective RAG #####
retrieval_grader = create_doc_grader(model)
docs = retriever.invoke(question)
# Grade document relevancy
filtered_docs = []
need_web_search = 'No'
for d in docs:
score = retrieval_grader.invoke({'question': question, 'document': d.page_content})
grade = score.binary_score
if grade == '예':
filtered_docs.append(d)
elif grade == '아니오':
need_web_search = 'Yes'
continue
# Decide to generate
if need_web_search == 'Yes':
web_search_num = 검색문서수 - len(filtered_docs)
question_rewriter = create_query_rewriter(model)
better_question = question_rewriter.invoke({"question": question})
# Web search
web_search_tool = TavilySearchResults(max_results=web_search_num)
web_docs = web_search_tool.invoke({"query": better_question})
web_results = [Document(page_content=d["content"] for d in web_docs]
filtered_docs.extend(web_results)
... (중략) ...