구글 드라이브 연결
from google.colab import drive
drive.mount('/content/drive')
필수 설치 라이브러리
!pip install -U langchain openai
import os
from typing import Dict, List
from langchain.chains import ConversationChain, LLMChain, LLMRouterChain
from langchain.chains.router import MultiPromptChain
from langchain.chains.router.llm_router import RouterOutputParser
from langchain.chains.router.multi_prompt_prompt import MULTI_PROMPT_ROUTER_TEMPLATE
from langchain.chat_models import ChatOpenAI
from langchain.prompts import PromptTemplate
from langchain.prompts.chat import ChatPromptTemplate
from pydantic import BaseModel
API 키 입력
import getpass
import os
os.environ["OPENAI_API_KEY"] = getpass.getpass()
LLM 파트 구현
- 게임룰에 대한 정보들을 얻는 방법을 프롬프트 체인을 이용해 구성했습니다.
- 부루마블이라는 보드게임을 진행하기위한 기본적인 rule과 건물을 지을 수 있는 규칙이 들어간 데이터를 이용해서 문답을 진행합니다.
PATH = "/content/drive/MyDrive/dataset/chain_prompts"
RULE_1 = os.path.join(
PATH, "game_basic.txt"
)
RULE_2 = os.path.join(
PATH, "game_building.txt"
)
def read_prompt_template(file_path: str) -> str:
with open(file_path, "r") as f:
prompt_template = f.read()
return prompt_template
def create_chain(llm, template_path, output_key):
return LLMChain(
llm=llm,
prompt=ChatPromptTemplate.from_template(
template=read_prompt_template(template_path)
),
output_key=output_key,
verbose=True,
)
llm = ChatOpenAI(temperature=0.1, max_tokens=200, model="gpt-3.5-turbo")
rule_1 = create_chain(
llm=llm,
template_path=RULE_1,
output_key="text",
)
rule_2 = create_chain(
llm=llm,
template_path=RULE_2,
output_key="text",
)
destinations = [
"basic: This page describes the basic rules used to play the board game Burumble.",
"building: This is where you'll find the rules for buildings as you play the board game.",
]
destinations = "\n".join(destinations)
router_prompt_template = MULTI_PROMPT_ROUTER_TEMPLATE.format(destinations=destinations)
router_prompt = PromptTemplate.from_template(
template=router_prompt_template, output_parser=RouterOutputParser()
)
router_chain = LLMRouterChain.from_llm(llm=llm, prompt=router_prompt, verbose=True)
multi_prompt_chain = MultiPromptChain(
router_chain=router_chain,
destination_chains={
"basic": rule_1,
"building": rule_2,
},
default_chain=ConversationChain(llm=llm, output_key="text"),
)
class UserRequest(BaseModel):
user_message: str
def gernerate_answer(req: UserRequest) -> Dict[str, str]:
context = req.dict()
context["input"] = context["user_message"]
answer = multi_prompt_chain.run(context)
return {"answer": answer}
User 데이터 입력
user_data = {
"user_message": "공자는 어떤 철학을 가지고 있어?"
}
request_instance = UserRequest(**user_data)
gernerate_answer(request_instance)