proj) multi chain chatbot

Jacob Kim·2024년 2월 2일
0

Naver Project Week5

목록 보기
11/12

구글 드라이브 연결

from google.colab import drive
drive.mount('/content/drive')

필수 설치 라이브러리

!pip install -U langchain openai
import os
from typing import Dict, List

from langchain.chains import ConversationChain, LLMChain, LLMRouterChain
from langchain.chains.router import MultiPromptChain
from langchain.chains.router.llm_router import RouterOutputParser
from langchain.chains.router.multi_prompt_prompt import MULTI_PROMPT_ROUTER_TEMPLATE
from langchain.chat_models import ChatOpenAI
from langchain.prompts import PromptTemplate
from langchain.prompts.chat import ChatPromptTemplate
from pydantic import BaseModel

API 키 입력

import getpass
import os
#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
os.environ["OPENAI_API_KEY"] = getpass.getpass()

LLM 파트 구현

  • 게임룰에 대한 정보들을 얻는 방법을 프롬프트 체인을 이용해 구성했습니다.
  • 부루마블이라는 보드게임을 진행하기위한 기본적인 rule과 건물을 지을 수 있는 규칙이 들어간 데이터를 이용해서 문답을 진행합니다.
PATH = "/content/drive/MyDrive/dataset/chain_prompts"
RULE_1 = os.path.join(
    PATH, "game_basic.txt"
)
RULE_2 = os.path.join(
    PATH, "game_building.txt"
)


def read_prompt_template(file_path: str) -> str:
    with open(file_path, "r") as f:
        prompt_template = f.read()

    return prompt_template


def create_chain(llm, template_path, output_key):
    return LLMChain(
        llm=llm,
        prompt=ChatPromptTemplate.from_template(
            template=read_prompt_template(template_path)
        ),
        output_key=output_key,
        verbose=True,
    )


llm = ChatOpenAI(temperature=0.1, max_tokens=200, model="gpt-3.5-turbo")

rule_1 = create_chain(
    llm=llm,
    template_path=RULE_1,
    output_key="text",
)
rule_2 = create_chain(
    llm=llm,
    template_path=RULE_2,
    output_key="text",
)


destinations = [
    "basic: This page describes the basic rules used to play the board game Burumble.",
    "building: This is where you'll find the rules for buildings as you play the board game.",
]
destinations = "\n".join(destinations)
router_prompt_template = MULTI_PROMPT_ROUTER_TEMPLATE.format(destinations=destinations)
router_prompt = PromptTemplate.from_template(
    template=router_prompt_template, output_parser=RouterOutputParser()
)
router_chain = LLMRouterChain.from_llm(llm=llm, prompt=router_prompt, verbose=True)

multi_prompt_chain = MultiPromptChain(
    router_chain=router_chain,
    destination_chains={
        "basic": rule_1,
        "building": rule_2,
    },
    default_chain=ConversationChain(llm=llm, output_key="text"),
)


class UserRequest(BaseModel):
    user_message: str


def gernerate_answer(req: UserRequest) -> Dict[str, str]:
    context = req.dict()
    context["input"] = context["user_message"]
    answer = multi_prompt_chain.run(context)

    return {"answer": answer}

User 데이터 입력

  • 유저 데이터 입력 후 결과를 확인 합니다.
user_data = {
    "user_message": "공자는 어떤 철학을 가지고 있어?"
}
request_instance = UserRequest(**user_data)
gernerate_answer(request_instance)
profile
AI, Information and Communication, Electronics, Computer Science, Bio, Algorithms

0개의 댓글