跳至主要內容

ScaNN

ScaNN (Scalable Nearest Neighbors) 是一種用於大規模高效向量相似度搜尋的方法。

ScaNN 包含搜尋空間修剪和量化,用於最大內積搜尋,並且還支援其他距離函數,例如歐幾里得距離。該實作針對具有 AVX2 支援的 x86 處理器進行了最佳化。 有關更多詳細資訊,請參閱其 Google Research github

您需要使用 pip install -qU langchain-community 安裝 langchain-community 才能使用此整合。

安裝 (Installation)

透過 pip 安裝 ScaNN。或者,您可以按照 ScaNN 網站上的說明從原始碼安裝。

%pip install --upgrade --quiet  scann

檢索範例 (Retrieval Demo)

下面我們展示如何將 ScaNN 與 Huggingface Embeddings 結合使用。

from langchain_community.document_loaders import TextLoader
from langchain_community.vectorstores import ScaNN
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_text_splitters import CharacterTextSplitter

loader = TextLoader("state_of_the_union.txt")
documents = loader.load()
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
docs = text_splitter.split_documents(documents)


model_name = "sentence-transformers/all-mpnet-base-v2"
embeddings = HuggingFaceEmbeddings(model_name=model_name)

db = ScaNN.from_documents(docs, embeddings)
query = "What did the president say about Ketanji Brown Jackson"
docs = db.similarity_search(query)

docs[0]

RetrievalQA 範例 (RetrievalQA Demo)

接下來,我們示範如何將 ScaNN 與 Google PaLM API 結合使用。

您可以從 https://developers.generativeai.google/tutorials/setup 獲取 API 金鑰

from langchain.chains import RetrievalQA
from langchain_community.chat_models.google_palm import ChatGooglePalm

palm_client = ChatGooglePalm(google_api_key="YOUR_GOOGLE_PALM_API_KEY")

qa = RetrievalQA.from_chain_type(
llm=palm_client,
chain_type="stuff",
retriever=db.as_retriever(search_kwargs={"k": 10}),
)
print(qa.run("What did the president say about Ketanji Brown Jackson?"))
The president said that Ketanji Brown Jackson is one of our nation's top legal minds, who will continue Justice Breyer's legacy of excellence.
print(qa.run("What did the president say about Michael Phelps?"))
The president did not mention Michael Phelps in his speech.

儲存和載入本機檢索索引 (Save and loading local retrieval index)

db.save_local("/tmp/db", "state_of_union")
restored_db = ScaNN.load_local("/tmp/db", embeddings, index_name="state_of_union")

此頁面是否有幫助?(Was this page helpful?)