NLP Course documentation

使用 FAISS 进行语义搜索

Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

使用 FAISS 进行语义搜索

Ask a Question Open In Colab Open In Studio Lab

第五小节, 我们从 🤗 Datasets 存储库创建了一个包含 GitHub 问题和评论的数据集。在本节中,我们将使用这些信息来构建一个搜索引擎,它可以帮助我们找到这个库最紧迫问题的答案!

使用嵌入进行语义搜索

正如我们在第一章,学习的, 基于 Transformer 的语言模型会将文本中的每个标记转换为嵌入向量.事实证明,可以“汇集”各个嵌入向量来创建整个句子、段落或文档(在某些情况下)的向量表示。然后,通过计算每个嵌入之间的点积相似度(或其他一些相似度度量)并返回相似度最大的文档,这些嵌入可用于在语料库中找到相似的文档。在本节中,我们将使用嵌入来开发语义搜索引擎。与基于将查询中的关键字的传统方法相比,这些搜索引擎具有多种优势。

Semantic search.

加载和准备数据集

我们需要做的第一件事是下载我们的 GitHub issues 数据集,所以让我们像往常那样使用 load_dataset()函数:

from datasets import load_dataset

issues_dataset = load_dataset("lewtun/github-issues", split="train")

issues_dataset
Dataset({
    features: ['url', 'repository_url', 'labels_url', 'comments_url', 'events_url', 'html_url', 'id', 'node_id', 'number', 'title', 'user', 'labels', 'state', 'locked', 'assignee', 'assignees', 'milestone', 'comments', 'created_at', 'updated_at', 'closed_at', 'author_association', 'active_lock_reason', 'pull_request', 'body', 'performed_via_github_app', 'is_pull_request'],
    num_rows: 2855
})

这里我们在load_dataset()中使用了默认的训练集分割,所以它返回一个数据集而不是数据集字典。第一项任务是过滤掉pull请求,因为这些请求很少用于回答用户提出的问题,而且会给我们的搜索引擎带来噪声。现在应该很熟悉了,我们可以使用dataset.filter()函数来排除数据集中的这些行。同时,让我们也过滤掉没有注释的行,因为这些行不会是用户提问的答案:

issues_dataset = issues_dataset.filter(
    lambda x: (x["is_pull_request"] == False and len(x["comments"]) > 0)
)
issues_dataset
Dataset({
    features: ['url', 'repository_url', 'labels_url', 'comments_url', 'events_url', 'html_url', 'id', 'node_id', 'number', 'title', 'user', 'labels', 'state', 'locked', 'assignee', 'assignees', 'milestone', 'comments', 'created_at', 'updated_at', 'closed_at', 'author_association', 'active_lock_reason', 'pull_request', 'body', 'performed_via_github_app', 'is_pull_request'],
    num_rows: 771
})

我们可以看到我们的数据集中有很多列,其中大部分我们不需要构建我们的搜索引擎。从搜索的角度来看,信息量最大的列是 title , body , 和 comments ,而 html_url 为我们提供了一个回到源问题的链接。让我们使用 Dataset.remove_columns() 删除其余部分的功能:

columns = issues_dataset.column_names
columns_to_keep = ["title", "body", "html_url", "comments"]
columns_to_remove = set(columns_to_keep).symmetric_difference(columns)
issues_dataset = issues_dataset.remove_columns(columns_to_remove)
issues_dataset
Dataset({
    features: ['html_url', 'title', 'comments', 'body'],
    num_rows: 771
})

为了创建我们的嵌入,我们将用问题的标题和正文来扩充每条评论,因为这些字段通常包含有用的上下文信息。因为我们的 comments 列当前是每个问题的评论列表,我们需要“重新组合”列,以便每一条评论都包含一个 (html_url, title, body, comment) 元组。在 Pandas 中,我们可以使用 DataFrame.explode() 函数, 它为类似列表的列中的每个元素创建一个新行,同时复制所有其他列值。为了看到它的实际效果,让我们首先切换到 Pandas的DataFrame 格式:

issues_dataset.set_format("pandas")
df = issues_dataset[:]

如果我们检查这里的第一行 DataFrame 我们可以看到有四个评论与这个问题相关:

df["comments"][0].tolist()
['the bug code locate in :\r\n    if data_args.task_name is not None:\r\n        # Downloading and loading a dataset from the hub.\r\n        datasets = load_dataset("glue", data_args.task_name, cache_dir=model_args.cache_dir)',
 'Hi @jinec,\r\n\r\nFrom time to time we get this kind of `ConnectionError` coming from the github.com website: https://raw.githubusercontent.com\r\n\r\nNormally, it should work if you wait a little and then retry.\r\n\r\nCould you please confirm if the problem persists?',
 'cannot connect,even by Web browser,please check that  there is some  problems。',
 'I can access https://raw.githubusercontent.com/huggingface/datasets/1.7.0/datasets/glue/glue.py without problem...']

我们希望这些评论中的每一条都得到一行。让我们检查是否是这种情况:

comments_df = df.explode("comments", ignore_index=True)
comments_df.head(4)
html_url title comments body
0 https://github.com/huggingface/datasets/issues/2787 ConnectionError: Couldn't reach https://raw.githubusercontent.com the bug code locate in :\r\n if data_args.task_name is not None... Hello,\r\nI am trying to run run_glue.py and it gives me this error...
1 https://github.com/huggingface/datasets/issues/2787 ConnectionError: Couldn't reach https://raw.githubusercontent.com Hi @jinec,\r\n\r\nFrom time to time we get this kind of `ConnectionError` coming from the github.com website: https://raw.githubusercontent.com... Hello,\r\nI am trying to run run_glue.py and it gives me this error...
2 https://github.com/huggingface/datasets/issues/2787 ConnectionError: Couldn't reach https://raw.githubusercontent.com cannot connect,even by Web browser,please check that there is some problems。 Hello,\r\nI am trying to run run_glue.py and it gives me this error...
3 https://github.com/huggingface/datasets/issues/2787 ConnectionError: Couldn't reach https://raw.githubusercontent.com I can access https://raw.githubusercontent.com/huggingface/datasets/1.7.0/datasets/glue/glue.py without problem... Hello,\r\nI am trying to run run_glue.py and it gives me this error...

太好了,我们可以看到评论成功被扩充, comments 是包含个人评论的列!现在我们已经完成了 Pandas要完成的部分功能,我们可以快速切换回 Dataset 通过加载 DataFrame 在内存中:

from datasets import Dataset

comments_dataset = Dataset.from_pandas(comments_df)
comments_dataset
Dataset({
    features: ['html_url', 'title', 'comments', 'body'],
    num_rows: 2842
})

太好了,我们获取到了几千条的评论!

✏️ Try it out! 看看能不能不用pandas就可以完成列的扩充; 这有点棘手; 你可能会发现 🤗 Datasets 文档的 “Batch mapping” 对这个任务很有用。

现在我们每行有一个评论,让我们创建一个新的 comments_length 列来存放每条评论的字数:

comments_dataset = comments_dataset.map(
    lambda x: {"comment_length": len(x["comments"].split())}
)

我们可以使用这个新列来过滤掉简短的评论,其中通常包括“cc @lewtun”或“谢谢!”之类与我们的搜索引擎无关的内容。虽然无法为过滤器选择的精确数字,但大约大于15 个单词似乎是一个不错的选择:

comments_dataset = comments_dataset.filter(lambda x: x["comment_length"] > 15)
comments_dataset
Dataset({
    features: ['html_url', 'title', 'comments', 'body', 'comment_length'],
    num_rows: 2098
})

稍微清理了我们的数据集后,让我们将问题标题、描述和评论连接到一个新的 text 列。像往常一样,我们可以编写一个简单的函数,并将其传递给 Dataset.map()来做到这些 :

def concatenate_text(examples):
    return {
        "text": examples["title"]
        + " \n "
        + examples["body"]
        + " \n "
        + examples["comments"]
    }


comments_dataset = comments_dataset.map(concatenate_text)

我们终于准备好创建一些嵌入了!让我们来看看。

创建文本嵌入

我们在第二章 学过,我们可以通过使用 AutoModel 类来完成词嵌入。我们需要做的就是选择一个合适的检查点来加载模型。幸运的是,有一个名为 sentence-transformers 专门用于创建词嵌入。如库中文档, 所述的,我们这次要实现的是非对称语义搜索,因为我们有一个简短的查询,我们希望在比如问题评论等更长的文档中找到其答案。通过查看模型概述表 我们可以发现 multi-qa-mpnet-base-dot-v1 检查点在语义搜索方面具有最佳性能,因此我们将在我们的应用程序中使用它。我们还将使用相同的检查点加载标记器:

from transformers import AutoTokenizer, AutoModel

model_ckpt = "sentence-transformers/multi-qa-mpnet-base-dot-v1"
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
model = AutoModel.from_pretrained(model_ckpt)

为了加快嵌入过程,将模型和输入放在 GPU 设备上,所以现在让我们这样做:

import torch

device = torch.device("cuda")
model.to(device)

正如我们之前提到的,我们希望将 GitHub 问题语料库中的每个条目表示为单个向量,因此我们需要以某种方式“池化”或平均化我们的标记嵌入。一种流行的方法是在我们模型的输出上执行CLS 池化,我们只获取[CLS] 令牌的最后一个隐藏状态。以下函数为我们提供了这样的方法:

def cls_pooling(model_output):
    return model_output.last_hidden_state[:, 0]

接下来,我们将创建一个辅助函数,该函数将标记文档列表,将tensor放在 GPU 上,然后提供给模型,最后对输出使用CLS 池化:

def get_embeddings(text_list):
    encoded_input = tokenizer(
        text_list, padding=True, truncation=True, return_tensors="pt"
    )
    encoded_input = {k: v.to(device) for k, v in encoded_input.items()}
    model_output = model(**encoded_input)
    return cls_pooling(model_output)

我们可以通过在我们的语料库中输入第一个文本条目并检查输出维度来测试该函数是否有效:

embedding = get_embeddings(comments_dataset["text"][0])
embedding.shape
torch.Size([1, 768])

太好了,我们已经将语料库中的第一个条目转换为 768 维向量!我们可以用 Dataset.map() 应用我们的 get_embeddings() 函数到我们语料库中的每一行,所以让我们创建一个新的 embeddings 列如下:

embeddings_dataset = comments_dataset.map(
    lambda x: {"embeddings": get_embeddings(x["text"]).detach().cpu().numpy()[0]}
)

请注意,我们已经将嵌入转换为 NumPy 数组——这是因为当我们尝试使用 FAISS 索引它们时,🤗 Datasets需要这种格式,我们接下来会这样做。

使用 FAISS 进行高效的相似性搜索

现在我们有了一个词嵌入数据集,我们需要一些方法来搜索它们。为此,我们将在 🤗 Datasets中使用一种特殊的数据结构,称为 FAISS指数.FAISS (short for Facebook AI Similarity Search) (Facebook AI Similarity Search 的缩写)是一个提供高效算法来快速搜索和聚类嵌入向量的库。FAISS 背后的基本思想是创建一个特殊的数据结构,称为指数。这允许人们找到哪些嵌入词与输入的词嵌入相似。在 🤗 Datasets中创建一个 FAISS 索引很简单——我们使用 Dataset.add_faiss_index() 函数并指定我们要索引的数据集的哪一列:

embeddings_dataset.add_faiss_index(column="embeddings")

现在,我们可以使用Dataset.get_nearest_examples()函数进行最近邻居查找。让我们通过首先嵌入一个问题来测试这一点,如下所示:

question = "How can I load a dataset offline?"
question_embedding = get_embeddings([question]).cpu().detach().numpy()
question_embedding.shape
torch.Size([1, 768])

就像文档一样,我们现在有一个 768 维向量表示查询,我们可以将其与整个语料库进行比较以找到最相似的嵌入:

scores, samples = embeddings_dataset.get_nearest_examples(
    "embeddings", question_embedding, k=5
)

Dataset.get_nearest_examples() 函数返回一个分数元组,对查询和文档之间的相似度进行排序,以及一组最佳匹配的结果(这里是 5 个)。让我们把这些收集到一个 pandas.DataFrame 以便我们可以轻松地对它们进行排序:

import pandas as pd

samples_df = pd.DataFrame.from_dict(samples)
samples_df["scores"] = scores
samples_df.sort_values("scores", ascending=False, inplace=True)

现在我们可以遍历前几行来查看我们的查询与评论的匹配程度:

for _, row in samples_df.iterrows():
    print(f"COMMENT: {row.comments}")
    print(f"SCORE: {row.scores}")
    print(f"TITLE: {row.title}")
    print(f"URL: {row.html_url}")
    print("=" * 50)
    print()
"""
COMMENT: Requiring online connection is a deal breaker in some cases unfortunately so it'd be great if offline mode is added similar to how `transformers` loads models offline fine.

@mandubian's second bullet point suggests that there's a workaround allowing you to use your offline (custom?) dataset with `datasets`. Could you please elaborate on how that should look like?
SCORE: 25.505046844482422
TITLE: Discussion using datasets in offline mode
URL: https://github.com/huggingface/datasets/issues/824
==================================================

COMMENT: The local dataset builders (csv, text , json and pandas) are now part of the `datasets` package since #1726 :)
You can now use them offline
\`\`\`python
datasets = load_dataset("text", data_files=data_files)
\`\`\`

We'll do a new release soon
SCORE: 24.555509567260742
TITLE: Discussion using datasets in offline mode
URL: https://github.com/huggingface/datasets/issues/824
==================================================

COMMENT: I opened a PR that allows to reload modules that have already been loaded once even if there's no internet.

Let me know if you know other ways that can make the offline mode experience better. I'd be happy to add them :)

I already note the "freeze" modules option, to prevent local modules updates. It would be a cool feature.

----------

> @mandubian's second bullet point suggests that there's a workaround allowing you to use your offline (custom?) dataset with `datasets`. Could you please elaborate on how that should look like?

Indeed `load_dataset` allows to load remote dataset script (squad, glue, etc.) but also you own local ones.
For example if you have a dataset script at `./my_dataset/my_dataset.py` then you can do
\`\`\`python
load_dataset("./my_dataset")
\`\`\`
and the dataset script will generate your dataset once and for all.

----------

About I'm looking into having `csv`, `json`, `text`, `pandas` dataset builders already included in the `datasets` package, so that they are available offline by default, as opposed to the other datasets that require the script to be downloaded.
cf #1724
SCORE: 24.14896583557129
TITLE: Discussion using datasets in offline mode
URL: https://github.com/huggingface/datasets/issues/824
==================================================

COMMENT: > here is my way to load a dataset offline, but it **requires** an online machine
>
> 1. (online machine)
>
> ```
>
> import datasets
>
> data = datasets.load_dataset(...)
>
> data.save_to_disk(/YOUR/DATASET/DIR)
>
> ```
>
> 2. copy the dir from online to the offline machine
>
> 3. (offline machine)
>
> ```
>
> import datasets
>
> data = datasets.load_from_disk(/SAVED/DATA/DIR)
>
> ```
>
>
>
> HTH.


SCORE: 22.893993377685547
TITLE: Discussion using datasets in offline mode
URL: https://github.com/huggingface/datasets/issues/824
==================================================

COMMENT: here is my way to load a dataset offline, but it **requires** an online machine
1. (online machine)
\`\`\`
import datasets
data = datasets.load_dataset(...)
data.save_to_disk(/YOUR/DATASET/DIR)
\`\`\`
2. copy the dir from online to the offline machine
3. (offline machine)
\`\`\`
import datasets
data = datasets.load_from_disk(/SAVED/DATA/DIR)
\`\`\`

HTH.
SCORE: 22.406635284423828
TITLE: Discussion using datasets in offline mode
URL: https://github.com/huggingface/datasets/issues/824
==================================================
"""

我们的第二个搜索结果似乎与查询相符。

✏️ 试试看!创建您自己的查询并查看您是否可以在检索到的文档中找到答案。您可能需要增加参数k以扩大搜索范围。