Open-Source AI Cookbook documentation

使用 Cleanlab 检测文本数据集中的问题

Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Open In Colab

使用 Cleanlab 检测文本数据集中的问题

作者: Aravind Putrevu

在这个 5 分钟的快速入门教程中,我们将使用 Cleanlab 检测一个由在线银行(文本)客户服务请求组成的意图分类数据集中的各种问题。我们考虑的是 Banking77-OOS数据集 的一个子集,包含 1,000 个客户服务请求,根据它们的意图被分类为 10 个类别(你可以在任何文本分类数据集上运行相同的代码)。Cleanlab自动识别我们数据集中的坏例子,包括错误标记的数据、范围外的示例(离群值)或其他模糊不清的示例。在深入建模你的数据之前,请考虑过滤或更正这样的坏例子!

本教程我们将要做的事情概述:

  • 使用预训练的 transformer 模型从客户服务请求中提取文本嵌入

  • 在文本嵌入上训练一个简单的逻辑回归模型,以计算样本外的预测概率

  • 使用这些预测和嵌入运行 Cleanlab 的 Datalab 审核,以识别数据集中的问题,如:标签问题、离群值和近重复项。

快速入门

已经有一个模型在现有标签集上训练得到的(样本外)pred_probs 了吗?也许你还有一些数值特征?运行下面的代码来查找数据集中的任何潜在标签错误。

注意: 如果在 Colab 上运行,可能需要使用 GPU(选择:Runtime > Change runtime type > Hardware accelerator > GPU)

from cleanlab import Datalab

lab = Datalab(data=your_dataset, label_name="column_name_of_labels")
lab.find_issues(pred_probs=your_pred_probs, features=your_features)

lab.report()
lab.get_issues()

安装需要的依赖

你可以使用 pip 按照以下方式安装本教程所需的所有包:

!pip install -U scikit-learn sentence-transformers datasets
!pip install -U "cleanlab[datalab]"
import re
import string
import pandas as pd
from sklearn.metrics import accuracy_score, log_loss
from sklearn.model_selection import cross_val_predict
from sklearn.linear_model import LogisticRegression
from sentence_transformers import SentenceTransformer

from cleanlab import Datalab
import random
import numpy as np

pd.set_option("display.max_colwidth", None)

SEED = 123456  # for reproducibility
np.random.seed(SEED)
random.seed(SEED)

加载和格式化文本数据集

from datasets import load_dataset

dataset = load_dataset("PolyAI/banking77", split="train")
data = pd.DataFrame(dataset[:1000])
data.head()
>>> raw_texts, labels = data["text"].values, data["label"].values
>>> num_classes = len(set(labels))

>>> print(f"This dataset has {num_classes} classes.")
>>> print(f"Classes: {set(labels)}")
This dataset has 7 classes.
Classes: {32, 34, 36, 11, 13, 46, 17}

让我们查看数据集中的第 i 个示例:

>>> i = 1  # change this to view other examples from the dataset
>>> print(f"Example Label: {labels[i]}")
>>> print(f"Example Text: {raw_texts[i]}")
Example Label: 11
Example Text: What can I do if my card still hasn't arrived after 2 weeks?

数据以两个 numpy 数组的形式存储:

  1. raw_texts 以文本格式存储客户服务请求的话语
  2. labels 存储每个示例的意图类别(标签)

自有数据?

你可以轻松地将上述内容替换为你自己的文本数据集,并继续进行教程的其余部分。

接下来,我们将文本字符串转换为更适合作为机器学习模型输入的向量。

我们将使用预训练的 Transformer 模型提供的数值表示作为我们文本的嵌入。Sentence Transformers 库提供了计算文本数据嵌入的简单方法。在这里,我们加载了预训练的 electra-small-discriminator 模型,然后通过网络运行我们的数据,以提取每个示例的向量嵌入。

transformer = SentenceTransformer("google/electra-small-discriminator")
text_embeddings = transformer.encode(raw_texts)

我们后续的机器学习模型将直接在 text_embeddings 的元素上操作,以便对客户服务请求进行分类。

定义一个分类模型并计算样本外的预测概率

为了利用预训练网络进行特定的分类任务,一种典型的方法是添加一个线性输出层,并在新数据上微调网络参数。然而,这可能需要大量的计算资源。另一种方法是冻结网络的预训练权重,只训练输出层,而不依赖于 GPU。在这里,我们通过在提取的嵌入顶部拟合一个 scikit-learn 线性模型来方便地实现这一点。

为了识别标签问题,cleanlab 需要你的模型为每个数据点提供概率预测。然而,对于模型之前训练过的数据点,这些预测将是过拟合的(因此不可靠)。cleanlab 旨在仅与样本外的预测类概率一起使用,即在模型训练期间保持不变的数据点。

在这里,我们使用带有交叉验证的逻辑回归模型来获得数据集中每个示例的样本外预测类概率。 确保你的 pred_probs 列根据类的排序正确排序,对于 Datalab 来说,是:按类名字典顺序排序。

model = LogisticRegression(max_iter=400)

pred_probs = cross_val_predict(model, text_embeddings, labels, method="predict_proba")

使用 Cleanlab 查找数据集中的问题

在给定来自你拥有的任何模型的特征嵌入和(样本外)预测类概率的情况下,cleanlab 可以帮助你快速识别数据中的低质量示例。

在这里,我们使用 Cleanlab 的 Datalab 来查找数据中的问题。Datalab 提供了几种加载数据的方式;我们将简单地在字典中包装训练特征和噪声标签。

data_dict = {"texts": raw_texts, "labels": labels}

审核你的数据所需的全部操作就是调用 find_issues()。我们传入上面获得的预测概率和特征嵌入,但你不一定需要提供所有这些信息,具体取决于你对哪些类型的问题感兴趣。你提供的输入越多,Datalab 就能在你的数据中检测到更多类型的问题。使用更好的模型来生成这些输入将确保 cleanlab 更准确地估计问题。

lab = Datalab(data_dict, label_name="labels")
lab.find_issues(pred_probs=pred_probs, features=text_embeddings)

输出看起来如下:

Finding null issues ...
Finding label issues ...
Finding outlier issues ...
Fitting OOD estimator based on provided features ...
Finding near_duplicate issues ...
Finding non_iid issues ...
Finding class_imbalance issues ...
Finding underperforming_group issues ...

Audit complete. 62 issues found in the dataset.

审计完成后,使用 report 方法来查看审计结果。

>>> lab.report()
Here is a summary of the different kinds of issues found in the data:

    issue_type  num_issues
       outlier          37
near_duplicate          14
         label          10
       non_iid           1

Dataset Information: num_examples: 1000, num_classes: 7


---------------------- outlier issues ----------------------

About this issue:
	Examples that are very different from the rest of the dataset 
    (i.e. potentially out-of-distribution or rare/anomalous instances).
    

Number of examples with this issue: 37
Overall dataset quality in terms of this issue: 0.3671

Examples representing most severe instances of this issue:
     is_outlier_issue  outlier_score
791              True       0.024866
601              True       0.031162
863              True       0.060738
355              True       0.064199
157              True       0.065075


------------------ near_duplicate issues -------------------

About this issue:
	A (near) duplicate issue refers to two or more examples in
    a dataset that are extremely similar to each other, relative
    to the rest of the dataset.  The examples flagged with this issue
    may be exactly duplicated, or lie atypically close together when
    represented as vectors (i.e. feature embeddings).
    

Number of examples with this issue: 14
Overall dataset quality in terms of this issue: 0.5961

Examples representing most severe instances of this issue:
     is_near_duplicate_issue  near_duplicate_score near_duplicate_sets  distance_to_nearest_neighbor
459                     True              0.009544               [429]                      0.000566
429                     True              0.009544               [459]                      0.000566
501                     True              0.046044          [412, 517]                      0.002781
412                     True              0.046044               [501]                      0.002781
698                     True              0.054626               [607]                      0.003314


----------------------- label issues -----------------------

About this issue:
	Examples whose given label is estimated to be potentially incorrect
    (e.g. due to annotation error) are flagged as having label issues.
    

Number of examples with this issue: 10
Overall dataset quality in terms of this issue: 0.9930

Examples representing most severe instances of this issue:
     is_label_issue  label_score  given_label  predicted_label
379           False     0.025486           32               11
100           False     0.032102           11               36
300           False     0.037742           32               46
485            True     0.057666           17               34
159            True     0.059408           13               11


---------------------- non_iid issues ----------------------

About this issue:
	Whether the dataset exhibits statistically significant
    violations of the IID assumption like:
    changepoints or shift, drift, autocorrelation, etc.
    The specific violation considered is whether the
    examples are ordered such that almost adjacent examples
    tend to have more similar feature values.
    

Number of examples with this issue: 1
Overall dataset quality in terms of this issue: 0.0000

Examples representing most severe instances of this issue:
     is_non_iid_issue  non_iid_score
988              True       0.563774
975             False       0.570179
997             False       0.571891
967             False       0.572357
956             False       0.577413

Additional Information: 
p-value: 0.0

标签问题

报告显示 cleanlab 在我们的数据集中识别出了许多标签问题。我们可以使用 get_issues 方法来查看哪些示例被标记为可能标签错误,以及每个示例的标签质量分数,通过指定 label 作为参数来关注数据中的标签问题。

label_issues = lab.get_issues("label")
label_issues.head()
is_label_issue label_score given_label predicted_label
0 False 0.903926 11 11
1 False 0.860544 11 11
2 False 0.658309 11 11
3 False 0.697085 11 11
4 False 0.434934 11 11

此方法返回一个包含每个示例的标签质量分数的数据框。这些数值分数介于 0 和 1 之间,其中较低的分数表示更可能是错误标记的示例。数据框还包含一个布尔列,指定是否将每个示例识别为具有标签问题(表明它可能是错误标记的)。

我们可以获取标记有标签问题的示例的子集,并且还可以按标签质量分数排序,以找到数据集中最可能错误标记的 5 个示例的索引。

>>> identified_label_issues = label_issues[label_issues["is_label_issue"] == True]
>>> lowest_quality_labels = label_issues["label_score"].argsort()[:5].to_numpy()

>>> print(
...     f"cleanlab found {len(identified_label_issues)} potential label errors in the dataset.\n"
...     f"Here are indices of the top 5 most likely errors: \n {lowest_quality_labels}"
... )
cleanlab found 10 potential label errors in the dataset.
Here are indices of the top 5 most likely errors: 
 [379 100 300 485 159]

让我们查看一些最可能的标签错误。

这里我们展示了数据集中被识别为最可能的标签错误的前 5 个示例,以及它们的给定(原始)标签和 cleanlab 提供的建议替代标签。

data_with_suggested_labels = pd.DataFrame(
    {"text": raw_texts, "given_label": labels, "suggested_label": label_issues["predicted_label"]}
)
data_with_suggested_labels.iloc[lowest_quality_labels]

上面命令的输出如下所示:

text given_label suggested_label
379 Is there a specific source that the exchange rate for the transfer I’m planning on making is pulled from? 32 11
100 can you share card tracking number? 11 36
300 If I need to cash foreign transfers, how does that work? 32 46
485 Was I charged more than I should of been for a currency exchange? 17 34
159 Is there any way to see my card in the app? 13 11

这些是 cleanlab 在此数据中识别的非常清晰的标签错误!请注意,given_label 并没有正确反映这些请求的意图,无论谁制作了这个数据集,在建模数据之前都需要解决许多错误。

离群值问题

根据报告,我们的数据集中包含一些离群值。

我们可以通过 get_issues 查看哪些示例是离群值(以及一个数值质量分数,量化每个示例看起来有多么典型)。我们将结果数据框按照 cleanlab 的离群值质量分数排序,以查看数据集中最严重的离群值。

outlier_issues = lab.get_issues("outlier")
outlier_issues.sort_values("outlier_score").head()

输出如下所示:

is_outlier_issue outlier_score
791 True 0.024866
601 True 0.031162
863 True 0.060738
355 True 0.064199
157 True 0.065075
lowest_quality_outliers = outlier_issues["outlier_score"].argsort()[:5]

data.iloc[lowest_quality_outliers]

对于质量最低的离群值,样本输出将如下所示:

index text label
791 withdrawal pending meaning? 46
601 $1 charge in transaction. 34
863 My atm withdraw is stillpending 46
355 explain the interbank exchange rate 32
157 lost card found, want to put it back in app 13

我们看到 cleanlab 已经识别出这个数据集中的条目,这些条目看起来并不是正确的客户请求。此数据集中的离群值似乎是不在范围内的客户请求和其他对意图分类没有意义的非语义文本。仔细考虑这些离群值是否可能对你的数据建模产生不利影响,如果有可能的话,考虑从数据集中移除它们。

近重复问题

根据报告,我们的数据集中包含一些几乎重复的示例集。 我们可以通过 get_issues 查看哪些示例是(几乎)重复的(以及一个数值质量分数,量化每个示例与数据集中最近邻的相似程度)。我们将结果数据框按照 cleanlab 的近重复质量分数排序,以查看数据集中最接近重复的文本示例。

duplicate_issues = lab.get_issues("near_duplicate")
duplicate_issues.sort_values("near_duplicate_score").head()

上面的结果显示了 cleanlab 认为哪些示例是近重复的(is_near_duplicate_issue == True 的行)。在这里,我们看到示例 459 和 429 是近重复的,示例 501 和 412 也是近重复的。

让我们查看这些示例,看看它们有多么相似。

data.iloc[[459, 429]]

样本输出:

index text label
459 I purchased something abroad and the incorrect exchange rate was applied. 17
429 I purchased something overseas and the incorrect exchange rate was applied. 17
data.iloc[[501, 412]]

样本输出:

index text label
501 The exchange rate you are using is really bad.This can’t be the official interbank exchange rate. 17
412 The exchange rate you are using is bad.This can’t be the official interbank exchange rate. 17

我们看到这两组请求确实非常相似!在数据集中包含近重复项可能会对模型产生意想不到的影响,并且要小心不要将它们分割到训练/测试集中。从常见问题解答中了解更多关于处理数据集中的近重复数据的信息。

非独立同分布问题(数据漂移)

根据报告,我们的数据集似乎不是独立同分布的(IID)。数据集的整体非 IID 分数(如下所示)对应于一个统计测试的 p 值,该测试用于判断数据集中样本的排序是否与它们特征值之间的相似性有关。一个低的 p 值强烈表明数据集违反了 IID 假设,这是从数据集产生的结论(模型)推广到更大总体所需的关键假设。

p_value = lab.get_info("non_iid")["p-value"]
p_value

在这里,我们的数据集被标记为非 IID,因为原始数据中的行恰好是按类别标签排序的。如果我们记得在模型训练和数据拆分之前打乱行,这可能是不重要的。但是,如果你不知道为什么你的数据被标记为非IID,那么你应该担心可能的数据漂移或数据点之间的意外交互(它们的价值可能不是统计独立的)。仔细考虑未来的测试数据可能看起来如何(以及你的数据是否代表你关心的人群)。在非 IID 测试运行之前,你不应该打乱数据(这将使结论无效)。

如上所示,cleanlab 可以自动筛选出数据集中最可能的问题,帮助你更好地为后续建模整理数据集。有了这个短名单,你可以选择修复这些标签问题,或者从数据集中移除非语义或重复的示例,以获得更高质量的数据集来训练你的下一个机器学习模型。cleanlab 的问题检测可以与你最初训练的任何类型的模型的输出一起运行。

Cleanlab 开源项目

Cleanlab 是一个标准的以数据为中心的人工智能包,旨在解决混乱的现实世界数据的质量问题。

请考虑给 Cleanlab Github 仓库一个星标,如果你有兴趣,也可以参与到这个项目中来,比如帮助解决一些简单的问题。。

< > Update on GitHub