跳到主要內容
Open In ColabOpen on GitHub

如何建立自訂聊天模型類別

先決條件

本指南假設您熟悉以下概念

在本指南中,我們將學習如何使用 LangChain 抽象概念建立自訂聊天模型

使用標準 BaseChatModel 介面封裝您的 LLM,讓您可以在現有的 LangChain 程式中使用您的 LLM,且只需進行最少的程式碼修改!

作為額外的好處,您的 LLM 將自動成為 LangChain Runnable,並將受益於一些開箱即用的最佳化 (例如,透過 threadpool 進行批次處理)、非同步支援、astream_events API 等。

輸入和輸出

首先,我們需要討論訊息,這是聊天模型的輸入和輸出。

訊息

聊天模型將訊息作為輸入,並傳回訊息作為輸出。

LangChain 有一些內建訊息類型

訊息類型描述
SystemMessage用於啟動 AI 行為,通常作為一系列輸入訊息的第一個傳入。
HumanMessage表示來自與聊天模型互動的人員的訊息。
AIMessage表示來自聊天模型的訊息。這可以是文字或調用工具的請求。
FunctionMessage / ToolMessage用於將工具調用結果傳回模型的訊息。
AIMessageChunk / HumanMessageChunk / ...每種訊息類型的區塊變體。
注意

ToolMessageFunctionMessage 密切遵循 OpenAI 的 functiontool 角色。

這是一個快速發展的領域,隨著越來越多的模型新增函數調用功能。預計此架構將會新增內容。

from langchain_core.messages import (
AIMessage,
BaseMessage,
FunctionMessage,
HumanMessage,
SystemMessage,
ToolMessage,
)

串流變體

所有聊天訊息都有一個串流變體,名稱中包含 Chunk

from langchain_core.messages import (
AIMessageChunk,
FunctionMessageChunk,
HumanMessageChunk,
SystemMessageChunk,
ToolMessageChunk,
)

這些區塊在從聊天模型串流輸出時使用,它們都定義了一個附加屬性!

AIMessageChunk(content="Hello") + AIMessageChunk(content=" World!")
AIMessageChunk(content='Hello World!')

基礎聊天模型

讓我們實作一個聊天模型,該模型會回音提示中最後一則訊息的前 n 個字元!

為此,我們將繼承自 BaseChatModel,並且需要實作以下內容

方法/屬性描述必要/選用
_generate用於從提示產生聊天結果必要
_llm_type (屬性)用於唯一識別模型類型。用於記錄。必要
_identifying_params (屬性)表示用於追蹤目的的模型參數化。選用
_stream用於實作串流。選用
_agenerate用於實作原生非同步方法。選用
_astream用於實作 _stream 的非同步版本。選用
提示

如果已實作 _stream,則 _astream 實作會使用 run_in_executor 在單獨的執行緒中啟動同步 _stream,否則會回退使用 _agenerate

如果您想重複使用 _stream 實作,可以使用此技巧,但如果您能夠實作原生非同步的程式碼,那會是更好的解決方案,因為該程式碼將以更少的 overhead 執行。

實作

from typing import Any, Dict, Iterator, List, Optional

from langchain_core.callbacks import (
CallbackManagerForLLMRun,
)
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
)
from langchain_core.messages.ai import UsageMetadata
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from pydantic import Field


class ChatParrotLink(BaseChatModel):
"""A custom chat model that echoes the first `parrot_buffer_length` characters
of the input.

When contributing an implementation to LangChain, carefully document
the model including the initialization parameters, include
an example of how to initialize the model and include any relevant
links to the underlying models documentation or API.

Example:

.. code-block:: python

model = ChatParrotLink(parrot_buffer_length=2, model="bird-brain-001")
result = model.invoke([HumanMessage(content="hello")])
result = model.batch([[HumanMessage(content="hello")],
[HumanMessage(content="world")]])
"""

model_name: str = Field(alias="model")
"""The name of the model"""
parrot_buffer_length: int
"""The number of characters from the last message of the prompt to be echoed."""
temperature: Optional[float] = None
max_tokens: Optional[int] = None
timeout: Optional[int] = None
stop: Optional[List[str]] = None
max_retries: int = 2

def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Override the _generate method to implement the chat model logic.

This can be a call to an API, a call to a local model, or any other
implementation that generates a response to the input prompt.

Args:
messages: the prompt composed of a list of messages.
stop: a list of strings on which the model should stop generating.
If generation stops due to a stop token, the stop token itself
SHOULD BE INCLUDED as part of the output. This is not enforced
across models right now, but it's a good practice to follow since
it makes it much easier to parse the output of the model
downstream and understand why generation stopped.
run_manager: A run manager with callbacks for the LLM.
"""
# Replace this with actual logic to generate a response from a list
# of messages.
last_message = messages[-1]
tokens = last_message.content[: self.parrot_buffer_length]
ct_input_tokens = sum(len(message.content) for message in messages)
ct_output_tokens = len(tokens)
message = AIMessage(
content=tokens,
additional_kwargs={}, # Used to add additional payload to the message
response_metadata={ # Use for response metadata
"time_in_seconds": 3,
},
usage_metadata={
"input_tokens": ct_input_tokens,
"output_tokens": ct_output_tokens,
"total_tokens": ct_input_tokens + ct_output_tokens,
},
)
##

generation = ChatGeneration(message=message)
return ChatResult(generations=[generation])

def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
"""Stream the output of the model.

This method should be implemented if the model can generate output
in a streaming fashion. If the model does not support streaming,
do not implement it. In that case streaming requests will be automatically
handled by the _generate method.

Args:
messages: the prompt composed of a list of messages.
stop: a list of strings on which the model should stop generating.
If generation stops due to a stop token, the stop token itself
SHOULD BE INCLUDED as part of the output. This is not enforced
across models right now, but it's a good practice to follow since
it makes it much easier to parse the output of the model
downstream and understand why generation stopped.
run_manager: A run manager with callbacks for the LLM.
"""
last_message = messages[-1]
tokens = str(last_message.content[: self.parrot_buffer_length])
ct_input_tokens = sum(len(message.content) for message in messages)

for token in tokens:
usage_metadata = UsageMetadata(
{
"input_tokens": ct_input_tokens,
"output_tokens": 1,
"total_tokens": ct_input_tokens + 1,
}
)
ct_input_tokens = 0
chunk = ChatGenerationChunk(
message=AIMessageChunk(content=token, usage_metadata=usage_metadata)
)

if run_manager:
# This is optional in newer versions of LangChain
# The on_llm_new_token will be called automatically
run_manager.on_llm_new_token(token, chunk=chunk)

yield chunk

# Let's add some other information (e.g., response metadata)
chunk = ChatGenerationChunk(
message=AIMessageChunk(content="", response_metadata={"time_in_sec": 3})
)
if run_manager:
# This is optional in newer versions of LangChain
# The on_llm_new_token will be called automatically
run_manager.on_llm_new_token(token, chunk=chunk)
yield chunk

@property
def _llm_type(self) -> str:
"""Get the type of language model used by this chat model."""
return "echoing-chat-model-advanced"

@property
def _identifying_params(self) -> Dict[str, Any]:
"""Return a dictionary of identifying parameters.

This information is used by the LangChain callback system, which
is used for tracing purposes make it possible to monitor LLMs.
"""
return {
# The model name allows users to specify custom token counting
# rules in LLM monitoring applications (e.g., in LangSmith users
# can provide per token pricing for their model and monitor
# costs for the given LLM.)
"model_name": self.model_name,
}

讓我們測試一下 🧪

聊天模型將實作 LangChain 的標準 Runnable 介面,許多 LangChain 抽象概念都支援該介面!

model = ChatParrotLink(parrot_buffer_length=3, model="my_custom_model")

model.invoke(
[
HumanMessage(content="hello!"),
AIMessage(content="Hi there human!"),
HumanMessage(content="Meow!"),
]
)
AIMessage(content='Meo', additional_kwargs={}, response_metadata={'time_in_seconds': 3}, id='run-cf11aeb6-8ab6-43d7-8c68-c1ef89b6d78e-0', usage_metadata={'input_tokens': 26, 'output_tokens': 3, 'total_tokens': 29})
model.invoke("hello")
AIMessage(content='hel', additional_kwargs={}, response_metadata={'time_in_seconds': 3}, id='run-618e5ed4-d611-4083-8cf1-c270726be8d9-0', usage_metadata={'input_tokens': 5, 'output_tokens': 3, 'total_tokens': 8})
model.batch(["hello", "goodbye"])
[AIMessage(content='hel', additional_kwargs={}, response_metadata={'time_in_seconds': 3}, id='run-eea4ed7d-d750-48dc-90c0-7acca1ff388f-0', usage_metadata={'input_tokens': 5, 'output_tokens': 3, 'total_tokens': 8}),
AIMessage(content='goo', additional_kwargs={}, response_metadata={'time_in_seconds': 3}, id='run-07cfc5c1-3c62-485f-b1e0-3d46e1547287-0', usage_metadata={'input_tokens': 7, 'output_tokens': 3, 'total_tokens': 10})]
for chunk in model.stream("cat"):
print(chunk.content, end="|")
c|a|t||

請參閱模型中 _astream 的實作!如果您未實作它,則不會串流任何輸出!

async for chunk in model.astream("cat"):
print(chunk.content, end="|")
c|a|t||

讓我們嘗試使用 astream events API,這也將有助於再次檢查是否已實作所有回呼!

async for event in model.astream_events("cat", version="v1"):
print(event)
{'event': 'on_chat_model_start', 'run_id': '3f0b5501-5c78-45b3-92fc-8322a6a5024a', 'name': 'ChatParrotLink', 'tags': [], 'metadata': {}, 'data': {'input': 'cat'}, 'parent_ids': []}
{'event': 'on_chat_model_stream', 'run_id': '3f0b5501-5c78-45b3-92fc-8322a6a5024a', 'tags': [], 'metadata': {}, 'name': 'ChatParrotLink', 'data': {'chunk': AIMessageChunk(content='c', additional_kwargs={}, response_metadata={}, id='run-3f0b5501-5c78-45b3-92fc-8322a6a5024a', usage_metadata={'input_tokens': 3, 'output_tokens': 1, 'total_tokens': 4})}, 'parent_ids': []}
{'event': 'on_chat_model_stream', 'run_id': '3f0b5501-5c78-45b3-92fc-8322a6a5024a', 'tags': [], 'metadata': {}, 'name': 'ChatParrotLink', 'data': {'chunk': AIMessageChunk(content='a', additional_kwargs={}, response_metadata={}, id='run-3f0b5501-5c78-45b3-92fc-8322a6a5024a', usage_metadata={'input_tokens': 0, 'output_tokens': 1, 'total_tokens': 1})}, 'parent_ids': []}
{'event': 'on_chat_model_stream', 'run_id': '3f0b5501-5c78-45b3-92fc-8322a6a5024a', 'tags': [], 'metadata': {}, 'name': 'ChatParrotLink', 'data': {'chunk': AIMessageChunk(content='t', additional_kwargs={}, response_metadata={}, id='run-3f0b5501-5c78-45b3-92fc-8322a6a5024a', usage_metadata={'input_tokens': 0, 'output_tokens': 1, 'total_tokens': 1})}, 'parent_ids': []}
{'event': 'on_chat_model_stream', 'run_id': '3f0b5501-5c78-45b3-92fc-8322a6a5024a', 'tags': [], 'metadata': {}, 'name': 'ChatParrotLink', 'data': {'chunk': AIMessageChunk(content='', additional_kwargs={}, response_metadata={'time_in_sec': 3}, id='run-3f0b5501-5c78-45b3-92fc-8322a6a5024a')}, 'parent_ids': []}
{'event': 'on_chat_model_end', 'name': 'ChatParrotLink', 'run_id': '3f0b5501-5c78-45b3-92fc-8322a6a5024a', 'tags': [], 'metadata': {}, 'data': {'output': AIMessageChunk(content='cat', additional_kwargs={}, response_metadata={'time_in_sec': 3}, id='run-3f0b5501-5c78-45b3-92fc-8322a6a5024a', usage_metadata={'input_tokens': 3, 'output_tokens': 3, 'total_tokens': 6})}, 'parent_ids': []}

貢獻

我們感謝所有聊天模型整合貢獻。

以下是一個檢查清單,可協助確保您的貢獻新增至 LangChain

文件

  • 模型包含所有初始化引數的 doc-string,因為這些引數將在 API 參考中顯示。
  • 如果模型由服務提供支援,則模型的類別 doc-string 包含模型 API 的連結。

測試

  • 將單元或整合測試新增至覆寫的方法。如果您已覆寫對應的程式碼,請驗證 invokeainvokebatchstream 是否正常運作。

串流 (如果您要實作)

  • 實作 _stream 方法以使串流正常運作

停止 token 行為

  • 應遵循停止 token
  • 停止 token 應包含在回應中

秘密 API 金鑰

  • 如果您的模型連線到 API,它可能會接受 API 金鑰作為其初始化的一部分。針對秘密使用 Pydantic 的 SecretStr 類型,這樣當人們列印模型時,它們不會被意外列印出來。

識別參數

  • 在識別參數中包含 model_name

最佳化

考慮提供原生非同步支援以減少來自模型的 overhead!

  • 提供 _agenerate 的原生非同步版本 (由 ainvoke 使用)
  • 提供 _astream 的原生非同步版本 (由 astream 使用)

後續步驟

您現在已學習如何建立自己的自訂聊天模型。

接下來,查看本節中其他關於聊天模型的操作指南,例如如何讓模型傳回結構化輸出如何追蹤聊天模型 token 使用量


此頁面是否有幫助?