建立基於 SQL 資料的問答系統
使 LLM 系統能夠查詢結構化資料,在性質上可能與非結構化文字資料有所不同。後者常見的做法是產生可針對向量資料庫搜尋的文字,而結構化資料的方法通常是讓 LLM 以 DSL(例如 SQL)撰寫和執行查詢。在本指南中,我們將介紹在資料庫中的表格資料上建立問答系統的基本方法。我們將涵蓋使用鏈和代理程式的實作。這些系統將允許我們詢問有關資料庫中資料的問題,並取回自然語言答案。兩者之間的主要區別在於,我們的代理程式可以循環查詢資料庫多次,直到回答問題為止。
⚠️ 安全注意事項 ⚠️
建立 SQL 資料庫的問答系統需要執行模型產生的 SQL 查詢。這樣做存在固有的風險。請務必確保您的資料庫連線權限始終盡可能地針對您的鏈/代理程式的需求進行範圍限制。這將減輕但不能消除建立模型驅動系統的風險。有關一般安全最佳實務的更多資訊,請參閱此處。
架構
在高層次上,這些系統的步驟如下:
- 將問題轉換為 SQL 查詢:模型將使用者輸入轉換為 SQL 查詢。
- 執行 SQL 查詢:執行查詢。
- 回答問題:模型使用查詢結果回應使用者輸入。
請注意,查詢 CSV 中的資料可以遵循類似的方法。有關更多詳細資訊,請參閱我們關於 CSV 資料問答的操作指南。
設定
首先,取得所需的套件並設定環境變數
%%capture --no-stderr
%pip install --upgrade --quiet langchain-community langchainhub langgraph
# Comment out the below to opt-out of using LangSmith in this notebook. Not required.
if not os.environ.get("LANGCHAIN_API_KEY"):
os.environ["LANGCHAIN_API_KEY"] = getpass.getpass()
os.environ["LANGCHAIN_TRACING_V2"] = "true"
範例資料
以下範例將使用 SQLite 連線和 Chinook 資料庫,這是一個代表數位媒體商店的範例資料庫。請依照這些安裝步驟,在與此筆記本相同的目錄中建立 Chinook.db
。您也可以透過命令列下載並建置資料庫
curl -s https://raw.githubusercontent.com/lerocha/chinook-database/master/ChinookDatabase/DataSources/Chinook_Sqlite.sql | sqlite3 Chinook.db
現在,Chinook.db
在我們的目錄中,我們可以使用 SQLAlchemy 驅動的 SQLDatabase
類別與其互動
from langchain_community.utilities import SQLDatabase
db = SQLDatabase.from_uri("sqlite:///Chinook.db")
print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT * FROM Artist LIMIT 10;")
sqlite
['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']
"[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]"
太棒了!我們有一個可以查詢的 SQL 資料庫。現在讓我們嘗試將其連接到 LLM。
鏈
鏈是可預測步驟的組合。在LangGraph中,我們可以透過簡單的節點序列來表示鏈。讓我們建立一個步驟序列,在給定問題的情況下,執行以下操作:
- 將問題轉換為 SQL 查詢;
- 執行查詢;
- 使用結果回答原始問題。
有些情境不受此安排支援。例如,此系統將為任何使用者輸入(甚至是「hello」)執行 SQL 查詢。重要的是,正如我們將在下面看到的,有些問題需要多次查詢才能回答。我們將在「代理程式」章節中解決這些情境。
應用程式狀態
我們的應用程式的 LangGraph 狀態控制著輸入到應用程式、在步驟之間傳輸以及由應用程式輸出的資料。它通常是 TypedDict
,但也可以是 Pydantic BaseModel。
對於此應用程式,我們可以只追蹤輸入問題、產生的查詢、查詢結果和產生的答案
from typing_extensions import TypedDict
class State(TypedDict):
question: str
query: str
result: str
answer: str
現在我們只需要操作此狀態並填充其內容的函式。
將問題轉換為 SQL 查詢
第一步是取得使用者輸入並將其轉換為 SQL 查詢。為了可靠地取得 SQL 查詢(沒有 Markdown 格式和說明或澄清),我們將使用 LangChain 的結構化輸出抽象。
讓我們為我們的應用程式選擇一個聊天模型
pip install -qU langchain-openai
import getpass
import os
if not os.environ.get("OPENAI_API_KEY"):
os.environ["OPENAI_API_KEY"] = getpass.getpass("Enter API key for OpenAI: ")
from langchain_openai import ChatOpenAI
llm = ChatOpenAI(model="gpt-4o-mini")
我們將從 Prompt Hub 中提取提示來指示模型。
from langchain import hub
query_prompt_template = hub.pull("langchain-ai/sql-query-system-prompt")
assert len(query_prompt_template.messages) == 1
query_prompt_template.messages[0].pretty_print()
================================[1m System Message [0m================================
Given an input question, create a syntactically correct [33;1m[1;3m{dialect}[0m query to run to help find the answer. Unless the user specifies in his question a specific number of examples they wish to obtain, always limit your query to at most [33;1m[1;3m{top_k}[0m results. You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for a the few relevant columns given the question.
Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Only use the following tables:
[33;1m[1;3m{table_info}[0m
Question: [33;1m[1;3m{input}[0m
提示包含幾個我們需要填充的參數,例如 SQL 方言和表格結構描述。LangChain 的 SQLDatabase 物件包含有助於此的方法。我們的 write_query
步驟只會填充這些參數並提示模型產生 SQL 查詢
from typing_extensions import Annotated
class QueryOutput(TypedDict):
"""Generated SQL query."""
query: Annotated[str, ..., "Syntactically valid SQL query."]
def write_query(state: State):
"""Generate SQL query to fetch information."""
prompt = query_prompt_template.invoke(
{
"dialect": db.dialect,
"top_k": 10,
"table_info": db.get_table_info(),
"input": state["question"],
}
)
structured_llm = llm.with_structured_output(QueryOutput)
result = structured_llm.invoke(prompt)
return {"query": result["query"]}
讓我們測試一下
write_query({"question": "How many Employees are there?"})
{'query': 'SELECT COUNT(EmployeeId) AS EmployeeCount FROM Employee;'}
執行查詢
這是建立 SQL 鏈中最危險的部分。 請仔細考慮是否可以自動查詢您的資料。盡可能減少資料庫連線權限。考慮在查詢執行之前為您的鏈新增人工核准步驟(見下文)。
為了執行查詢,我們將從 langchain-community 載入一個工具。我們的 execute_query
節點將只封裝這個工具
from langchain_community.tools.sql_database.tool import QuerySQLDatabaseTool
def execute_query(state: State):
"""Execute SQL query."""
execute_query_tool = QuerySQLDatabaseTool(db=db)
return {"result": execute_query_tool.invoke(state["query"])}
測試此步驟
execute_query({"query": "SELECT COUNT(EmployeeId) AS EmployeeCount FROM Employee;"})
{'result': '[(8,)]'}
產生答案
最後,我們的最後一個步驟是根據從資料庫提取的資訊產生問題的答案
def generate_answer(state: State):
"""Answer question using retrieved information as context."""
prompt = (
"Given the following user question, corresponding SQL query, "
"and SQL result, answer the user question.\n\n"
f'Question: {state["question"]}\n'
f'SQL Query: {state["query"]}\n'
f'SQL Result: {state["result"]}'
)
response = llm.invoke(prompt)
return {"answer": response.content}
使用 LangGraph 編排
最後,我們將我們的應用程式編譯成單個 graph
物件。在本例中,我們只是將這三個步驟連接成一個序列。
from langgraph.graph import START, StateGraph
graph_builder = StateGraph(State).add_sequence(
[write_query, execute_query, generate_answer]
)
graph_builder.add_edge(START, "write_query")
graph = graph_builder.compile()
LangGraph 還附帶內建的實用程式,用於視覺化您的應用程式的控制流程
from IPython.display import Image, display
display(Image(graph.get_graph().draw_mermaid_png()))
讓我們測試一下我們的應用程式!請注意,我們可以串流個別步驟的結果
for step in graph.stream(
{"question": "How many employees are there?"}, stream_mode="updates"
):
print(step)
{'write_query': {'query': 'SELECT COUNT(EmployeeId) AS EmployeeCount FROM Employee;'}}
{'execute_query': {'result': '[(8,)]'}}
{'generate_answer': {'answer': 'There are 8 employees.'}}
查看 LangSmith 追蹤。
人工參與迴路
LangGraph 支援許多功能,這些功能對於此工作流程可能很有用。其中之一是人工參與迴路:我們可以在敏感步驟(例如執行 SQL 查詢)之前中斷我們的應用程式,以進行人工審查。這由 LangGraph 的持久性層啟用,該層將執行進度儲存到您選擇的儲存空間中。在下面,我們指定記憶體中的儲存空間
from langgraph.checkpoint.memory import MemorySaver
memory = MemorySaver()
graph = graph_builder.compile(checkpointer=memory, interrupt_before=["execute_query"])
# Now that we're using persistence, we need to specify a thread ID
# so that we can continue the run after review.
config = {"configurable": {"thread_id": "1"}}
display(Image(graph.get_graph().draw_mermaid_png()))
讓我們重複相同的執行,加入一個簡單的是/否核准步驟
for step in graph.stream(
{"question": "How many employees are there?"},
config,
stream_mode="updates",
):
print(step)
try:
user_approval = input("Do you want to go to execute query? (yes/no): ")
except Exception:
user_approval = "no"
if user_approval.lower() == "yes":
# If approved, continue the graph execution
for step in graph.stream(None, config, stream_mode="updates"):
print(step)
else:
print("Operation cancelled by user.")
{'write_query': {'query': 'SELECT COUNT(EmployeeId) AS EmployeeCount FROM Employee;'}}
{'__interrupt__': ()}
``````output
Do you want to go to execute query? (yes/no): yes
``````output
{'execute_query': {'result': '[(8,)]'}}
{'generate_answer': {'answer': 'There are 8 employees.'}}
有關更多詳細資訊和範例,請參閱此 LangGraph 指南。
後續步驟
對於更複雜的查詢產生,我們可能想要建立少量樣本提示或新增查詢檢查步驟。如需此類進階技術和更多資訊,請查看
代理程式
代理程式利用 LLM 的推理能力在執行期間做出決策。使用代理程式可讓您卸載對查詢產生和執行過程的額外酌情權。雖然它們的行為不如上述「鏈」可預測,但它們具有一些優勢:
- 它們可以根據需要多次查詢資料庫以回答使用者問題。
- 它們可以從錯誤中恢復,方法是執行產生的查詢,捕獲追蹤並正確地重新產生它。
- 它們可以根據資料庫的結構描述以及資料庫的內容(例如描述特定表格)回答問題。
下面我們組裝一個最小的 SQL 代理程式。我們將使用 LangChain 的 SQLDatabaseToolkit 為其配備一組工具。使用 LangGraph 的預先建置的 ReAct 代理程式建構子,我們可以一行程式碼完成此操作。
查看 LangGraph 的SQL 代理程式教學,以取得更進階的 SQL 代理程式公式。
SQLDatabaseToolkit
包括可以執行以下操作的工具:
- 建立和執行查詢
- 檢查查詢語法
- 檢索表格描述
- ... 以及更多
from langchain_community.agent_toolkits import SQLDatabaseToolkit
toolkit = SQLDatabaseToolkit(db=db, llm=llm)
tools = toolkit.get_tools()
tools
[QuerySQLDatabaseTool(description="Input to this tool is a detailed and correct SQL query, output is a result from the database. If the query is not correct, an error message will be returned. If an error is returned, rewrite the query, check the query, and try again. If you encounter an issue with Unknown column 'xxxx' in 'field list', use sql_db_schema to query the correct table fields.", db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x10d5f9120>),
InfoSQLDatabaseTool(description='Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables. Be sure that the tables actually exist by calling sql_db_list_tables first! Example Input: table1, table2, table3', db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x10d5f9120>),
ListSQLDatabaseTool(db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x10d5f9120>),
QuerySQLCheckerTool(description='Use this tool to double check if your query is correct before executing it. Always use this tool before executing a query with sql_db_query!', db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x10d5f9120>, llm=ChatOpenAI(client=<openai.resources.chat.completions.Completions object at 0x119315480>, async_client=<openai.resources.chat.completions.AsyncCompletions object at 0x119317550>, root_client=<openai.OpenAI object at 0x10d5f8df0>, root_async_client=<openai.AsyncOpenAI object at 0x1193154e0>, model_name='gpt-4o', temperature=0.0, model_kwargs={}, openai_api_key=SecretStr('**********')), llm_chain=LLMChain(verbose=False, prompt=PromptTemplate(input_variables=['dialect', 'query'], input_types={}, partial_variables={}, template='\n{query}\nDouble check the {dialect} query above for common mistakes, including:\n- Using NOT IN with NULL values\n- Using UNION when UNION ALL should have been used\n- Using BETWEEN for exclusive ranges\n- Data type mismatch in predicates\n- Properly quoting identifiers\n- Using the correct number of arguments for functions\n- Casting to the correct data type\n- Using the proper columns for joins\n\nIf there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.\n\nOutput the final SQL query only.\n\nSQL Query: '), llm=ChatOpenAI(client=<openai.resources.chat.completions.Completions object at 0x119315480>, async_client=<openai.resources.chat.completions.AsyncCompletions object at 0x119317550>, root_client=<openai.OpenAI object at 0x10d5f8df0>, root_async_client=<openai.AsyncOpenAI object at 0x1193154e0>, model_name='gpt-4o', temperature=0.0, model_kwargs={}, openai_api_key=SecretStr('**********')), output_parser=StrOutputParser(), llm_kwargs={}))]
系統提示
我們還需要為我們的代理程式載入系統提示。這將包含有關如何運作的指示。
from langchain import hub
prompt_template = hub.pull("langchain-ai/sql-agent-system-prompt")
assert len(prompt_template.messages) == 1
prompt_template.messages[0].pretty_print()
================================[1m System Message [0m================================
You are an agent designed to interact with a SQL database.
Given an input question, create a syntactically correct [33;1m[1;3m{dialect}[0m query to run, then look at the results of the query and return the answer.
Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most [33;1m[1;3m{top_k}[0m results.
You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for the relevant columns given the question.
You have access to tools for interacting with the database.
Only use the below tools. Only use the information returned by the below tools to construct your final answer.
You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.
DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
To start you should ALWAYS look at the tables in the database to see what you can query.
Do NOT skip this step.
Then you should query the schema of the most relevant tables.
讓我們填充提示中突出顯示的參數
system_message = prompt_template.format(dialect="SQLite", top_k=5)
初始化代理程式
我們將使用預先建置的 LangGraph 代理程式來建置我們的代理程式
from langchain_core.messages import HumanMessage
from langgraph.prebuilt import create_react_agent
agent_executor = create_react_agent(llm, tools, state_modifier=system_message)
考慮代理程式如何回應以下問題
question = "Which country's customers spent the most?"
for step in agent_executor.stream(
{"messages": [{"role": "user", "content": question}]},
stream_mode="values",
):
step["messages"][-1].pretty_print()
================================[1m Human Message [0m=================================
Which country's customers spent the most?
==================================[1m Ai Message [0m==================================
Tool Calls:
sql_db_list_tables (call_tFp7HYD6sAAmCShgeqkVZH6Q)
Call ID: call_tFp7HYD6sAAmCShgeqkVZH6Q
Args:
=================================[1m Tool Message [0m=================================
Name: sql_db_list_tables
Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track
==================================[1m Ai Message [0m==================================
Tool Calls:
sql_db_schema (call_KJZ1Jx6JazyDdJa0uH1UeiOz)
Call ID: call_KJZ1Jx6JazyDdJa0uH1UeiOz
Args:
table_names: Customer, Invoice
=================================[1m Tool Message [0m=================================
Name: sql_db_schema
CREATE TABLE "Customer" (
"CustomerId" INTEGER NOT NULL,
"FirstName" NVARCHAR(40) NOT NULL,
"LastName" NVARCHAR(20) NOT NULL,
"Company" NVARCHAR(80),
"Address" NVARCHAR(70),
"City" NVARCHAR(40),
"State" NVARCHAR(40),
"Country" NVARCHAR(40),
"PostalCode" NVARCHAR(10),
"Phone" NVARCHAR(24),
"Fax" NVARCHAR(24),
"Email" NVARCHAR(60) NOT NULL,
"SupportRepId" INTEGER,
PRIMARY KEY ("CustomerId"),
FOREIGN KEY("SupportRepId") REFERENCES "Employee" ("EmployeeId")
)
/*
3 rows from Customer table:
CustomerId FirstName LastName Company Address City State Country PostalCode Phone Fax Email SupportRepId
1 Luís Gonçalves Embraer - Empresa Brasileira de Aeronáutica S.A. Av. Brigadeiro Faria Lima, 2170 São José dos Campos SP Brazil 12227-000 +55 (12) 3923-5555 +55 (12) 3923-5566 luisg@embraer.com.br 3
2 Leonie Köhler None Theodor-Heuss-Straße 34 Stuttgart None Germany 70174 +49 0711 2842222 None leonekohler@surfeu.de 5
3 François Tremblay None 1498 rue Bélanger Montréal QC Canada H2G 1A7 +1 (514) 721-4711 None ftremblay@gmail.com 3
*/
CREATE TABLE "Invoice" (
"InvoiceId" INTEGER NOT NULL,
"CustomerId" INTEGER NOT NULL,
"InvoiceDate" DATETIME NOT NULL,
"BillingAddress" NVARCHAR(70),
"BillingCity" NVARCHAR(40),
"BillingState" NVARCHAR(40),
"BillingCountry" NVARCHAR(40),
"BillingPostalCode" NVARCHAR(10),
"Total" NUMERIC(10, 2) NOT NULL,
PRIMARY KEY ("InvoiceId"),
FOREIGN KEY("CustomerId") REFERENCES "Customer" ("CustomerId")
)
/*
3 rows from Invoice table:
InvoiceId CustomerId InvoiceDate BillingAddress BillingCity BillingState BillingCountry BillingPostalCode Total
1 2 2021-01-01 00:00:00 Theodor-Heuss-Straße 34 Stuttgart None Germany 70174 1.98
2 4 2021-01-02 00:00:00 Ullevålsveien 14 Oslo None Norway 0171 3.96
3 8 2021-01-03 00:00:00 Grétrystraat 63 Brussels None Belgium 1000 5.94
*/
==================================[1m Ai Message [0m==================================
Tool Calls:
sql_db_query_checker (call_AQuTGbgH63u4gPgyV723yrjX)
Call ID: call_AQuTGbgH63u4gPgyV723yrjX
Args:
query: SELECT c.Country, SUM(i.Total) as TotalSpent FROM Customer c JOIN Invoice i ON c.CustomerId = i.CustomerId GROUP BY c.Country ORDER BY TotalSpent DESC LIMIT 1;
=================================[1m Tool Message [0m=================================
Name: sql_db_query_checker
\`\`\`sql
SELECT c.Country, SUM(i.Total) as TotalSpent FROM Customer c JOIN Invoice i ON c.CustomerId = i.CustomerId GROUP BY c.Country ORDER BY TotalSpent DESC LIMIT 1;
\`\`\`
==================================[1m Ai Message [0m==================================
Tool Calls:
sql_db_query (call_B88EwU44nwwpQL5M9nlcemSU)
Call ID: call_B88EwU44nwwpQL5M9nlcemSU
Args:
query: SELECT c.Country, SUM(i.Total) as TotalSpent FROM Customer c JOIN Invoice i ON c.CustomerId = i.CustomerId GROUP BY c.Country ORDER BY TotalSpent DESC LIMIT 1;
=================================[1m Tool Message [0m=================================
Name: sql_db_query
[('USA', 523.06)]
==================================[1m Ai Message [0m==================================
The country whose customers spent the most is the USA, with a total spending of 523.06.
您也可以使用 LangSmith 追蹤來視覺化這些步驟和相關元數據。
請注意,代理程式會執行多個查詢,直到它擁有需要的資訊
- 列出可用的表格;
- 檢索三個表格的結構描述;
- 透過聯結操作查詢多個表格。
然後,代理程式能夠使用最終查詢的結果來產生原始問題的答案。
代理程式也可以類似地處理定性問題
question = "Describe the playlisttrack table"
for step in agent_executor.stream(
{"messages": [{"role": "user", "content": question}]},
stream_mode="values",
):
step["messages"][-1].pretty_print()
================================[1m Human Message [0m=================================
Describe the playlisttrack table
==================================[1m Ai Message [0m==================================
Tool Calls:
sql_db_list_tables (call_fMF8eTmX5TJDJjc3Mhdg52TI)
Call ID: call_fMF8eTmX5TJDJjc3Mhdg52TI
Args:
=================================[1m Tool Message [0m=================================
Name: sql_db_list_tables
Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track
==================================[1m Ai Message [0m==================================
Tool Calls:
sql_db_schema (call_W8Vkk4NEodkAAIg8nexAszUH)
Call ID: call_W8Vkk4NEodkAAIg8nexAszUH
Args:
table_names: PlaylistTrack
=================================[1m Tool Message [0m=================================
Name: sql_db_schema
CREATE TABLE "PlaylistTrack" (
"PlaylistId" INTEGER NOT NULL,
"TrackId" INTEGER NOT NULL,
PRIMARY KEY ("PlaylistId", "TrackId"),
FOREIGN KEY("TrackId") REFERENCES "Track" ("TrackId"),
FOREIGN KEY("PlaylistId") REFERENCES "Playlist" ("PlaylistId")
)
/*
3 rows from PlaylistTrack table:
PlaylistId TrackId
1 3402
1 3389
1 3390
*/
==================================[1m Ai Message [0m==================================
The `PlaylistTrack` table is designed to associate tracks with playlists. It has the following structure:
- **PlaylistId**: An integer that serves as a foreign key referencing the `Playlist` table. It is part of the composite primary key.
- **TrackId**: An integer that serves as a foreign key referencing the `Track` table. It is also part of the composite primary key.
The primary key for this table is a composite key consisting of both `PlaylistId` and `TrackId`, ensuring that each track can be uniquely associated with a playlist. The table enforces referential integrity by linking to the `Track` and `Playlist` tables through foreign keys.
處理高基數欄位
為了篩選包含專有名詞(例如地址、歌曲名稱或藝術家)的欄位,我們首先需要仔細檢查拼字,以便正確篩選資料。
我們可以透過建立一個向量儲存庫來實現這一點,該儲存庫包含資料庫中存在的所有不同的專有名詞。然後,每次使用者在其問題中包含專有名詞時,我們可以讓代理程式查詢該向量儲存庫,以找到該詞的正確拼字。透過這種方式,代理程式可以確保它在建構目標查詢之前了解使用者指的是哪個實體。
首先,我們需要我們想要的每個實體的唯一值,為此我們定義一個將結果解析為元素列表的函式
import ast
import re
def query_as_list(db, query):
res = db.run(query)
res = [el for sub in ast.literal_eval(res) for el in sub if el]
res = [re.sub(r"\b\d+\b", "", string).strip() for string in res]
return list(set(res))
artists = query_as_list(db, "SELECT Name FROM Artist")
albums = query_as_list(db, "SELECT Title FROM Album")
albums[:5]
['In Through The Out Door',
'Transmission',
'Battlestar Galactica (Classic), Season',
'A Copland Celebration, Vol. I',
'Quiet Songs']
使用此函式,我們可以建立一個檢索器工具,代理程式可以自行決定執行。
選擇嵌入模型:
pip install -qU langchain-openai
import getpass
import os
if not os.environ.get("OPENAI_API_KEY"):
os.environ["OPENAI_API_KEY"] = getpass.getpass("Enter API key for OpenAI: ")
from langchain_openai import OpenAIEmbeddings
embeddings = OpenAIEmbeddings(model="text-embedding-3-large")
選擇向量儲存庫:
pip install -qU langchain-core
from langchain_core.vectorstores import InMemoryVectorStore
vector_store = InMemoryVectorStore(embeddings)
我們現在可以建構一個檢索工具,該工具可以搜尋資料庫中相關的專有名詞
from langchain.agents.agent_toolkits import create_retriever_tool
_ = vector_store.add_texts(artists + albums)
retriever = vector_store.as_retriever(search_kwargs={"k": 5})
description = (
"Use to look up values to filter on. Input is an approximate spelling "
"of the proper noun, output is valid proper nouns. Use the noun most "
"similar to the search."
)
retriever_tool = create_retriever_tool(
retriever,
name="search_proper_nouns",
description=description,
)
讓我們試試看
print(retriever_tool.invoke("Alice Chains"))
Alice In Chains
Alanis Morissette
Pearl Jam
Pearl Jam
Audioslave
這樣一來,如果代理程式確定它需要根據像「Alice Chains」這樣的藝術家來編寫篩選器,它可以先使用檢索器工具來觀察欄位的相關值。
將其整合在一起
# Add to system message
suffix = (
"If you need to filter on a proper noun like a Name, you must ALWAYS first look up "
"the filter value using the 'search_proper_nouns' tool! Do not try to "
"guess at the proper name - use this function to find similar ones."
)
system = f"{system_message}\n\n{suffix}"
tools.append(retriever_tool)
agent = create_react_agent(llm, tools, state_modifier=system)
question = "How many albums does alis in chain have?"
for step in agent.stream(
{"messages": [{"role": "user", "content": question}]},
stream_mode="values",
):
step["messages"][-1].pretty_print()
================================[1m Human Message [0m=================================
How many albums does alis in chain have?
==================================[1m Ai Message [0m==================================
Tool Calls:
search_proper_nouns (call_8ryjsRPLAr79mM3Qvnq6gTOH)
Call ID: call_8ryjsRPLAr79mM3Qvnq6gTOH
Args:
query: alis in chain
=================================[1m Tool Message [0m=================================
Name: search_proper_nouns
Alice In Chains
Aisha Duo
Xis
Da Lama Ao Caos
A-Sides
==================================[1m Ai Message [0m==================================
Tool Calls:
sql_db_list_tables (call_NJjtCpU89MBMplssjn1z0xzq)
Call ID: call_NJjtCpU89MBMplssjn1z0xzq
Args:
search_proper_nouns (call_1BfrueC9koSIyi4OfMu2Ao8q)
Call ID: call_1BfrueC9koSIyi4OfMu2Ao8q
Args:
query: Alice In Chains
=================================[1m Tool Message [0m=================================
Name: search_proper_nouns
Alice In Chains
Pearl Jam
Pearl Jam
Foo Fighters
Soundgarden
==================================[1m Ai Message [0m==================================
Tool Calls:
sql_db_schema (call_Kn09w9jd9swcNzIZ1b5MlKID)
Call ID: call_Kn09w9jd9swcNzIZ1b5MlKID
Args:
table_names: Album, Artist
=================================[1m Tool Message [0m=================================
Name: sql_db_schema
CREATE TABLE "Album" (
"AlbumId" INTEGER NOT NULL,
"Title" NVARCHAR(160) NOT NULL,
"ArtistId" INTEGER NOT NULL,
PRIMARY KEY ("AlbumId"),
FOREIGN KEY("ArtistId") REFERENCES "Artist" ("ArtistId")
)
/*
3 rows from Album table:
AlbumId Title ArtistId
1 For Those About To Rock We Salute You 1
2 Balls to the Wall 2
3 Restless and Wild 2
*/
CREATE TABLE "Artist" (
"ArtistId" INTEGER NOT NULL,
"Name" NVARCHAR(120),
PRIMARY KEY ("ArtistId")
)
/*
3 rows from Artist table:
ArtistId Name
1 AC/DC
2 Accept
3 Aerosmith
*/
==================================[1m Ai Message [0m==================================
Tool Calls:
sql_db_query (call_WkHRiPcBoGN9bc58MIupRHKP)
Call ID: call_WkHRiPcBoGN9bc58MIupRHKP
Args:
query: SELECT COUNT(*) FROM Album WHERE ArtistId = (SELECT ArtistId FROM Artist WHERE Name = 'Alice In Chains')
=================================[1m Tool Message [0m=================================
Name: sql_db_query
[(1,)]
==================================[1m Ai Message [0m==================================
Alice In Chains has released 1 album in the database.
正如我們在串流步驟和 LangSmith 追蹤中看到的那樣,代理程式使用了 search_proper_nouns
工具,以便檢查如何正確查詢資料庫以獲取這位特定藝術家的資訊。