from langchain_core.prompts import (
ChatPromptTemplate,
FewShotChatMessagePromptTemplate,
PromptTemplate,
)
from langchain.agents.mrkl import prompt as react_prompt
examples = [
{"input": "List all schedules.", "query": "SELECT * FROM table;"},
{
"input": "Find all delayed schedules.",
"query": "SELECT * FROM table WHERE Delayed > 0;",
}
]
system_prefix = """
You are an agent designed to interact with a SQL database.
Given an input question, create a syntactically correct query to run, then look at the results of the query and
return the answer. Unless the user specifies a specific number of examples they wish to obtain.
You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for the relevant columns given the question.
You have access to tools for interacting with the database.
Only use the given tools. Only use the information returned by the tools to construct your final answer.
You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query
and try again.
DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
You have access to the following tables:
{table_info}
Never access to unmentioned tables.
If the question does not seem related to the database, just return "I don't know" as the answer.
"""
basic_suffix = """
Begin!
Question: {input}
table : {table_info}
Thought: I should look at the tables in the database to see what I can query. Then I should query the schema of the most relevant tables.\n
{agent_scratchpad}
"""
example_prompt = ChatPromptTemplate.from_messages(
messages=[
('human', "{input}"),
('ai', "{query}")
]
)
few_shot_prompt = FewShotChatMessagePromptTemplate(
examples=examples,
example_prompt=example_prompt,
input_variables=["input",
"agent_scratchpad"],
)
format_instructions = f"{react_prompt.FORMAT_INSTRUCTIONS}\n " \
f"Here are some examples of user inputs and " \
f"their corresponding SQL queries:\n"
template = "\n\n".join(
[
system_prefix.format(table_info = "table"),
"{tools}",
format_instructions,
few_shot_prompt.format(),
basic_suffix
]
)
prompt = PromptTemplate.from_template(template=template)
# Create the SQL toolkit
toolkit = SQLDatabaseToolkit(db=db, llm=llm)
# Create the SQL agent
agent = create_sql_agent(
llm=llm,
toolkit=toolkit,
verbose=True,
prompt = prompt
)
agent.run("what are the delayed schedules?")
Reference
https://github.com/langchain-ai/langchain/issues/17939
https://data-newbie.tistory.com/965#pipeline