Skip to main content
Open In ColabOpen on GitHub

大型数据库的 SQL 问答处理方法

为了编写对数据库有效的查询,我们需要向模型提供表名、表模式和用于查询的特征值。当存在许多表、列和/或高基数(high-cardinality)列时,我们无法在每个提示中转储关于数据库的全部信息。相反,我们必须找到动态地将最相关信息插入提示的方法。

在本指南中,我们将演示识别此类相关信息并将它们输入查询生成步骤的方法。我们将涵盖:

  1. 识别相关的表子集;
  2. 识别相关的列值子集。

设置

首先,获取所需的包并设置环境变量:

%pip install --upgrade --quiet  langchain langchain-community langchain-openai
# Uncomment the below to use LangSmith. Not required.
# import os
# os.environ["LANGSMITH_API_KEY"] = getpass.getpass()
# os.environ["LANGSMITH_TRACING"] = "true"

以下示例将使用 SQLite 连接到 Chinook 数据库。请遵循 这些安装步骤,在与此笔记本相同的目录中创建 Chinook.db

  • 此文件 保存为 Chinook_Sqlite.sql
  • 运行 sqlite3 Chinook.db
  • 运行 .read Chinook_Sqlite.sql
  • 测试 SELECT * FROM Artist LIMIT 10;

现在,Chinook.db 已经位于我们的目录中,我们可以使用 SQLAlchemy 驱动的 SQLDatabase 类来与之交互:

from langchain_community.utilities import SQLDatabase

db = SQLDatabase.from_uri("sqlite:///Chinook.db")
print(db.dialect)
print(db.get_usable_table_names())
print(db.run("SELECT * FROM Artist LIMIT 10;"))
API Reference:SQLDatabase
sqlite
['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']
[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]

许多表

我们需要在提示中包含的一个主要信息是相关表的模式。当我们有很多表时,无法将所有模式都包含在单个提示中。在这种情况下,我们可以首先提取与用户输入相关的表的名称,然后只包含它们的模式。

一种简单可靠的方法是使用 tool-calling。下面,我们将展示如何使用此功能来获得符合所需格式的输出(在本例中为表名列表)。我们使用聊天模型的 .bind_tools 方法来绑定 Pydantic 格式的工具,并将其馈送到输出解析器中,以从模型的响应中重建对象。

pip install -qU "langchain[google-genai]"
import getpass
import os

if not os.environ.get("GOOGLE_API_KEY"):
os.environ["GOOGLE_API_KEY"] = getpass.getpass("Enter API key for Google Gemini: ")

from langchain.chat_models import init_chat_model

llm = init_chat_model("gemini-2.0-flash", model_provider="google_genai")
from langchain_core.output_parsers.openai_tools import PydanticToolsParser
from langchain_core.prompts import ChatPromptTemplate
from pydantic import BaseModel, Field


class Table(BaseModel):
"""Table in SQL database."""

name: str = Field(description="Name of table in SQL database.")


table_names = "\n".join(db.get_usable_table_names())
system = f"""Return the names of ALL the SQL tables that MIGHT be relevant to the user question. \
The tables are:

{table_names}

Remember to include ALL POTENTIALLY RELEVANT tables, even if you're not sure that they're needed."""

prompt = ChatPromptTemplate.from_messages(
[
("system", system),
("human", "{input}"),
]
)
llm_with_tools = llm.bind_tools([Table])
output_parser = PydanticToolsParser(tools=[Table])

table_chain = prompt | llm_with_tools | output_parser

table_chain.invoke({"input": "What are all the genres of Alanis Morissette songs"})
[Table(name='Genre')]

这效果相当不错!不过,正如我们下面将看到的,我们实际上还需要几个其他的表格。仅凭用户问题,模型很难知道这一点。在这种情况下,我们或许可以考虑通过将表格分组来简化模型的任务。我们将只要求模型在“音乐”和“商业”两个类别之间进行选择,然后由模型来处理选取所有相关表的工作:

system = """Return the names of any SQL tables that are relevant to the user question.
The tables are:

Music
Business
"""

prompt = ChatPromptTemplate.from_messages(
[
("system", system),
("human", "{input}"),
]
)

category_chain = prompt | llm_with_tools | output_parser
category_chain.invoke({"input": "What are all the genres of Alanis Morissette songs"})
[Table(name='Music'), Table(name='Business')]
from typing import List


def get_tables(categories: List[Table]) -> List[str]:
tables = []
for category in categories:
if category.name == "Music":
tables.extend(
[
"Album",
"Artist",
"Genre",
"MediaType",
"Playlist",
"PlaylistTrack",
"Track",
]
)
elif category.name == "Business":
tables.extend(["Customer", "Employee", "Invoice", "InvoiceLine"])
return tables


table_chain = category_chain | get_tables
table_chain.invoke({"input": "What are all the genres of Alanis Morissette songs"})
['Album',
'Artist',
'Genre',
'MediaType',
'Playlist',
'PlaylistTrack',
'Track',
'Customer',
'Employee',
'Invoice',
'InvoiceLine']

现在我们有了一个可以针对任何查询输出相关表的链,我们可以将其与我们的 create_sql_query_chain 结合起来,该链接受 table_names_to_use 列表来确定提示中包含哪些表的模式:

from operator import itemgetter

from langchain.chains import create_sql_query_chain
from langchain_core.runnables import RunnablePassthrough

query_chain = create_sql_query_chain(llm, db)
# Convert "question" key to the "input" key expected by current table_chain.
table_chain = {"input": itemgetter("question")} | table_chain
# Set table_names_to_use using table_chain.
full_chain = RunnablePassthrough.assign(table_names_to_use=table_chain) | query_chain
query = full_chain.invoke(
{"question": "What are all the genres of Alanis Morissette songs"}
)
print(query)
SELECT DISTINCT "g"."Name"
FROM "Genre" g
JOIN "Track" t ON "g"."GenreId" = "t"."GenreId"
JOIN "Album" a ON "t"."AlbumId" = "a"."AlbumId"
JOIN "Artist" ar ON "a"."ArtistId" = "ar"."ArtistId"
WHERE "ar"."Name" = 'Alanis Morissette'
LIMIT 5;
db.run(query)
"[('Rock',)]"

我们可以看到本次运行的 LangSmith trace 在此

我们已经看到了如何在链中的提示中动态包含表模式的子集。解决这个问题的另一种可能方法是让 Agent 自己决定何时查找表,方法是给它提供一个用于此目的的工具。您可以在 SQL: Agents 指南中找到一个示例。

高基数性列

为了过滤包含地址、歌曲名称或艺术家等专有名词的列,我们首先需要仔细检查拼写,以便正确地筛选数据。

一个朴素的策略是创建一个包含数据库中所有不同专有名词的向量存储。然后,我们可以针对每个用户输入查询该向量存储,并将最相关的专有名词注入提示。

首先,我们需要获取我们想要的每个实体的唯一值,为此我们定义了一个函数来解析结果为元素列表:

import ast
import re


def query_as_list(db, query):
res = db.run(query)
res = [el for sub in ast.literal_eval(res) for el in sub if el]
res = [re.sub(r"\b\d+\b", "", string).strip() for string in res]
return res


proper_nouns = query_as_list(db, "SELECT Name FROM Artist")
proper_nouns += query_as_list(db, "SELECT Title FROM Album")
proper_nouns += query_as_list(db, "SELECT Name FROM Genre")
len(proper_nouns)
proper_nouns[:5]
['AC/DC', 'Accept', 'Aerosmith', 'Alanis Morissette', 'Alice In Chains']

现在我们可以将所有值嵌入并存储在向量数据库中了:

from langchain_community.vectorstores import FAISS
from langchain_openai import OpenAIEmbeddings

vector_db = FAISS.from_texts(proper_nouns, OpenAIEmbeddings())
retriever = vector_db.as_retriever(search_kwargs={"k": 15})
API Reference:FAISS | OpenAIEmbeddings

并构建一个查询构建链,该链首先从数据库中检索值,然后将它们插入到提示中:

from operator import itemgetter

from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough

system = """You are a SQLite expert. Given an input question, create a syntactically
correct SQLite query to run. Unless otherwise specificed, do not return more than
{top_k} rows.

Only return the SQL query with no markup or explanation.

Here is the relevant table info: {table_info}

Here is a non-exhaustive list of possible feature values. If filtering on a feature
value make sure to check its spelling against this list first:

{proper_nouns}
"""

prompt = ChatPromptTemplate.from_messages([("system", system), ("human", "{input}")])

query_chain = create_sql_query_chain(llm, db, prompt=prompt)
retriever_chain = (
itemgetter("question")
| retriever
| (lambda docs: "\n".join(doc.page_content for doc in docs))
)
chain = RunnablePassthrough.assign(proper_nouns=retriever_chain) | query_chain

为了试用我们的链,让我们看看在不检索和检索的情况下,对 "elenis moriset"(艾拉妮丝·莫莉塞特拼写错误)进行过滤会发生什么情况:

# Without retrieval
query = query_chain.invoke(
{"question": "What are all the genres of elenis moriset songs", "proper_nouns": ""}
)
print(query)
db.run(query)
SELECT DISTINCT g.Name 
FROM Track t
JOIN Album a ON t.AlbumId = a.AlbumId
JOIN Artist ar ON a.ArtistId = ar.ArtistId
JOIN Genre g ON t.GenreId = g.GenreId
WHERE ar.Name = 'Elenis Moriset';
''
# With retrieval
query = chain.invoke({"question": "What are all the genres of elenis moriset songs"})
print(query)
db.run(query)
SELECT DISTINCT g.Name
FROM Genre g
JOIN Track t ON g.GenreId = t.GenreId
JOIN Album a ON t.AlbumId = a.AlbumId
JOIN Artist ar ON a.ArtistId = ar.ArtistId
WHERE ar.Name = 'Alanis Morissette';
"[('Rock',)]"

我们可以看到,通过检索,我们能够将拼写从“Elenis Moriset”纠正为“Alanis Morissette”,并获得有效结果。

解决这个问题的另一种可能方法是让 Agent 自己决定何时查找专有名词。你可以在 SQL: Agents 指南中看到一个示例。