Skip to main content
Open In ColabOpen on GitHub

如何创建自定义 LLM 类

本 notebook 将介绍如何创建自定义 LLM 包装器,以防您想使用自己的 LLM 或 LangChain 不支持的其他包装器。

通过标准的 LLM 接口包装您的 LLM,可以让您在现有的 LangChain 程序中使用您的 LLM,而只需进行最少的代码修改。

此外,您的 LLM 将自动成为 LangChain 的 Runnable,并可直接受益于一些开箱即用的优化、异步支持、astream_events API 等。

caution

您当前所在页面记录的是 文本补全模型 的用法。许多最新和最受欢迎的模型是 聊天补全模型

除非您特别使用了更高级的提示技术,否则您可能是在查找此页面

实现

自定义 LLM 只需要实现两个必需的方法:

方法描述
_call接收一个字符串和一些可选的停止词,并返回一个字符串。供 invoke 使用。
_llm_type一个返回字符串的属性,仅用于日志记录目的。

可选的实现:

方法描述
_identifying_params用于帮助识别模型和打印 LLM;应返回一个字典。这是一个 @property
_acall提供 _call 的异步原生实现,供 ainvoke 使用。
_stream用于逐个 token 流式输出的方法。
_astream提供 _stream 的异步原生实现;在较新版本的 LangChain 中,默认为 _stream

让我们来实现一个简单的自定义 LLM,它只返回输入的前 n 个字符。

from typing import Any, Dict, Iterator, List, Mapping, Optional

from langchain_core.callbacks.manager import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.outputs import GenerationChunk


class CustomLLM(LLM):
"""A custom chat model that echoes the first `n` characters of the input.

When contributing an implementation to LangChain, carefully document
the model including the initialization parameters, include
an example of how to initialize the model and include any relevant
links to the underlying models documentation or API.

Example:

.. code-block:: python

model = CustomChatModel(n=2)
result = model.invoke([HumanMessage(content="hello")])
result = model.batch([[HumanMessage(content="hello")],
[HumanMessage(content="world")]])
"""

n: int
"""The number of characters from the last message of the prompt to be echoed."""

def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Run the LLM on the given input.

Override this method to implement the LLM logic.

Args:
prompt: The prompt to generate from.
stop: Stop words to use when generating. Model output is cut off at the
first occurrence of any of the stop substrings.
If stop tokens are not supported consider raising NotImplementedError.
run_manager: Callback manager for the run.
**kwargs: Arbitrary additional keyword arguments. These are usually passed
to the model provider API call.

Returns:
The model output as a string. Actual completions SHOULD NOT include the prompt.
"""
if stop is not None:
raise ValueError("stop kwargs are not permitted.")
return prompt[: self.n]

def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
"""Stream the LLM on the given prompt.

This method should be overridden by subclasses that support streaming.

If not implemented, the default behavior of calls to stream will be to
fallback to the non-streaming version of the model and return
the output as a single chunk.

Args:
prompt: The prompt to generate from.
stop: Stop words to use when generating. Model output is cut off at the
first occurrence of any of these substrings.
run_manager: Callback manager for the run.
**kwargs: Arbitrary additional keyword arguments. These are usually passed
to the model provider API call.

Returns:
An iterator of GenerationChunks.
"""
for char in prompt[: self.n]:
chunk = GenerationChunk(text=char)
if run_manager:
run_manager.on_llm_new_token(chunk.text, chunk=chunk)

yield chunk

@property
def _identifying_params(self) -> Dict[str, Any]:
"""Return a dictionary of identifying parameters."""
return {
# The model name allows users to specify custom token counting
# rules in LLM monitoring applications (e.g., in LangSmith users
# can provide per token pricing for their model and monitor
# costs for the given LLM.)
"model_name": "CustomChatModel",
}

@property
def _llm_type(self) -> str:
"""Get the type of language model used by this chat model. Used for logging purposes only."""
return "custom"

让我们来测试一下 🧪

这个 LLM 将会实现 LangChain 的标准 Runnable 接口,LangChain 的许多抽象都支持该接口!

llm = CustomLLM(n=5)
print(llm)
CustomLLM
Params: {'model_name': 'CustomChatModel'}
llm.invoke("This is a foobar thing")
'This '
await llm.ainvoke("world")
'world'
llm.batch(["woof woof woof", "meow meow meow"])
['woof ', 'meow ']
await llm.abatch(["woof woof woof", "meow meow meow"])
['woof ', 'meow ']
async for token in llm.astream("hello"):
print(token, end="|", flush=True)
h|e|l|l|o|

让我们确认它能与其他 LangChain API 很好地集成。

from langchain_core.prompts import ChatPromptTemplate
API Reference:ChatPromptTemplate
prompt = ChatPromptTemplate.from_messages(
[("system", "you are a bot"), ("human", "{input}")]
)
llm = CustomLLM(n=7)
chain = prompt | llm
idx = 0
async for event in chain.astream_events({"input": "hello there!"}, version="v1"):
print(event)
idx += 1
if idx > 7:
# Truncate
break
{'event': 'on_chain_start', 'run_id': '05f24b4f-7ea3-4fb6-8417-3aa21633462f', 'name': 'RunnableSequence', 'tags': [], 'metadata': {}, 'data': {'input': {'input': 'hello there!'}}}
{'event': 'on_prompt_start', 'name': 'ChatPromptTemplate', 'run_id': '7e996251-a926-4344-809e-c425a9846d21', 'tags': ['seq:step:1'], 'metadata': {}, 'data': {'input': {'input': 'hello there!'}}}
{'event': 'on_prompt_end', 'name': 'ChatPromptTemplate', 'run_id': '7e996251-a926-4344-809e-c425a9846d21', 'tags': ['seq:step:1'], 'metadata': {}, 'data': {'input': {'input': 'hello there!'}, 'output': ChatPromptValue(messages=[SystemMessage(content='you are a bot'), HumanMessage(content='hello there!')])}}
{'event': 'on_llm_start', 'name': 'CustomLLM', 'run_id': 'a8766beb-10f4-41de-8750-3ea7cf0ca7e2', 'tags': ['seq:step:2'], 'metadata': {}, 'data': {'input': {'prompts': ['System: you are a bot\nHuman: hello there!']}}}
{'event': 'on_llm_stream', 'name': 'CustomLLM', 'run_id': 'a8766beb-10f4-41de-8750-3ea7cf0ca7e2', 'tags': ['seq:step:2'], 'metadata': {}, 'data': {'chunk': 'S'}}
{'event': 'on_chain_stream', 'run_id': '05f24b4f-7ea3-4fb6-8417-3aa21633462f', 'tags': [], 'metadata': {}, 'name': 'RunnableSequence', 'data': {'chunk': 'S'}}
{'event': 'on_llm_stream', 'name': 'CustomLLM', 'run_id': 'a8766beb-10f4-41de-8750-3ea7cf0ca7e2', 'tags': ['seq:step:2'], 'metadata': {}, 'data': {'chunk': 'y'}}
{'event': 'on_chain_stream', 'run_id': '05f24b4f-7ea3-4fb6-8417-3aa21633462f', 'tags': [], 'metadata': {}, 'name': 'RunnableSequence', 'data': {'chunk': 'y'}}

贡献

我们感谢所有关于聊天模型集成的贡献。

以下清单有助于确保您的贡献被添加到 LangChain 中:

文档:

  • 为所有初始化参数添加了文档字符串 (doc-strings),这些文档字符串将显示在 API 参考 中。
  • 如果模型是由某个服务支持的,其类的文档字符串应包含指向该模型 API 的链接。

测试:

  • 为重写的方法添加单元测试或集成测试。如果您重写了相应的代码,请验证 invokeainvokebatchstream 是否正常工作。

流式输出(如果您正在实现):

  • 确保调用了 on_llm_new_token 回调
  • on_llm_new_token 应在 yielding 块 之前 调用

停止标记行为:

  • 应尊重停止标记
  • 停止标记应包含在响应中

秘密 API 密钥:

  • 如果您的模型连接到 API,它可能会接受初始化参数中的 API 密钥。请使用 Pydantic 的 SecretStr 类型来处理密钥,这样人们在打印模型时就不会意外地打印出它们。