NLP Course documentation

使用 FAISS 進行語義搜索

Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

使用 FAISS 進行語義搜索

Ask a Question Open In Colab Open In Studio Lab

第五小節, 我們從 🤗 Datasets 存儲庫創建了一個包含 GitHub 問題和評論的數據集。在本節中,我們將使用這些信息來構建一個搜索引擎,它可以幫助我們找到這個庫最緊迫問題的答案!

使用嵌入進行語義搜索

正如我們在第一章,學習的, 基於 Transformer 的語言模型會將文本中的每個標記轉換為嵌入向量.事實證明,可以“彙集”各個嵌入向量來創建整個句子、段落或文檔(在某些情況下)的向量表示。然後,通過計算每個嵌入之間的點積相似度(或其他一些相似度度量)並返回相似度最大的文檔,這些嵌入可用於在語料庫中找到相似的文檔。在本節中,我們將使用嵌入來開發語義搜索引擎。與基於將查詢中的關鍵字的傳統方法相比,這些搜索引擎具有多種優勢。

Semantic search.

## 加載和準備數據集

我們需要做的第一件事是下載我們的 GitHub 問題數據集,所以讓我們使用 🤗 Hub 庫來解析我們的文件在 Hugging Face Hub 上存儲的數據:

from huggingface_hub import hf_hub_url

data_files = hf_hub_url(
    repo_id="lewtun/github-issues",
    filename="datasets-issues-with-comments.jsonl",
    repo_type="dataset",
)

將 URL 存儲在 data_files ,然後我們可以使用第二小節介紹的方法加載遠程數據集:

from datasets import load_dataset

issues_dataset = load_dataset("json", data_files=data_files, split="train")
issues_dataset
Dataset({
    features: ['url', 'repository_url', 'labels_url', 'comments_url', 'events_url', 'html_url', 'id', 'node_id', 'number', 'title', 'user', 'labels', 'state', 'locked', 'assignee', 'assignees', 'milestone', 'comments', 'created_at', 'updated_at', 'closed_at', 'author_association', 'active_lock_reason', 'pull_request', 'body', 'performed_via_github_app', 'is_pull_request'],
    num_rows: 2855
})

這裡我們在load_dataset()中使用了默認的訓練集分割,所以它返回一個數據集而不是數據集字典。第一項任務是過濾掉pull請求,因為這些請求很少用於回答用戶提出的問題,而且會給我們的搜索引擎帶來噪聲。現在應該很熟悉了,我們可以使用dataset.filter()函數來排除數據集中的這些行。同時,讓我們也過濾掉沒有註釋的行,因為這些行不會是用戶提問的答案:

issues_dataset = issues_dataset.filter(
    lambda x: (x["is_pull_request"] == False and len(x["comments"]) > 0)
)
issues_dataset
Dataset({
    features: ['url', 'repository_url', 'labels_url', 'comments_url', 'events_url', 'html_url', 'id', 'node_id', 'number', 'title', 'user', 'labels', 'state', 'locked', 'assignee', 'assignees', 'milestone', 'comments', 'created_at', 'updated_at', 'closed_at', 'author_association', 'active_lock_reason', 'pull_request', 'body', 'performed_via_github_app', 'is_pull_request'],
    num_rows: 771
})

我們可以看到我們的數據集中有很多列,其中大部分我們不需要構建我們的搜索引擎。從搜索的角度來看,信息量最大的列是 title , body , 和 comments ,而 html_url 為我們提供了一個回到源問題的鏈接。讓我們使用 Dataset.remove_columns() 刪除其餘部分的功能:

columns = issues_dataset.column_names
columns_to_keep = ["title", "body", "html_url", "comments"]
columns_to_remove = set(columns_to_keep).symmetric_difference(columns)
issues_dataset = issues_dataset.remove_columns(columns_to_remove)
issues_dataset
Dataset({
    features: ['html_url', 'title', 'comments', 'body'],
    num_rows: 771
})

為了創建我們的嵌入,我們將用問題的標題和正文來擴充每條評論,因為這些字段通常包含有用的上下文信息。因為我們的 comments 列當前是每個問題的評論列表,我們需要“重新組合”列,以便每一條評論都包含一個 (html_url, title, body, comment) 元組。在 Pandas 中,我們可以使用 DataFrame.explode() 函數, 它為類似列表的列中的每個元素創建一個新行,同時複製所有其他列值。為了看到它的實際效果,讓我們首先切換到 Pandas的DataFrame 格式:

issues_dataset.set_format("pandas")
df = issues_dataset[:]

如果我們檢查這裡的第一行 DataFrame 我們可以看到有四個評論與這個問題相關:

df["comments"][0].tolist()
['the bug code locate in :\r\n    if data_args.task_name is not None:\r\n        # Downloading and loading a dataset from the hub.\r\n        datasets = load_dataset("glue", data_args.task_name, cache_dir=model_args.cache_dir)',
 'Hi @jinec,\r\n\r\nFrom time to time we get this kind of `ConnectionError` coming from the github.com website: https://raw.githubusercontent.com\r\n\r\nNormally, it should work if you wait a little and then retry.\r\n\r\nCould you please confirm if the problem persists?',
 'cannot connect,even by Web browser,please check that  there is some  problems。',
 'I can access https://raw.githubusercontent.com/huggingface/datasets/1.7.0/datasets/glue/glue.py without problem...']

我們希望這些評論中的每一條都得到一行。讓我們檢查是否是這種情況:

comments_df = df.explode("comments", ignore_index=True)
comments_df.head(4)
html_url title comments body
0 https://github.com/huggingface/datasets/issues/2787 ConnectionError: Couldn't reach https://raw.githubusercontent.com the bug code locate in :\r\n if data_args.task_name is not None... Hello,\r\nI am trying to run run_glue.py and it gives me this error...
1 https://github.com/huggingface/datasets/issues/2787 ConnectionError: Couldn't reach https://raw.githubusercontent.com Hi @jinec,\r\n\r\nFrom time to time we get this kind of `ConnectionError` coming from the github.com website: https://raw.githubusercontent.com... Hello,\r\nI am trying to run run_glue.py and it gives me this error...
2 https://github.com/huggingface/datasets/issues/2787 ConnectionError: Couldn't reach https://raw.githubusercontent.com cannot connect,even by Web browser,please check that there is some problems。 Hello,\r\nI am trying to run run_glue.py and it gives me this error...
3 https://github.com/huggingface/datasets/issues/2787 ConnectionError: Couldn't reach https://raw.githubusercontent.com I can access https://raw.githubusercontent.com/huggingface/datasets/1.7.0/datasets/glue/glue.py without problem... Hello,\r\nI am trying to run run_glue.py and it gives me this error...

太好了,我們可以看到評論成功被擴充, comments 是包含個人評論的列!現在我們已經完成了 Pandas要完成的部分功能,我們可以快速切換回 Dataset 通過加載 DataFrame 在內存中:

from datasets import Dataset

comments_dataset = Dataset.from_pandas(comments_df)
comments_dataset
Dataset({
    features: ['html_url', 'title', 'comments', 'body'],
    num_rows: 2842
})

太好了,我們獲取到了幾千條的評論!

✏️ Try it out! 看看能不能不用pandas就可以完成列的擴充; 這有點棘手; 你可能會發現 🤗 Datasets 文檔的 “Batch mapping” 對這個任務很有用。

現在我們每行有一個評論,讓我們創建一個新的 comments_length 列來存放每條評論的字數:

comments_dataset = comments_dataset.map(
    lambda x: {"comment_length": len(x["comments"].split())}
)

我們可以使用這個新列來過濾掉簡短的評論,其中通常包括“cc @lewtun”或“謝謝!”之類與我們的搜索引擎無關的內容。雖然無法為過濾器選擇的精確數字,但大約大於15 個單詞似乎是一個不錯的選擇:

comments_dataset = comments_dataset.filter(lambda x: x["comment_length"] > 15)
comments_dataset
Dataset({
    features: ['html_url', 'title', 'comments', 'body', 'comment_length'],
    num_rows: 2098
})

稍微清理了我們的數據集後,讓我們將問題標題、描述和評論連接到一個新的 text 列。像往常一樣,我們可以編寫一個簡單的函數,並將其傳遞給 Dataset.map()來做到這些 :

def concatenate_text(examples):
    return {
        "text": examples["title"]
        + " \n "
        + examples["body"]
        + " \n "
        + examples["comments"]
    }


comments_dataset = comments_dataset.map(concatenate_text)

我們終於準備好創建一些嵌入了!讓我們來看看。

創建文本嵌入

我們在第二章 學過,我們可以通過使用 AutoModel 類來完成詞嵌入。我們需要做的就是選擇一個合適的檢查點來加載模型。幸運的是,有一個名為 sentence-transformers 專門用於創建詞嵌入。如庫中文檔, 所述的,我們這次要實現的是非對稱語義搜索,因為我們有一個簡短的查詢,我們希望在比如問題評論等更長的文檔中找到其答案。通過查看模型概述表 我們可以發現 multi-qa-mpnet-base-dot-v1 檢查點在語義搜索方面具有最佳性能,因此我們將在我們的應用程序中使用它。我們還將使用相同的檢查點加載標記器:

from transformers import AutoTokenizer, AutoModel

model_ckpt = "sentence-transformers/multi-qa-mpnet-base-dot-v1"
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
model = AutoModel.from_pretrained(model_ckpt)

為了加快嵌入過程,將模型和輸入放在 GPU 設備上,所以現在讓我們這樣做:

import torch

device = torch.device("cuda")
model.to(device)

正如我們之前提到的,我們希望將 GitHub 問題語料庫中的每個條目表示為單個向量,因此我們需要以某種方式“池化”或平均化我們的標記嵌入。一種流行的方法是在我們模型的輸出上執行CLS 池化,我們只獲取[CLS] 令牌的最後一個隱藏狀態。以下函數為我們提供了這樣的方法:

def cls_pooling(model_output):
    return model_output.last_hidden_state[:, 0]

接下來,我們將創建一個輔助函數,該函數將標記文檔列表,將tensor放在 GPU 上,然後提供給模型,最後對輸出使用CLS 池化:

def get_embeddings(text_list):
    encoded_input = tokenizer(
        text_list, padding=True, truncation=True, return_tensors="pt"
    )
    encoded_input = {k: v.to(device) for k, v in encoded_input.items()}
    model_output = model(**encoded_input)
    return cls_pooling(model_output)

我們可以通過在我們的語料庫中輸入第一個文本條目並檢查輸出維度來測試該函數是否有效:

embedding = get_embeddings(comments_dataset["text"][0])
embedding.shape
torch.Size([1, 768])

太好了,我們已經將語料庫中的第一個條目轉換為 768 維向量!我們可以用 Dataset.map() 應用我們的 get_embeddings() 函數到我們語料庫中的每一行,所以讓我們創建一個新的 embeddings 列如下:

embeddings_dataset = comments_dataset.map(
    lambda x: {"embeddings": get_embeddings(x["text"]).detach().cpu().numpy()[0]}
)

請注意,我們已經將嵌入轉換為 NumPy 數組——這是因為當我們嘗試使用 FAISS 索引它們時,🤗 Datasets需要這種格式,我們接下來會這樣做。

使用 FAISS 進行高效的相似性搜索

現在我們有了一個詞嵌入數據集,我們需要一些方法來搜索它們。為此,我們將在 🤗 Datasets中使用一種特殊的數據結構,稱為 FAISS指數.FAISS (short for Facebook AI Similarity Search) (Facebook AI Similarity Search 的縮寫)是一個提供高效算法來快速搜索和聚類嵌入向量的庫。FAISS 背後的基本思想是創建一個特殊的數據結構,稱為指數。這允許人們找到哪些嵌入詞與輸入的詞嵌入相似。在 🤗 Datasets中創建一個 FAISS 索引很簡單——我們使用 Dataset.add_faiss_index() 函數並指定我們要索引的數據集的哪一列:

embeddings_dataset.add_faiss_index(column="embeddings")

現在,我們可以使用Dataset.get_nearest_examples()函數進行最近鄰居查找。讓我們通過首先嵌入一個問題來測試這一點,如下所示:

question = "How can I load a dataset offline?"
question_embedding = get_embeddings([question]).cpu().detach().numpy()
question_embedding.shape
torch.Size([1, 768])

就像文檔一樣,我們現在有一個 768 維向量表示查詢,我們可以將其與整個語料庫進行比較以找到最相似的嵌入:

scores, samples = embeddings_dataset.get_nearest_examples(
    "embeddings", question_embedding, k=5
)

Dataset.get_nearest_examples() 函數返回一個分數元組,對查詢和文檔之間的相似度進行排序,以及一組最佳匹配的結果(這裡是 5 個)。讓我們把這些收集到一個 pandas.DataFrame 以便我們可以輕鬆地對它們進行排序:

import pandas as pd

samples_df = pd.DataFrame.from_dict(samples)
samples_df["scores"] = scores
samples_df.sort_values("scores", ascending=False, inplace=True)

現在我們可以遍歷前幾行來查看我們的查詢與評論的匹配程度:

for _, row in samples_df.iterrows():
    print(f"COMMENT: {row.comments}")
    print(f"SCORE: {row.scores}")
    print(f"TITLE: {row.title}")
    print(f"URL: {row.html_url}")
    print("=" * 50)
    print()
"""
COMMENT: Requiring online connection is a deal breaker in some cases unfortunately so it'd be great if offline mode is added similar to how `transformers` loads models offline fine.

@mandubian's second bullet point suggests that there's a workaround allowing you to use your offline (custom?) dataset with `datasets`. Could you please elaborate on how that should look like?
SCORE: 25.505046844482422
TITLE: Discussion using datasets in offline mode
URL: https://github.com/huggingface/datasets/issues/824
==================================================

COMMENT: The local dataset builders (csv, text , json and pandas) are now part of the `datasets` package since #1726 :)
You can now use them offline
\`\`\`python
datasets = load_dataset("text", data_files=data_files)
\`\`\`

We'll do a new release soon
SCORE: 24.555509567260742
TITLE: Discussion using datasets in offline mode
URL: https://github.com/huggingface/datasets/issues/824
==================================================

COMMENT: I opened a PR that allows to reload modules that have already been loaded once even if there's no internet.

Let me know if you know other ways that can make the offline mode experience better. I'd be happy to add them :)

I already note the "freeze" modules option, to prevent local modules updates. It would be a cool feature.

----------

> @mandubian's second bullet point suggests that there's a workaround allowing you to use your offline (custom?) dataset with `datasets`. Could you please elaborate on how that should look like?

Indeed `load_dataset` allows to load remote dataset script (squad, glue, etc.) but also you own local ones.
For example if you have a dataset script at `./my_dataset/my_dataset.py` then you can do
\`\`\`python
load_dataset("./my_dataset")
\`\`\`
and the dataset script will generate your dataset once and for all.

----------

About I'm looking into having `csv`, `json`, `text`, `pandas` dataset builders already included in the `datasets` package, so that they are available offline by default, as opposed to the other datasets that require the script to be downloaded.
cf #1724
SCORE: 24.14896583557129
TITLE: Discussion using datasets in offline mode
URL: https://github.com/huggingface/datasets/issues/824
==================================================

COMMENT: > here is my way to load a dataset offline, but it **requires** an online machine
>
> 1. (online machine)
>
> ```
>
> import datasets
>
> data = datasets.load_dataset(...)
>
> data.save_to_disk(/YOUR/DATASET/DIR)
>
> ```
>
> 2. copy the dir from online to the offline machine
>
> 3. (offline machine)
>
> ```
>
> import datasets
>
> data = datasets.load_from_disk(/SAVED/DATA/DIR)
>
> ```
>
>
>
> HTH.


SCORE: 22.893993377685547
TITLE: Discussion using datasets in offline mode
URL: https://github.com/huggingface/datasets/issues/824
==================================================

COMMENT: here is my way to load a dataset offline, but it **requires** an online machine
1. (online machine)
\`\`\`
import datasets
data = datasets.load_dataset(...)
data.save_to_disk(/YOUR/DATASET/DIR)
\`\`\`
2. copy the dir from online to the offline machine
3. (offline machine)
\`\`\`
import datasets
data = datasets.load_from_disk(/SAVED/DATA/DIR)
\`\`\`

HTH.
SCORE: 22.406635284423828
TITLE: Discussion using datasets in offline mode
URL: https://github.com/huggingface/datasets/issues/824
==================================================
"""

我們的第二個搜索結果似乎與查詢相符。

✏️ 試試看!創建您自己的查詢並查看您是否可以在檢索到的文檔中找到答案。您可能需要增加參數k以擴大搜索範圍。