Skip to main content
Open In ColabOpen on GitHub

如何创建自定义聊天模型类

先决条件

本指南假定您熟悉以下概念:

在本指南中,我们将学习如何使用 LangChain 抽象来创建自定义 聊天模型

使用标准的 BaseChatModel 接口包装您的 LLM,您可以让您的 LLM 用于现有的 LangChain 程序,只需少量代码修改!

作为一项额外的好处,您的 LLM 将自动成为 LangChain 可运行对象,并可以直接获得一些优化(例如,通过线程池进行批处理)、异步支持、astream_events API 等。

输入和输出

首先,我们需要讨论 消息,它们是聊天模型的输入和输出。

消息

聊天模型以消息作为输入,并返回消息作为输出。

LangChain 有几种 内置消息类型

消息类型描述
SystemMessage用于调整 AI 行为,通常作为一系列输入消息的第一个传入。
HumanMessage代表与聊天模型互动的人员的消息。
AIMessage代表来自聊天模型的消息。这可以是文本,也可以是调用工具的请求。
FunctionMessage / ToolMessage用于将工具调用的结果传递回模型的消息。
AIMessageChunk / HumanMessageChunk / ...每种消息类型的分块变体。
note

ToolMessageFunctionMessage 紧随 OpenAI 的 functiontool 角色。

这是一个快速发展的领域,随着越来越多的模型添加函数调用功能。预计该架构将有所添加。

from langchain_core.messages import (
AIMessage,
BaseMessage,
FunctionMessage,
HumanMessage,
SystemMessage,
ToolMessage,
)

流式传输变体

所有的聊天消息都有一个流式传输变体,其名称包含 Chunk

from langchain_core.messages import (
AIMessageChunk,
FunctionMessageChunk,
HumanMessageChunk,
SystemMessageChunk,
ToolMessageChunk,
)

当流式输出聊天模型时,会使用这些分块,它们都定义了一个附加属性!

AIMessageChunk(content="Hello") + AIMessageChunk(content=" World!")
AIMessageChunk(content='Hello World!')

基础聊天模型

让我们来实现一个聊天模型,它会回显提示中最后一条消息的前 n 个字符!

为此,我们将继承自 BaseChatModel,并且需要实现以下内容:

方法/属性描述必需/可选
_generate用于从提示生成聊天结果必需
_llm_type (属性)用于唯一标识模型类型。用于日志记录。必需
_identifying_params (属性)代表模型参数化以用于跟踪目的。可选
_stream用于实现流式传输。可选
_agenerate用于实现原生的异步方法。可选
_astream用于实现 _stream 的异步版本。可选
tip

_astream 的实现使用 run_in_executor 在单独的线程中启动同步的 _stream(如果 _stream 已实现),否则它将回退使用 _agenerate

如果你想重用 _stream 的实现,可以使用这个技巧,但如果你能实现原生异步代码,那将是更好的解决方案,因为该代码将以更低的开销运行。

实现

from typing import Any, Dict, Iterator, List, Optional

from langchain_core.callbacks import (
CallbackManagerForLLMRun,
)
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
)
from langchain_core.messages.ai import UsageMetadata
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from pydantic import Field


class ChatParrotLink(BaseChatModel):
"""A custom chat model that echoes the first `parrot_buffer_length` characters
of the input.

When contributing an implementation to LangChain, carefully document
the model including the initialization parameters, include
an example of how to initialize the model and include any relevant
links to the underlying models documentation or API.

Example:

.. code-block:: python

model = ChatParrotLink(parrot_buffer_length=2, model="bird-brain-001")
result = model.invoke([HumanMessage(content="hello")])
result = model.batch([[HumanMessage(content="hello")],
[HumanMessage(content="world")]])
"""

model_name: str = Field(alias="model")
"""The name of the model"""
parrot_buffer_length: int
"""The number of characters from the last message of the prompt to be echoed."""
temperature: Optional[float] = None
max_tokens: Optional[int] = None
timeout: Optional[int] = None
stop: Optional[List[str]] = None
max_retries: int = 2

def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Override the _generate method to implement the chat model logic.

This can be a call to an API, a call to a local model, or any other
implementation that generates a response to the input prompt.

Args:
messages: the prompt composed of a list of messages.
stop: a list of strings on which the model should stop generating.
If generation stops due to a stop token, the stop token itself
SHOULD BE INCLUDED as part of the output. This is not enforced
across models right now, but it's a good practice to follow since
it makes it much easier to parse the output of the model
downstream and understand why generation stopped.
run_manager: A run manager with callbacks for the LLM.
"""
# Replace this with actual logic to generate a response from a list
# of messages.
last_message = messages[-1]
tokens = last_message.content[: self.parrot_buffer_length]
ct_input_tokens = sum(len(message.content) for message in messages)
ct_output_tokens = len(tokens)
message = AIMessage(
content=tokens,
additional_kwargs={}, # Used to add additional payload to the message
response_metadata={ # Use for response metadata
"time_in_seconds": 3,
"model_name": self.model_name,
},
usage_metadata={
"input_tokens": ct_input_tokens,
"output_tokens": ct_output_tokens,
"total_tokens": ct_input_tokens + ct_output_tokens,
},
)
##

generation = ChatGeneration(message=message)
return ChatResult(generations=[generation])

def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
"""Stream the output of the model.

This method should be implemented if the model can generate output
in a streaming fashion. If the model does not support streaming,
do not implement it. In that case streaming requests will be automatically
handled by the _generate method.

Args:
messages: the prompt composed of a list of messages.
stop: a list of strings on which the model should stop generating.
If generation stops due to a stop token, the stop token itself
SHOULD BE INCLUDED as part of the output. This is not enforced
across models right now, but it's a good practice to follow since
it makes it much easier to parse the output of the model
downstream and understand why generation stopped.
run_manager: A run manager with callbacks for the LLM.
"""
last_message = messages[-1]
tokens = str(last_message.content[: self.parrot_buffer_length])
ct_input_tokens = sum(len(message.content) for message in messages)

for token in tokens:
usage_metadata = UsageMetadata(
{
"input_tokens": ct_input_tokens,
"output_tokens": 1,
"total_tokens": ct_input_tokens + 1,
}
)
ct_input_tokens = 0
chunk = ChatGenerationChunk(
message=AIMessageChunk(content=token, usage_metadata=usage_metadata)
)

if run_manager:
# This is optional in newer versions of LangChain
# The on_llm_new_token will be called automatically
run_manager.on_llm_new_token(token, chunk=chunk)

yield chunk

# Let's add some other information (e.g., response metadata)
chunk = ChatGenerationChunk(
message=AIMessageChunk(
content="",
response_metadata={"time_in_sec": 3, "model_name": self.model_name},
)
)
if run_manager:
# This is optional in newer versions of LangChain
# The on_llm_new_token will be called automatically
run_manager.on_llm_new_token(token, chunk=chunk)
yield chunk

@property
def _llm_type(self) -> str:
"""Get the type of language model used by this chat model."""
return "echoing-chat-model-advanced"

@property
def _identifying_params(self) -> Dict[str, Any]:
"""Return a dictionary of identifying parameters.

This information is used by the LangChain callback system, which
is used for tracing purposes make it possible to monitor LLMs.
"""
return {
# The model name allows users to specify custom token counting
# rules in LLM monitoring applications (e.g., in LangSmith users
# can provide per token pricing for their model and monitor
# costs for the given LLM.)
"model_name": self.model_name,
}

来测试一下 🧪

聊天模型将实现 LangChain 的标准 Runnable 接口,许多 LangChain 的抽象都支持这个接口!

model = ChatParrotLink(parrot_buffer_length=3, model="my_custom_model")

model.invoke(
[
HumanMessage(content="hello!"),
AIMessage(content="Hi there human!"),
HumanMessage(content="Meow!"),
]
)
AIMessage(content='Meo', additional_kwargs={}, response_metadata={'time_in_seconds': 3}, id='run-cf11aeb6-8ab6-43d7-8c68-c1ef89b6d78e-0', usage_metadata={'input_tokens': 26, 'output_tokens': 3, 'total_tokens': 29})
model.invoke("hello")
AIMessage(content='hel', additional_kwargs={}, response_metadata={'time_in_seconds': 3}, id='run-618e5ed4-d611-4083-8cf1-c270726be8d9-0', usage_metadata={'input_tokens': 5, 'output_tokens': 3, 'total_tokens': 8})
model.batch(["hello", "goodbye"])
[AIMessage(content='hel', additional_kwargs={}, response_metadata={'time_in_seconds': 3}, id='run-eea4ed7d-d750-48dc-90c0-7acca1ff388f-0', usage_metadata={'input_tokens': 5, 'output_tokens': 3, 'total_tokens': 8}),
AIMessage(content='goo', additional_kwargs={}, response_metadata={'time_in_seconds': 3}, id='run-07cfc5c1-3c62-485f-b1e0-3d46e1547287-0', usage_metadata={'input_tokens': 7, 'output_tokens': 3, 'total_tokens': 10})]
for chunk in model.stream("cat"):
print(chunk.content, end="|")
c|a|t||

请查看模型中 _astream 的实现!如果您不实现它,那么将不会有任何输出流式传输!

async for chunk in model.astream("cat"):
print(chunk.content, end="|")
c|a|t||

让我们尝试使用 astream 事件 API,这也有助于仔细检查所有回调是否都已实现!

async for event in model.astream_events("cat", version="v1"):
print(event)
{'event': 'on_chat_model_start', 'run_id': '3f0b5501-5c78-45b3-92fc-8322a6a5024a', 'name': 'ChatParrotLink', 'tags': [], 'metadata': {}, 'data': {'input': 'cat'}, 'parent_ids': []}
{'event': 'on_chat_model_stream', 'run_id': '3f0b5501-5c78-45b3-92fc-8322a6a5024a', 'tags': [], 'metadata': {}, 'name': 'ChatParrotLink', 'data': {'chunk': AIMessageChunk(content='c', additional_kwargs={}, response_metadata={}, id='run-3f0b5501-5c78-45b3-92fc-8322a6a5024a', usage_metadata={'input_tokens': 3, 'output_tokens': 1, 'total_tokens': 4})}, 'parent_ids': []}
{'event': 'on_chat_model_stream', 'run_id': '3f0b5501-5c78-45b3-92fc-8322a6a5024a', 'tags': [], 'metadata': {}, 'name': 'ChatParrotLink', 'data': {'chunk': AIMessageChunk(content='a', additional_kwargs={}, response_metadata={}, id='run-3f0b5501-5c78-45b3-92fc-8322a6a5024a', usage_metadata={'input_tokens': 0, 'output_tokens': 1, 'total_tokens': 1})}, 'parent_ids': []}
{'event': 'on_chat_model_stream', 'run_id': '3f0b5501-5c78-45b3-92fc-8322a6a5024a', 'tags': [], 'metadata': {}, 'name': 'ChatParrotLink', 'data': {'chunk': AIMessageChunk(content='t', additional_kwargs={}, response_metadata={}, id='run-3f0b5501-5c78-45b3-92fc-8322a6a5024a', usage_metadata={'input_tokens': 0, 'output_tokens': 1, 'total_tokens': 1})}, 'parent_ids': []}
{'event': 'on_chat_model_stream', 'run_id': '3f0b5501-5c78-45b3-92fc-8322a6a5024a', 'tags': [], 'metadata': {}, 'name': 'ChatParrotLink', 'data': {'chunk': AIMessageChunk(content='', additional_kwargs={}, response_metadata={'time_in_sec': 3}, id='run-3f0b5501-5c78-45b3-92fc-8322a6a5024a')}, 'parent_ids': []}
{'event': 'on_chat_model_end', 'name': 'ChatParrotLink', 'run_id': '3f0b5501-5c78-45b3-92fc-8322a6a5024a', 'tags': [], 'metadata': {}, 'data': {'output': AIMessageChunk(content='cat', additional_kwargs={}, response_metadata={'time_in_sec': 3}, id='run-3f0b5501-5c78-45b3-92fc-8322a6a5024a', usage_metadata={'input_tokens': 3, 'output_tokens': 3, 'total_tokens': 6})}, 'parent_ids': []}

贡献

我们非常感谢所有对聊天模型集成的贡献。

以下清单将帮助您确保您的贡献被添加到 LangChain:

文档:

  • 为所有初始化参数添加了文档字符串,因为这些将在API 参考中显示。
  • 如果模型由某个服务提供支持,其类的文档字符串应包含指向该模型 API 的链接。

测试:

  • 为重写的方法添加单元测试或集成测试。如果您重写了相应的代码,请验证 invokeainvokebatchstream 是否正常工作。

流式传输(如果您正在实现):

  • 实现 _stream 方法以启用流式传输

停止令牌行为:

  • 应遵守停止令牌
  • 停止令牌应作为响应的一部分包含在内

秘密 API 密钥:

  • 如果您的模型连接到 API,它可能会在初始化时接受 API 密钥。请使用 Pydantic 的 SecretStr 类型处理秘密信息,以免它们在被打印时意外泄露。

识别参数:

  • 在识别参数中包含 model_name

优化:

考虑提供本地异步支持以减少模型的开销!

  • 提供了 _agenerateainvoke 使用)的本地异步版本
  • 提供了 _astreamastream 使用)的本地异步版本

后续步骤

现在您已经学会了如何创建自己的自定义聊天模型。

接下来,查看此部分中关于聊天模型的其他操作指南,例如如何让模型返回结构化输出如何跟踪聊天模型的令牌使用情况