|
""" |
|
|
|
HuggingFaceDataset Class |
|
========================= |
|
|
|
TextAttack allows users to provide their own dataset or load from HuggingFace. |
|
|
|
|
|
""" |
|
|
|
import collections |
|
|
|
import datasets |
|
|
|
import textattack |
|
|
|
from .dataset import Dataset |
|
|
|
|
|
def _cb(s): |
|
"""Colors some text blue for printing to the terminal.""" |
|
return textattack.shared.utils.color_text(str(s), color="blue", method="ansi") |
|
|
|
|
|
def get_datasets_dataset_columns(dataset): |
|
"""Common schemas for datasets found in dataset hub.""" |
|
schema = set(dataset.column_names) |
|
if {"premise", "hypothesis", "label"} <= schema: |
|
input_columns = ("premise", "hypothesis") |
|
output_column = "label" |
|
elif {"question", "sentence", "label"} <= schema: |
|
input_columns = ("question", "sentence") |
|
output_column = "label" |
|
elif {"sentence1", "sentence2", "label"} <= schema: |
|
input_columns = ("sentence1", "sentence2") |
|
output_column = "label" |
|
elif {"question1", "question2", "label"} <= schema: |
|
input_columns = ("question1", "question2") |
|
output_column = "label" |
|
elif {"question", "sentence", "label"} <= schema: |
|
input_columns = ("question", "sentence") |
|
output_column = "label" |
|
elif {"context", "question", "title", "answers"} <= schema: |
|
|
|
input_columns = ("title", "context", "question") |
|
output_column = "answers" |
|
elif {"text", "label"} <= schema: |
|
input_columns = ("text",) |
|
output_column = "label" |
|
elif {"sentence", "label"} <= schema: |
|
input_columns = ("sentence",) |
|
output_column = "label" |
|
elif {"document", "summary"} <= schema: |
|
input_columns = ("document",) |
|
output_column = "summary" |
|
elif {"content", "summary"} <= schema: |
|
input_columns = ("content",) |
|
output_column = "summary" |
|
elif {"label", "review"} <= schema: |
|
input_columns = ("review",) |
|
output_column = "label" |
|
else: |
|
raise ValueError( |
|
f"Unsupported dataset schema {schema}. Try passing your own `dataset_columns` argument." |
|
) |
|
|
|
return input_columns, output_column |
|
|
|
|
|
class HuggingFaceDataset(Dataset): |
|
"""Loads a dataset from 🤗 Datasets and prepares it as a TextAttack dataset. |
|
|
|
Args: |
|
name_or_dataset (:obj:`Union[str, datasets.Dataset]`): |
|
The dataset name as :obj:`str` or actual :obj:`datasets.Dataset` object. |
|
If it's your custom :obj:`datasets.Dataset` object, please pass the input and output columns via :obj:`dataset_columns` argument. |
|
subset (:obj:`str`, `optional`, defaults to :obj:`None`): |
|
The subset of the main dataset. Dataset will be loaded as :obj:`datasets.load_dataset(name, subset)`. |
|
split (:obj:`str`, `optional`, defaults to :obj:`"train"`): |
|
The split of the dataset. |
|
dataset_columns (:obj:`tuple(list[str], str))`, `optional`, defaults to :obj:`None`): |
|
Pair of :obj:`list[str]` representing list of input column names (e.g. :obj:`["premise", "hypothesis"]`) |
|
and :obj:`str` representing the output column name (e.g. :obj:`label`). If not set, we will try to automatically determine column names from known designs. |
|
label_map (:obj:`dict[int, int]`, `optional`, defaults to :obj:`None`): |
|
Mapping if output labels of the dataset should be re-mapped. Useful if model was trained with a different label arrangement. |
|
For example, if dataset's arrangement is 0 for `Negative` and 1 for `Positive`, but model's label |
|
arrangement is 1 for `Negative` and 0 for `Positive`, passing :obj:`{0: 1, 1: 0}` will remap the dataset's label to match with model's arrangements. |
|
Could also be used to remap literal labels to numerical labels (e.g. :obj:`{"positive": 1, "negative": 0}`). |
|
label_names (:obj:`list[str]`, `optional`, defaults to :obj:`None`): |
|
List of label names in corresponding order (e.g. :obj:`["World", "Sports", "Business", "Sci/Tech"]` for AG-News dataset). |
|
If not set, labels will printed as is (e.g. "0", "1", ...). This should be set to :obj:`None` for non-classification datasets. |
|
output_scale_factor (:obj:`float`, `optional`, defaults to :obj:`None`): |
|
Factor to divide ground-truth outputs by. Generally, TextAttack goal functions require model outputs between 0 and 1. |
|
Some datasets are regression tasks, in which case this is necessary. |
|
shuffle (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether to shuffle the underlying dataset. |
|
|
|
.. note:: |
|
Generally not recommended to shuffle the underlying dataset. Shuffling can be performed using DataLoader or by shuffling the order of indices we attack. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
name_or_dataset, |
|
subset=None, |
|
split="train", |
|
dataset_columns=None, |
|
label_map=None, |
|
label_names=None, |
|
output_scale_factor=None, |
|
shuffle=False, |
|
): |
|
if isinstance(name_or_dataset, datasets.Dataset): |
|
self._dataset = name_or_dataset |
|
else: |
|
self._name = name_or_dataset |
|
self._subset = subset |
|
self._dataset = datasets.load_dataset(self._name, subset)[split] |
|
subset_print_str = f", subset {_cb(subset)}" if subset else "" |
|
textattack.shared.logger.info( |
|
f"Loading {_cb('datasets')} dataset {_cb(self._name)}{subset_print_str}, split {_cb(split)}." |
|
) |
|
|
|
( |
|
self.input_columns, |
|
self.output_column, |
|
) = dataset_columns or get_datasets_dataset_columns(self._dataset) |
|
|
|
if not isinstance(self.input_columns, (list, tuple)): |
|
raise ValueError( |
|
"First element of `dataset_columns` must be a list or a tuple." |
|
) |
|
|
|
self.label_map = label_map |
|
self.output_scale_factor = output_scale_factor |
|
if label_names: |
|
self.label_names = label_names |
|
else: |
|
try: |
|
self.label_names = self._dataset.features[self.output_column].names |
|
except (KeyError, AttributeError): |
|
|
|
self.label_names = None |
|
|
|
|
|
if self.label_names and label_map: |
|
self.label_names = [ |
|
self.label_names[self.label_map[i]] for i in self.label_map |
|
] |
|
|
|
self.shuffled = shuffle |
|
if shuffle: |
|
self._dataset.shuffle() |
|
|
|
def _format_as_dict(self, example): |
|
input_dict = collections.OrderedDict( |
|
[(c, example[c]) for c in self.input_columns] |
|
) |
|
|
|
output = example[self.output_column] |
|
if self.label_map: |
|
output = self.label_map[output] |
|
if self.output_scale_factor: |
|
output = output / self.output_scale_factor |
|
|
|
return (input_dict, output) |
|
|
|
def filter_by_labels_(self, labels_to_keep): |
|
"""Filter items by their labels for classification datasets. Performs |
|
in-place filtering. |
|
|
|
Args: |
|
labels_to_keep (:obj:`Union[Set, Tuple, List, Iterable]`): |
|
Set, tuple, list, or iterable of integers representing labels. |
|
""" |
|
if not isinstance(labels_to_keep, set): |
|
labels_to_keep = set(labels_to_keep) |
|
self._dataset = self._dataset.filter( |
|
lambda x: x[self.output_column] in labels_to_keep |
|
) |
|
|
|
def __getitem__(self, i): |
|
"""Return i-th sample.""" |
|
if isinstance(i, int): |
|
return self._format_as_dict(self._dataset[i]) |
|
else: |
|
|
|
|
|
return [ |
|
self._format_as_dict(self._dataset[j]) for j in range(i.start, i.stop) |
|
] |
|
|
|
def shuffle(self): |
|
self._dataset.shuffle() |
|
self.shuffled = True |
|
|