course documentation

使用 FAISS 进行语义搜索

Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

使用 FAISS 进行语义搜索

Ask a Question Open In Colab Open In Studio Lab

第5小节 ,我们创建了一个 🤗 Datasets 仓库的 GitHub issues 和评论组成的数据集。在本节,我们将使用这些信息构建一个搜索引擎,帮助我们找到关于该库的最相关的 issue 的答案!

使用文本嵌入进行语义搜索

正如我们在 第一章 ,学习的,基于 Transformer 的语言模型会将文本中的每个 token 转换为嵌入向量。事实证明,我们可以 “池化(pool)” 嵌入向量以创建整个句子、段落或(在某些情况下)文档的向量表示。然后,通过计算每个嵌入之间的点积相似度(或其他一些相似度度量)并返回相似度最大的文档,这些嵌入可用于在语料库中找到相似的文档。

在本节中,我们将使用文本嵌入向量来开发语义搜索引擎。与基于将查询中的关键字的传统方法相比,这些搜索引擎具有多种优势。

Semantic search.

加载和准备数据集

首先,我们需要下载我们的 GitHub issues 数据集,所以让我们像往常那样使用 load_dataset() 函数:

from datasets import load_dataset

issues_dataset = load_dataset("lewtun/github-issues", split="train")

issues_dataset
Dataset({
    features: ['url', 'repository_url', 'labels_url', 'comments_url', 'events_url', 'html_url', 'id', 'node_id', 'number', 'title', 'user', 'labels', 'state', 'locked', 'assignee', 'assignees', 'milestone', 'comments', 'created_at', 'updated_at', 'closed_at', 'author_association', 'active_lock_reason', 'pull_request', 'body', 'performed_via_github_app', 'is_pull_request'],
    num_rows: 2855
})

在此,我们在 load_dataset() 中指定了默认的 train(训练集) 部分,因此它返回一个 Dataset 而不是 DatasetDict 。首要任务是排除掉拉取请求(pull),因为这些请求往往很少用于回答提出的 issue,会为我们的搜索引擎引入干扰。正如我们现在熟悉的那样,我们可以使用 Dataset.filter() 函数来排除数据集中的这些行。与此同时,让我们也筛选掉没有评论的行,因为这些行没有为用户提问提供回答:

issues_dataset = issues_dataset.filter(
    lambda x: (x["is_pull_request"] == False and len(x["comments"]) > 0)
)
issues_dataset
Dataset({
    features: ['url', 'repository_url', 'labels_url', 'comments_url', 'events_url', 'html_url', 'id', 'node_id', 'number', 'title', 'user', 'labels', 'state', 'locked', 'assignee', 'assignees', 'milestone', 'comments', 'created_at', 'updated_at', 'closed_at', 'author_association', 'active_lock_reason', 'pull_request', 'body', 'performed_via_github_app', 'is_pull_request'],
    num_rows: 771
})

我们可以看到,我们的数据集中有很多列,其中大部分在构建我们的搜索引擎都不会使用。从搜索的角度来看,信息量最大的列是 titlebody ,和 comments ,而 html_url 为我们提供了一个回到原 issue 的链接。让我们使用 Dataset.remove_columns() 删除其余的列:

columns = issues_dataset.column_names
columns_to_keep = ["title", "body", "html_url", "comments"]
columns_to_remove = set(columns_to_keep).symmetric_difference(columns)
issues_dataset = issues_dataset.remove_columns(columns_to_remove)
issues_dataset
Dataset({
    features: ['html_url', 'title', 'comments', 'body'],
    num_rows: 771
})

为了创建我们的文本嵌入数据集,我们将用 issue 的标题和正文来扩充每条评论,因为这些字段通常包含有用的上下文信息。因为我们的 comments 列当前是每个 issue 的评论列表,我们需要“重新组合”列,使得每一行都是由一个 (html_url, title, body, comment) 元组组成。在 Pandas 中,我们可以使用 DataFrame.explode() 函数 完成这个操作 它为类似列表的列中的每个元素创建一个新行,同时复制所有其他列值。让我们首先切换到 Pandas 的 DataFrame 格式:

issues_dataset.set_format("pandas")
df = issues_dataset[:]

如果我们检查这个 DataFrame 的第一行,我们可以看到这个 issue 有四个相关评论:

df["comments"][0].tolist()
['the bug code locate in :\r\n    if data_args.task_name is not None:\r\n        # Downloading and loading a dataset from the hub.\r\n        datasets = load_dataset("glue", data_args.task_name, cache_dir=model_args.cache_dir)',
 'Hi @jinec,\r\n\r\nFrom time to time we get this kind of `ConnectionError` coming from the github.com website: https://raw.githubusercontent.com\r\n\r\nNormally, it should work if you wait a little and then retry.\r\n\r\nCould you please confirm if the problem persists?',
 'cannot connect,even by Web browser,please check that  there is some  problems。',
 'I can access https://raw.githubusercontent.com/huggingface/datasets/1.7.0/datasets/glue/glue.py without problem...']

我们希望使用 explode() 将这些评论中的每一条都展开成为一行。让我们看看是否可以做到:

comments_df = df.explode("comments", ignore_index=True)
comments_df.head(4)
html_url title comments body
0 https://github.com/huggingface/datasets/issues/2787 ConnectionError: Couldn't reach https://raw.githubusercontent.com the bug code locate in :\r\n if data_args.task_name is not None... Hello,\r\nI am trying to run run_glue.py and it gives me this error...
1 https://github.com/huggingface/datasets/issues/2787 ConnectionError: Couldn't reach https://raw.githubusercontent.com Hi @jinec,\r\n\r\nFrom time to time we get this kind of `ConnectionError` coming from the github.com website: https://raw.githubusercontent.com... Hello,\r\nI am trying to run run_glue.py and it gives me this error...
2 https://github.com/huggingface/datasets/issues/2787 ConnectionError: Couldn't reach https://raw.githubusercontent.com cannot connect,even by Web browser,please check that there is some problems。 Hello,\r\nI am trying to run run_glue.py and it gives me this error...
3 https://github.com/huggingface/datasets/issues/2787 ConnectionError: Couldn't reach https://raw.githubusercontent.com I can access https://raw.githubusercontent.com/huggingface/datasets/1.7.0/datasets/glue/glue.py without problem... Hello,\r\nI am trying to run run_glue.py and it gives me this error...

非常好,我们可以看到其他三列已经被复制了,并且 comments 列里存放着单独的评论!现在我们已经完成了 Pandas 要完成的部分功能,我们可以通过加载内存中的 DataFrame 快速切换回 Dataset

from datasets import Dataset

comments_dataset = Dataset.from_pandas(comments_df)
comments_dataset
Dataset({
    features: ['html_url', 'title', 'comments', 'body'],
    num_rows: 2842
})

太好了,我们获取到了几千条的评论!

✏️ 试试看! 看看你是否可以使用 Dataset.map() 展开 issues_datasetcomments 列,这有点棘手;你可能会发现 🤗 Datasets 文档的 “批处理映射(Batch mapping)” 对这个任务很有用。

既然我们每行有一个评论,让我们创建一个新的 comments_length 列来存放每条评论的字数:

comments_dataset = comments_dataset.map(
    lambda x: {"comment_length": len(x["comments"].split())}
)

我们可以使用这个新列来过滤掉简短的评论,其中通常包括“cc @lewtun”或“谢谢!”之类与我们的搜索引擎无关的内容。筛选的精确数字没有硬性规定,但大约大于 15 个单词似乎是一个不错的选择:

comments_dataset = comments_dataset.filter(lambda x: x["comment_length"] > 15)
comments_dataset
Dataset({
    features: ['html_url', 'title', 'comments', 'body', 'comment_length'],
    num_rows: 2098
})

稍微清理了我们的数据集后,让我们使用 issue 标题、描述和评论构建一个新的 text 列。像往常一样,我们可以编写一个简单的函数,并将其传递给 Dataset.map() 来完成这些操作

def concatenate_text(examples):
    return {
        "text": examples["title"]
        + " \n "
        + examples["body"]
        + " \n "
        + examples["comments"]
    }


comments_dataset = comments_dataset.map(concatenate_text)

我们终于准备好创建一些文本嵌入了!让我们继续。

创建文本嵌入

正如我们在 第二章 学过,我们可以通过使用 AutoModel 类来完成文本嵌入。首先需要做的就是选择一个合适的 checkpoint。幸运的是,有一个名为 sentence-transformers 的库专门用于创建文本嵌入。如库中的 文档 所述的,我们这次要实现的是非对称语义搜索(asymmetric semantic search),因为我们有一个简短的查询,我们希望在比如 issue 评论等更长的文档中找到其匹配的文本。通过查看 模型概述表 我们可以发现 multi-qa-mpnet-base-dot-v1 checkpoint 在语义搜索方面具有最佳性能,因此我们将使用它。我们还将使用这个 checkpoint 加载了对应的 tokenizer :

from transformers import AutoTokenizer, AutoModel

model_ckpt = "sentence-transformers/multi-qa-mpnet-base-dot-v1"
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
model = AutoModel.from_pretrained(model_ckpt)

将模型和输入的文本放到 GPU 上会加速嵌入过程,所以让我们现在就这么做:

import torch

device = torch.device("cuda")
model.to(device)

根据我们之前的想法,我们希望将我们的 GitHub issue 中的每一条记录转化为一个单一的向量,所以我们需要以某种方式“池化(pool)”或平均每个词的嵌入向量。一种流行的方法是在我们模型的输出上进行 CLS 池化 ,我们只需要收集 [CLS] token 的的最后一个隐藏状态。以下函数实现了这个功能:

def cls_pooling(model_output):
    return model_output.last_hidden_state[:, 0]

接下来,我们将创建一个辅助函数,它将对一组文档进行 tokenize,然后将张量放在 GPU 上,接着将它们喂给模型,最后对输出进行 CLS 池化:

def get_embeddings(text_list):
    encoded_input = tokenizer(
        text_list, padding=True, truncation=True, return_tensors="pt"
    )
    encoded_input = {k: v.to(device) for k, v in encoded_input.items()}
    model_output = model(**encoded_input)
    return cls_pooling(model_output)

我们可以将第一条数据喂给它并检查输出的形状来测试这个函数是否正常工作:

embedding = get_embeddings(comments_dataset["text"][0])
embedding.shape
torch.Size([1, 768])

太好了,我们已经将语料库中的第一个条目转换为了一个 768 维向量!我们可以用 Dataset.map() 将我们的 get_embeddings() 函数应用到我们语料库中的每一行,然后创建一个新的 embeddings 列:

embeddings_dataset = comments_dataset.map(
    lambda x: {"embeddings": get_embeddings(x["text"]).detach().cpu().numpy()[0]}
)

请注意,在上面的代码中我们已经将文本嵌入转换为 NumPy 数组——这是因为当我们尝试使用 FAISS 搜索它们时,🤗 Datasets 需要这种格式,让我们继续吧。

使用 FAISS 进行高效的相似性搜索

现在我们有了一个文本嵌入数据集,我们需要一些方法来搜索它们。为此,我们将使用🤗 Datasets 中一种特殊的数据结构,称为 FAISS 指数。 FAISS (Facebook AI Similarity Search 的缩写)是一个库,提供了用于快速搜索和聚类嵌入向量的高效算法。

FAISS 背后的基本思想是创建一个特殊的数据结构,称为 index(索引) 它可以找到哪些嵌入与输入嵌入相似。在 🤗 Datasets 中创建一个 FAISS index(索引)很简单——我们使用 Dataset.add_faiss_index() 函数并指定我们要索引的数据集的哪一列:

embeddings_dataset.add_faiss_index(column="embeddings")

现在,我们可以使用 Dataset.get_nearest_examples() 函数进行最近邻居查找。让我们通过首先嵌入一个 issue 来测试这一点,如下所示:

question = "How can I load a dataset offline?"
question_embedding = get_embeddings([question]).cpu().detach().numpy()
question_embedding.shape
torch.Size([1, 768])

就像对文档进行嵌入一样,我们现在有一个表示查询的 768 维向量,我们可以将其与整个语料库进行比较以找到最相似的嵌入:

scores, samples = embeddings_dataset.get_nearest_examples(
    "embeddings", question_embedding, k=5
)

Dataset.get_nearest_examples() 函数返回一个元组,包括评分(评价查询和文档之间的相似程度)和对应的样本(这里是 5 个最佳匹配)。让我们把这些收集到一个 pandas.DataFrame ,这样我们就可以轻松地对它们进行排序:

import pandas as pd

samples_df = pd.DataFrame.from_dict(samples)
samples_df["scores"] = scores
samples_df.sort_values("scores", ascending=False, inplace=True)

现在我们可以遍历前几行来查看我们的查询与评论的匹配程度如何:

for _, row in samples_df.iterrows():
    print(f"COMMENT: {row.comments}")
    print(f"SCORE: {row.scores}")
    print(f"TITLE: {row.title}")
    print(f"URL: {row.html_url}")
    print("=" * 50)
    print()
"""
COMMENT: Requiring online connection is a deal breaker in some cases unfortunately so it'd be great if offline mode is added similar to how `transformers` loads models offline fine.

@mandubian's second bullet point suggests that there's a workaround allowing you to use your offline (custom?) dataset with `datasets`. Could you please elaborate on how that should look like?
SCORE: 25.505046844482422
TITLE: Discussion using datasets in offline mode
URL: https://github.com/huggingface/datasets/issues/824
==================================================

COMMENT: The local dataset builders (csv, text , json and pandas) are now part of the `datasets` package since #1726 :)
You can now use them offline
\`\`\`python
datasets = load_dataset("text", data_files=data_files)
\`\`\`

We'll do a new release soon
SCORE: 24.555509567260742
TITLE: Discussion using datasets in offline mode
URL: https://github.com/huggingface/datasets/issues/824
==================================================

COMMENT: I opened a PR that allows to reload modules that have already been loaded once even if there's no internet.

Let me know if you know other ways that can make the offline mode experience better. I'd be happy to add them :)

I already note the "freeze" modules option, to prevent local modules updates. It would be a cool feature.

----------

> @mandubian's second bullet point suggests that there's a workaround allowing you to use your offline (custom?) dataset with `datasets`. Could you please elaborate on how that should look like?

Indeed `load_dataset` allows to load remote dataset script (squad, glue, etc.) but also you own local ones.
For example if you have a dataset script at `./my_dataset/my_dataset.py` then you can do
\`\`\`python
load_dataset("./my_dataset")
\`\`\`
and the dataset script will generate your dataset once and for all.

----------

About I'm looking into having `csv`, `json`, `text`, `pandas` dataset builders already included in the `datasets` package, so that they are available offline by default, as opposed to the other datasets that require the script to be downloaded.
cf #1724
SCORE: 24.14896583557129
TITLE: Discussion using datasets in offline mode
URL: https://github.com/huggingface/datasets/issues/824
==================================================

COMMENT: > here is my way to load a dataset offline, but it **requires** an online machine
>
> 1. (online machine)
>
> ```
>
> import datasets
>
> data = datasets.load_dataset(.)
>
> data.save_to_disk(/YOUR/DATASET/DIR)
>
> ```
>
> 2. copy the dir from online to the offline machine
>
> 3. (offline machine)
>
> ```
>
> import datasets
>
> data = datasets.load_from_disk(/SAVED/DATA/DIR)
>
> ```
>
>
>
> HTH.


SCORE: 22.893993377685547
TITLE: Discussion using datasets in offline mode
URL: https://github.com/huggingface/datasets/issues/824
==================================================

COMMENT: here is my way to load a dataset offline, but it **requires** an online machine
1. (online machine)
\`\`\`
import datasets
data = datasets.load_dataset(...)
data.save_to_disk(/YOUR/DATASET/DIR)
\`\`\`
2. copy the dir from online to the offline machine
3. (offline machine)
\`\`\`
import datasets
data = datasets.load_from_disk(/SAVED/DATA/DIR)
\`\`\`

HTH.
SCORE: 22.406635284423828
TITLE: Discussion using datasets in offline mode
URL: https://github.com/huggingface/datasets/issues/824
==================================================
"""

不错!我们的输出的第 2 个结果似乎与查询匹配。

✏️ 试试看!创建你自己的查询并查看你是否可以在检索到的文档中找到答案。你可能需要在 Dataset.get_nearest_examples() 增加参数 k 以扩大搜索范围。

Update on GitHub