PDF 파일들을 읽어서 질문에 대한 응답을 생성하는 PDF RetrevalQA를 만들었는데 채팅 기록을 기억하도록 하고싶었다.
잘 찾아보니 이미 관련되어 토론이 진행된 주제였다.
위 토론에서 나온 방식으로 구현을 할까 했지만 기존 RetrevalQA 구조를 유지하고 싶어 소스코드를 분석했다.
기존 prompt에 chat history를 추가하고, RetrevalQA 클래스를 상속해서 _call 함수를 오버라이드해서 해결했다.
아래는 새로만든 RetrevalQA 클래스와 streamlit를 이용한 예제다.
import streamlit as st
from streamlit.delta_generator import DeltaGenerator
from langchain_community.chat_message_histories.in_memory import ChatMessageHistory
from langchain_community.document_loaders.pdf import PyPDFDirectoryLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_core.callbacks import CallbackManagerForChainRun
from langchain_core.prompts.chat import ChatPromptTemplate
from langchain.chains.retrieval_qa.base import RetrievalQA
from langchain_community.vectorstores.faiss import FAISS
from langchain.callbacks.base import BaseCallbackHandler
from langchain_core.prompts import MessagesPlaceholder
from langchain_openai import OpenAIEmbeddings
from langchain_openai import ChatOpenAI
from openai import AuthenticationError
import inspect
class ChatRetrievalQA(RetrievalQA):
def _call(self, inputs, run_manager=None):
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
question = inputs[self.input_key]
extra_kwargs = {key:inputs[key] for key in inputs if key != self.input_key}
accepts_run_manager = (
"run_manager" in inspect.signature(self._get_docs).parameters
)
if accepts_run_manager:
docs = self._get_docs(question, run_manager=_run_manager)
else:
docs = self._get_docs(question)
answer = self.combine_documents_chain.invoke(
input={
"question":question,
"input_documents":docs,
**extra_kwargs
},
config={
"callbacks":_run_manager.get_child(),
}
)
if self.return_source_documents:
return {self.output_key: answer["output_text"], "source_documents": docs}
else:
return {self.output_key: answer["output_text"]}
@st.cache_resource
def get_pdf_retriever():
loader = PyPDFDirectoryLoader("./pdfs/") # pdf 파일들이 모여있는 폴더 경로
data = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=10000, chunk_overlap=20)
text_chunks = text_splitter.split_documents(data)
embeddings = OpenAIEmbeddings()
vector_store = FAISS.from_documents(text_chunks, embedding=embeddings)
return vector_store.as_retriever(search_kwargs={"k": 2})
@st.cache_resource
def get_qa(openai_api_key:str):
llm = ChatOpenAI(
api_key = openai_api_key,
streaming = True,
temperature=0.75,
verbose=True,
)
prompt = ChatPromptTemplate.from_messages([
("system", """다음 컨텍스트를 사용하여 사용자의 질문에 답합니다.
답을 모르면 모르겠다고만 하면 됩니다. 답을 지어내려 하지 마세요.
한국어로 대답하세요.
----------------
{context}"""),
MessagesPlaceholder(variable_name="chat_history"),
("human", "{question}")
])
qa = ChatRetrievalQA.from_chain_type(llm=llm,
chain_type_kwargs={"prompt":prompt},
chain_type="stuff",
retriever=get_pdf_retriever())
return qa
class StreamHandler(BaseCallbackHandler):
def __init__(self, container : DeltaGenerator, initial_text=""):
self.container = container
self.text_container = [initial_text]
def on_llm_new_token(self, token: str, **kwargs) -> None:
self.text_container.append(token)
self.container.markdown(''.join(self.text_container))
st.title("PDF 리더 봇")
if "langchain_messages" not in st.session_state:
st.session_state.langchain_messages = ChatMessageHistory()
st.session_state.langchain_messages.add_ai_message("pdf 파일들을 읽은 내용을 토대로 답변해드립니다. 무엇을 도와드릴까요?")
def chat_clear_btn():
st.session_state.langchain_messages.clear()
st.session_state.langchain_messages.add_ai_message("pdf 파일들을 읽은 내용을 토대로 답변해드립니다. 무엇을 도와드릴까요?")
with st.sidebar:
openai_api_key = st.text_input("OpenAI API Key", type="password")
st.button("채팅 초기화", on_click=chat_clear_btn)
for message in st.session_state.langchain_messages.messages:
with st.chat_message(message.type):
st.markdown(str(message.content))
if prompt := st.chat_input("여기에 입력하세요!"):
qa = get_qa(openai_api_key)
with st.chat_message("user"):
st.markdown(prompt)
if qa:
with st.chat_message("ai"):
try:
handler = StreamHandler(st.empty())
response = qa.invoke(input={"query": prompt, "chat_history":st.session_state.langchain_messages.messages},
config={"callbacks":[handler]})
st.session_state.langchain_messages.add_user_message(prompt)
st.session_state.langchain_messages.add_ai_message(response['result'])
except AuthenticationError:
st.markdown("올바른 openai api key를 입력해주세요")
else:
st.cache_resource.clear()
with st.chat_message("ai"):
st.markdown("openai api key를 입력해 주세요.")
ChatRetrievalQA를 사용하면 invoke를 할 때, prompt에 필요한 변수를 input 값으로 줄 수 있게된다.
아래는 실행한 결과로 이름을 잘 기억해서 답변해주는 모습이다.

만약 streamlit 방식이 아니라 api 방식으로 하고싶으면 [LangChain] Agent의 응답을 스트리밍으로 하는 법!를 보자.