跳至主要內容
Open In ColabOpen on GitHub

如何建立自訂 LLM 類別

本筆記本說明如何建立自訂 LLM 包裝函式,以便在您想要使用自己的 LLM 或 LangChain 中不支援的其他包裝函式時使用。

使用標準 LLM 介面包裝您的 LLM,可讓您在現有的 LangChain 程式中使用您的 LLM,並盡可能減少程式碼修改。

作為額外好處,您的 LLM 將自動成為 LangChain Runnable,並可立即受益於一些最佳化、非同步支援、astream_events API 等。

注意

您目前所在的頁面說明如何使用文字完成模型。許多最新和最流行的模型是聊天完成模型

除非您特別使用更進階的提示技術,否則您可能正在尋找這個頁面

實作

自訂 LLM 只需要實作兩個必要項目

方法描述
_call接收字串和一些可選的停止詞,並傳回字串。由 invoke 使用。
_llm_type傳回字串的屬性,僅用於記錄目的。

可選實作

方法描述
_identifying_params用於協助識別模型並列印 LLM;應傳回字典。這是一個 @property
_acall提供 _call 的非同步原生實作,由 ainvoke 使用。
_stream逐個 Token 串流輸出結果的方法。
_astream提供 _stream 的非同步原生實作;在較新的 LangChain 版本中,預設為 _stream

讓我們實作一個簡單的自訂 LLM,它只傳回輸入的前 n 個字元。

from typing import Any, Dict, Iterator, List, Mapping, Optional

from langchain_core.callbacks.manager import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.outputs import GenerationChunk


class CustomLLM(LLM):
"""A custom chat model that echoes the first `n` characters of the input.

When contributing an implementation to LangChain, carefully document
the model including the initialization parameters, include
an example of how to initialize the model and include any relevant
links to the underlying models documentation or API.

Example:

.. code-block:: python

model = CustomChatModel(n=2)
result = model.invoke([HumanMessage(content="hello")])
result = model.batch([[HumanMessage(content="hello")],
[HumanMessage(content="world")]])
"""

n: int
"""The number of characters from the last message of the prompt to be echoed."""

def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Run the LLM on the given input.

Override this method to implement the LLM logic.

Args:
prompt: The prompt to generate from.
stop: Stop words to use when generating. Model output is cut off at the
first occurrence of any of the stop substrings.
If stop tokens are not supported consider raising NotImplementedError.
run_manager: Callback manager for the run.
**kwargs: Arbitrary additional keyword arguments. These are usually passed
to the model provider API call.

Returns:
The model output as a string. Actual completions SHOULD NOT include the prompt.
"""
if stop is not None:
raise ValueError("stop kwargs are not permitted.")
return prompt[: self.n]

def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
"""Stream the LLM on the given prompt.

This method should be overridden by subclasses that support streaming.

If not implemented, the default behavior of calls to stream will be to
fallback to the non-streaming version of the model and return
the output as a single chunk.

Args:
prompt: The prompt to generate from.
stop: Stop words to use when generating. Model output is cut off at the
first occurrence of any of these substrings.
run_manager: Callback manager for the run.
**kwargs: Arbitrary additional keyword arguments. These are usually passed
to the model provider API call.

Returns:
An iterator of GenerationChunks.
"""
for char in prompt[: self.n]:
chunk = GenerationChunk(text=char)
if run_manager:
run_manager.on_llm_new_token(chunk.text, chunk=chunk)

yield chunk

@property
def _identifying_params(self) -> Dict[str, Any]:
"""Return a dictionary of identifying parameters."""
return {
# The model name allows users to specify custom token counting
# rules in LLM monitoring applications (e.g., in LangSmith users
# can provide per token pricing for their model and monitor
# costs for the given LLM.)
"model_name": "CustomChatModel",
}

@property
def _llm_type(self) -> str:
"""Get the type of language model used by this chat model. Used for logging purposes only."""
return "custom"

讓我們測試一下 🧪

此 LLM 將實作 LangChain 的標準 Runnable 介面,許多 LangChain 抽象化都支援此介面!

llm = CustomLLM(n=5)
print(llm)
CustomLLM
Params: {'model_name': 'CustomChatModel'}
llm.invoke("This is a foobar thing")
'This '
await llm.ainvoke("world")
'world'
llm.batch(["woof woof woof", "meow meow meow"])
['woof ', 'meow ']
await llm.abatch(["woof woof woof", "meow meow meow"])
['woof ', 'meow ']
async for token in llm.astream("hello"):
print(token, end="|", flush=True)
h|e|l|l|o|

讓我們確認它與其他 LangChain API 良好整合。

from langchain_core.prompts import ChatPromptTemplate
API 參考:ChatPromptTemplate
prompt = ChatPromptTemplate.from_messages(
[("system", "you are a bot"), ("human", "{input}")]
)
llm = CustomLLM(n=7)
chain = prompt | llm
idx = 0
async for event in chain.astream_events({"input": "hello there!"}, version="v1"):
print(event)
idx += 1
if idx > 7:
# Truncate
break
{'event': 'on_chain_start', 'run_id': '05f24b4f-7ea3-4fb6-8417-3aa21633462f', 'name': 'RunnableSequence', 'tags': [], 'metadata': {}, 'data': {'input': {'input': 'hello there!'}}}
{'event': 'on_prompt_start', 'name': 'ChatPromptTemplate', 'run_id': '7e996251-a926-4344-809e-c425a9846d21', 'tags': ['seq:step:1'], 'metadata': {}, 'data': {'input': {'input': 'hello there!'}}}
{'event': 'on_prompt_end', 'name': 'ChatPromptTemplate', 'run_id': '7e996251-a926-4344-809e-c425a9846d21', 'tags': ['seq:step:1'], 'metadata': {}, 'data': {'input': {'input': 'hello there!'}, 'output': ChatPromptValue(messages=[SystemMessage(content='you are a bot'), HumanMessage(content='hello there!')])}}
{'event': 'on_llm_start', 'name': 'CustomLLM', 'run_id': 'a8766beb-10f4-41de-8750-3ea7cf0ca7e2', 'tags': ['seq:step:2'], 'metadata': {}, 'data': {'input': {'prompts': ['System: you are a bot\nHuman: hello there!']}}}
{'event': 'on_llm_stream', 'name': 'CustomLLM', 'run_id': 'a8766beb-10f4-41de-8750-3ea7cf0ca7e2', 'tags': ['seq:step:2'], 'metadata': {}, 'data': {'chunk': 'S'}}
{'event': 'on_chain_stream', 'run_id': '05f24b4f-7ea3-4fb6-8417-3aa21633462f', 'tags': [], 'metadata': {}, 'name': 'RunnableSequence', 'data': {'chunk': 'S'}}
{'event': 'on_llm_stream', 'name': 'CustomLLM', 'run_id': 'a8766beb-10f4-41de-8750-3ea7cf0ca7e2', 'tags': ['seq:step:2'], 'metadata': {}, 'data': {'chunk': 'y'}}
{'event': 'on_chain_stream', 'run_id': '05f24b4f-7ea3-4fb6-8417-3aa21633462f', 'tags': [], 'metadata': {}, 'name': 'RunnableSequence', 'data': {'chunk': 'y'}}

貢獻

我們感謝所有聊天模型整合的貢獻。

以下是一個檢查清單,可協助確保您的貢獻新增到 LangChain 中

文件

  • 模型包含所有初始化參數的文件字串,因為這些參數將在API 參考中顯示。
  • 如果模型由服務提供支援,則模型的類別文件字串包含模型 API 的連結。

測試

  • 為覆寫的方法新增單元測試或整合測試。如果您已覆寫對應的程式碼,請驗證 invokeainvokebatchstream 是否運作。

串流(如果您要實作它)

  • 請務必調用 on_llm_new_token 回呼
  • on_llm_new_token 在產生區塊之前調用

停止 Token 行為

  • 應尊重停止 Token
  • 停止 Token 應包含在回應中

密鑰 API 金鑰

  • 如果您的模型連線到 API,它可能會接受 API 金鑰作為其初始化的一部分。針對密碼使用 Pydantic 的 SecretStr 類型,這樣人們在列印模型時就不會意外地列印出來。

此頁面是否有幫助?