跳至主要內容
Open In ColabOpen on GitHub

Cross Encoder Reranker

本筆記本示範如何在檢索器中實作重新排序器,並使用您自己的 cross encoder,來源為 Hugging Face cross encoder 模型 或實作 cross encoder 函式的 Hugging Face 模型 (範例:BAAI/bge-reranker-base)。SagemakerEndpointCrossEncoder 讓您可以使用在 Sagemaker 上載入的這些 HuggingFace 模型。

這是基於 ContextualCompressionRetriever 中的概念而建立的。本文檔的整體結構來自 Cohere Reranker 文件

如需更多關於為何 cross encoder 可作為重新排序機制,與嵌入結合使用以獲得更佳檢索效果的資訊,請參閱 Hugging Face Cross-Encoders 文件

#!pip install faiss sentence_transformers

# OR (depending on Python version)

#!pip install faiss-cpu sentence_transformers
# Helper function for printing docs


def pretty_print_docs(docs):
print(
f"\n{'-' * 100}\n".join(
[f"Document {i+1}:\n\n" + d.page_content for i, d in enumerate(docs)]
)
)

設定基礎向量儲存檢索器

我們先初始化一個簡單的向量儲存檢索器,並儲存 2023 年國情咨文演講 (以區塊形式)。我們可以設定檢索器檢索大量 (20) 文件。

from langchain_community.document_loaders import TextLoader
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter

documents = TextLoader("../../how_to/state_of_the_union.txt").load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)
texts = text_splitter.split_documents(documents)
embeddingsModel = HuggingFaceEmbeddings(
model_name="sentence-transformers/msmarco-distilbert-dot-v5"
)
retriever = FAISS.from_documents(texts, embeddingsModel).as_retriever(
search_kwargs={"k": 20}
)

query = "What is the plan for the economy?"
docs = retriever.invoke(query)
pretty_print_docs(docs)

使用 CrossEncoderReranker 進行重新排序

現在,我們將基礎檢索器與 ContextualCompressionRetriever 封裝在一起。CrossEncoderReranker 使用 HuggingFaceCrossEncoder 重新排序傳回的結果。

from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import CrossEncoderReranker
from langchain_community.cross_encoders import HuggingFaceCrossEncoder

model = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-base")
compressor = CrossEncoderReranker(model=model, top_n=3)
compression_retriever = ContextualCompressionRetriever(
base_compressor=compressor, base_retriever=retriever
)

compressed_docs = compression_retriever.invoke("What is the plan for the economy?")
pretty_print_docs(compressed_docs)
Document 1:

More infrastructure and innovation in America.

More goods moving faster and cheaper in America.

More jobs where you can earn a good living in America.

And instead of relying on foreign supply chains, let’s make it in America.

Economists call it “increasing the productive capacity of our economy.”

I call it building a better America.

My plan to fight inflation will lower your costs and lower the deficit.
----------------------------------------------------------------------------------------------------
Document 2:

Second – cut energy costs for families an average of $500 a year by combatting climate change.

Let’s provide investments and tax credits to weatherize your homes and businesses to be energy efficient and you get a tax credit; double America’s clean energy production in solar, wind, and so much more; lower the price of electric vehicles, saving you another $80 a month because you’ll never have to pay at the gas pump again.
----------------------------------------------------------------------------------------------------
Document 3:

Look at cars.

Last year, there weren’t enough semiconductors to make all the cars that people wanted to buy.

And guess what, prices of automobiles went up.

So—we have a choice.

One way to fight inflation is to drive down wages and make Americans poorer.

I have a better plan to fight inflation.

Lower your costs, not your wages.

Make more cars and semiconductors in America.

More infrastructure and innovation in America.

More goods moving faster and cheaper in America.

將 Hugging Face 模型上傳到 SageMaker 端點

以下是範例 inference.py,用於建立適用於 SagemakerEndpointCrossEncoder 的端點。如需逐步指南的更多詳細資訊,請參閱 本文

它會即時下載 Hugging Face 模型,因此您不需要將模型成品 (例如 pytorch_model.bin) 保留在您的 model.tar.gz 中。

import json
import logging
from typing import List

import torch
from sagemaker_inference import encoder
from transformers import AutoModelForSequenceClassification, AutoTokenizer

PAIRS = "pairs"
SCORES = "scores"


class CrossEncoder:
def __init__(self) -> None:
self.device = (
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
)
logging.info(f"Using device: {self.device}")
model_name = "BAAI/bge-reranker-base"
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
self.model = self.model.to(self.device)

def __call__(self, pairs: List[List[str]]) -> List[float]:
with torch.inference_mode():
inputs = self.tokenizer(
pairs,
padding=True,
truncation=True,
return_tensors="pt",
max_length=512,
)
inputs = inputs.to(self.device)
scores = (
self.model(**inputs, return_dict=True)
.logits.view(
-1,
)
.float()
)

return scores.detach().cpu().tolist()


def model_fn(model_dir: str) -> CrossEncoder:
try:
return CrossEncoder()
except Exception:
logging.exception(f"Failed to load model from: {model_dir}")
raise


def transform_fn(
cross_encoder: CrossEncoder, input_data: bytes, content_type: str, accept: str
) -> bytes:
payload = json.loads(input_data)
model_output = cross_encoder(**payload)
output = {SCORES: model_output}
return encoder.encode(output, accept)

此頁面是否對您有幫助?