如何建立自訂 LLM 類別
本筆記本說明如何建立自訂 LLM 包裝函式,以便在您想要使用自己的 LLM 或 LangChain 中不支援的其他包裝函式時使用。
使用標準 LLM
介面包裝您的 LLM,可讓您在現有的 LangChain 程式中使用您的 LLM,並盡可能減少程式碼修改。
作為額外好處,您的 LLM 將自動成為 LangChain Runnable
,並可立即受益於一些最佳化、非同步支援、astream_events
API 等。
實作
自訂 LLM 只需要實作兩個必要項目
方法 | 描述 |
---|---|
_call | 接收字串和一些可選的停止詞,並傳回字串。由 invoke 使用。 |
_llm_type | 傳回字串的屬性,僅用於記錄目的。 |
可選實作
方法 | 描述 |
---|---|
_identifying_params | 用於協助識別模型並列印 LLM;應傳回字典。這是一個 @property。 |
_acall | 提供 _call 的非同步原生實作,由 ainvoke 使用。 |
_stream | 逐個 Token 串流輸出結果的方法。 |
_astream | 提供 _stream 的非同步原生實作;在較新的 LangChain 版本中,預設為 _stream 。 |
讓我們實作一個簡單的自訂 LLM,它只傳回輸入的前 n 個字元。
from typing import Any, Dict, Iterator, List, Mapping, Optional
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.outputs import GenerationChunk
class CustomLLM(LLM):
"""A custom chat model that echoes the first `n` characters of the input.
When contributing an implementation to LangChain, carefully document
the model including the initialization parameters, include
an example of how to initialize the model and include any relevant
links to the underlying models documentation or API.
Example:
.. code-block:: python
model = CustomChatModel(n=2)
result = model.invoke([HumanMessage(content="hello")])
result = model.batch([[HumanMessage(content="hello")],
[HumanMessage(content="world")]])
"""
n: int
"""The number of characters from the last message of the prompt to be echoed."""
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Run the LLM on the given input.
Override this method to implement the LLM logic.
Args:
prompt: The prompt to generate from.
stop: Stop words to use when generating. Model output is cut off at the
first occurrence of any of the stop substrings.
If stop tokens are not supported consider raising NotImplementedError.
run_manager: Callback manager for the run.
**kwargs: Arbitrary additional keyword arguments. These are usually passed
to the model provider API call.
Returns:
The model output as a string. Actual completions SHOULD NOT include the prompt.
"""
if stop is not None:
raise ValueError("stop kwargs are not permitted.")
return prompt[: self.n]
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
"""Stream the LLM on the given prompt.
This method should be overridden by subclasses that support streaming.
If not implemented, the default behavior of calls to stream will be to
fallback to the non-streaming version of the model and return
the output as a single chunk.
Args:
prompt: The prompt to generate from.
stop: Stop words to use when generating. Model output is cut off at the
first occurrence of any of these substrings.
run_manager: Callback manager for the run.
**kwargs: Arbitrary additional keyword arguments. These are usually passed
to the model provider API call.
Returns:
An iterator of GenerationChunks.
"""
for char in prompt[: self.n]:
chunk = GenerationChunk(text=char)
if run_manager:
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
yield chunk
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Return a dictionary of identifying parameters."""
return {
# The model name allows users to specify custom token counting
# rules in LLM monitoring applications (e.g., in LangSmith users
# can provide per token pricing for their model and monitor
# costs for the given LLM.)
"model_name": "CustomChatModel",
}
@property
def _llm_type(self) -> str:
"""Get the type of language model used by this chat model. Used for logging purposes only."""
return "custom"
讓我們測試一下 🧪
此 LLM 將實作 LangChain 的標準 Runnable
介面,許多 LangChain 抽象化都支援此介面!
llm = CustomLLM(n=5)
print(llm)
[1mCustomLLM[0m
Params: {'model_name': 'CustomChatModel'}
llm.invoke("This is a foobar thing")
'This '
await llm.ainvoke("world")
'world'
llm.batch(["woof woof woof", "meow meow meow"])
['woof ', 'meow ']
await llm.abatch(["woof woof woof", "meow meow meow"])
['woof ', 'meow ']
async for token in llm.astream("hello"):
print(token, end="|", flush=True)
h|e|l|l|o|
讓我們確認它與其他 LangChain
API 良好整合。
from langchain_core.prompts import ChatPromptTemplate
API 參考:ChatPromptTemplate
prompt = ChatPromptTemplate.from_messages(
[("system", "you are a bot"), ("human", "{input}")]
)
llm = CustomLLM(n=7)
chain = prompt | llm
idx = 0
async for event in chain.astream_events({"input": "hello there!"}, version="v1"):
print(event)
idx += 1
if idx > 7:
# Truncate
break
{'event': 'on_chain_start', 'run_id': '05f24b4f-7ea3-4fb6-8417-3aa21633462f', 'name': 'RunnableSequence', 'tags': [], 'metadata': {}, 'data': {'input': {'input': 'hello there!'}}}
{'event': 'on_prompt_start', 'name': 'ChatPromptTemplate', 'run_id': '7e996251-a926-4344-809e-c425a9846d21', 'tags': ['seq:step:1'], 'metadata': {}, 'data': {'input': {'input': 'hello there!'}}}
{'event': 'on_prompt_end', 'name': 'ChatPromptTemplate', 'run_id': '7e996251-a926-4344-809e-c425a9846d21', 'tags': ['seq:step:1'], 'metadata': {}, 'data': {'input': {'input': 'hello there!'}, 'output': ChatPromptValue(messages=[SystemMessage(content='you are a bot'), HumanMessage(content='hello there!')])}}
{'event': 'on_llm_start', 'name': 'CustomLLM', 'run_id': 'a8766beb-10f4-41de-8750-3ea7cf0ca7e2', 'tags': ['seq:step:2'], 'metadata': {}, 'data': {'input': {'prompts': ['System: you are a bot\nHuman: hello there!']}}}
{'event': 'on_llm_stream', 'name': 'CustomLLM', 'run_id': 'a8766beb-10f4-41de-8750-3ea7cf0ca7e2', 'tags': ['seq:step:2'], 'metadata': {}, 'data': {'chunk': 'S'}}
{'event': 'on_chain_stream', 'run_id': '05f24b4f-7ea3-4fb6-8417-3aa21633462f', 'tags': [], 'metadata': {}, 'name': 'RunnableSequence', 'data': {'chunk': 'S'}}
{'event': 'on_llm_stream', 'name': 'CustomLLM', 'run_id': 'a8766beb-10f4-41de-8750-3ea7cf0ca7e2', 'tags': ['seq:step:2'], 'metadata': {}, 'data': {'chunk': 'y'}}
{'event': 'on_chain_stream', 'run_id': '05f24b4f-7ea3-4fb6-8417-3aa21633462f', 'tags': [], 'metadata': {}, 'name': 'RunnableSequence', 'data': {'chunk': 'y'}}
貢獻
我們感謝所有聊天模型整合的貢獻。
以下是一個檢查清單,可協助確保您的貢獻新增到 LangChain 中
文件
- 模型包含所有初始化參數的文件字串,因為這些參數將在API 參考中顯示。
- 如果模型由服務提供支援,則模型的類別文件字串包含模型 API 的連結。
測試
- 為覆寫的方法新增單元測試或整合測試。如果您已覆寫對應的程式碼,請驗證
invoke
、ainvoke
、batch
、stream
是否運作。
串流(如果您要實作它)
- 請務必調用
on_llm_new_token
回呼 -
on_llm_new_token
在產生區塊之前調用
停止 Token 行為
- 應尊重停止 Token
- 停止 Token 應包含在回應中
密鑰 API 金鑰
- 如果您的模型連線到 API,它可能會接受 API 金鑰作為其初始化的一部分。針對密碼使用 Pydantic 的
SecretStr
類型,這樣人們在列印模型時就不會意外地列印出來。