[LangChain] RetrievalQA에 chat history 추가하기!

김준기·2024년 4월 4일

PDF 파일들을 읽어서 질문에 대한 응답을 생성하는 PDF RetrevalQA를 만들었는데 채팅 기록을 기억하도록 하고싶었다.

잘 찾아보니 이미 관련되어 토론이 진행된 주제였다.

위 토론에서 나온 방식으로 구현을 할까 했지만 기존 RetrevalQA 구조를 유지하고 싶어 소스코드를 분석했다.
기존 promptchat history를 추가하고, RetrevalQA 클래스를 상속해서 _call 함수를 오버라이드해서 해결했다.

아래는 새로만든 RetrevalQA 클래스와 streamlit를 이용한 예제다.

import streamlit as st
from streamlit.delta_generator import DeltaGenerator

from langchain_community.chat_message_histories.in_memory import ChatMessageHistory
from langchain_community.document_loaders.pdf import PyPDFDirectoryLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_core.callbacks import CallbackManagerForChainRun
from langchain_core.prompts.chat import ChatPromptTemplate
from langchain.chains.retrieval_qa.base import RetrievalQA
from langchain_community.vectorstores.faiss import FAISS
from langchain.callbacks.base import BaseCallbackHandler
from langchain_core.prompts import MessagesPlaceholder
from langchain_openai import OpenAIEmbeddings
from langchain_openai import ChatOpenAI
from openai import AuthenticationError
import inspect


class ChatRetrievalQA(RetrievalQA):
    def _call(self, inputs, run_manager=None):
        _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
        question = inputs[self.input_key]
        extra_kwargs = {key:inputs[key] for key in inputs if key != self.input_key}
        accepts_run_manager = (
            "run_manager" in inspect.signature(self._get_docs).parameters
        )
        if accepts_run_manager:
            docs = self._get_docs(question, run_manager=_run_manager)
        else:
            docs = self._get_docs(question)
        answer = self.combine_documents_chain.invoke(
            input={
                "question":question, 
                "input_documents":docs, 
                **extra_kwargs
            },
            config={
                "callbacks":_run_manager.get_child(), 
            }
        )
        if self.return_source_documents:
            return {self.output_key: answer["output_text"], "source_documents": docs}
        else:
            return {self.output_key: answer["output_text"]}

@st.cache_resource
def get_pdf_retriever():
    loader = PyPDFDirectoryLoader("./pdfs/") # pdf 파일들이 모여있는 폴더 경로
    data = loader.load()
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=10000, chunk_overlap=20)
    text_chunks = text_splitter.split_documents(data)
    embeddings = OpenAIEmbeddings()
    vector_store = FAISS.from_documents(text_chunks, embedding=embeddings)
    return vector_store.as_retriever(search_kwargs={"k": 2})

@st.cache_resource
def get_qa(openai_api_key:str):
    llm = ChatOpenAI(
        api_key = openai_api_key,
        streaming = True,
        temperature=0.75,
        verbose=True,
    )

    prompt = ChatPromptTemplate.from_messages([
        ("system", """다음 컨텍스트를 사용하여 사용자의 질문에 답합니다.
답을 모르면 모르겠다고만 하면 됩니다. 답을 지어내려 하지 마세요.
한국어로 대답하세요.
----------------
{context}"""),
    MessagesPlaceholder(variable_name="chat_history"),
    ("human", "{question}")
    ])

    qa = ChatRetrievalQA.from_chain_type(llm=llm, 
                                         chain_type_kwargs={"prompt":prompt},
                                         chain_type="stuff", 
                                         retriever=get_pdf_retriever())
    return qa

class StreamHandler(BaseCallbackHandler):
    def __init__(self, container : DeltaGenerator, initial_text=""):
        self.container = container
        self.text_container = [initial_text]

    def on_llm_new_token(self, token: str, **kwargs) -> None:
        self.text_container.append(token)
        self.container.markdown(''.join(self.text_container))

st.title("PDF 리더 봇")

if "langchain_messages" not in st.session_state:
    st.session_state.langchain_messages = ChatMessageHistory()
    st.session_state.langchain_messages.add_ai_message("pdf 파일들을 읽은 내용을 토대로 답변해드립니다. 무엇을 도와드릴까요?")

def chat_clear_btn():
    st.session_state.langchain_messages.clear()
    st.session_state.langchain_messages.add_ai_message("pdf 파일들을 읽은 내용을 토대로 답변해드립니다. 무엇을 도와드릴까요?")

with st.sidebar:
    openai_api_key = st.text_input("OpenAI API Key", type="password")

    st.button("채팅 초기화", on_click=chat_clear_btn)

for message in st.session_state.langchain_messages.messages:
    with st.chat_message(message.type):
        st.markdown(str(message.content))
        
if prompt := st.chat_input("여기에 입력하세요!"):
    qa = get_qa(openai_api_key)
    with st.chat_message("user"):
        st.markdown(prompt)

    if qa:
        with st.chat_message("ai"):
                try:
                    handler = StreamHandler(st.empty())
                    response = qa.invoke(input={"query": prompt, "chat_history":st.session_state.langchain_messages.messages},
                                         config={"callbacks":[handler]})
                    st.session_state.langchain_messages.add_user_message(prompt)
                    st.session_state.langchain_messages.add_ai_message(response['result'])
                except AuthenticationError:
                    st.markdown("올바른 openai api key를 입력해주세요")

    else:
        st.cache_resource.clear()
        with st.chat_message("ai"):
            st.markdown("openai api key를 입력해 주세요.")

ChatRetrievalQA를 사용하면 invoke를 할 때, prompt에 필요한 변수를 input 값으로 줄 수 있게된다.

아래는 실행한 결과로 이름을 잘 기억해서 답변해주는 모습이다.

만약 streamlit 방식이 아니라 api 방식으로 하고싶으면 [LangChain] Agent의 응답을 스트리밍으로 하는 법!를 보자.

profile
코딩 잘하고 싶은 백엔드 개발자

0개의 댓글