File size: 6,184 Bytes
4a1df2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
"""

Dataset Class
======================

TextAttack allows users to provide their own dataset or load from HuggingFace.


"""

from collections import OrderedDict
import random

import torch


class Dataset(torch.utils.data.Dataset):
    """Basic class for dataset. It operates as a map-style dataset, fetching
    data via :meth:`__getitem__` and :meth:`__len__` methods.

    .. note::
        This class subclasses :obj:`torch.utils.data.Dataset` and therefore can be treated as a regular PyTorch Dataset.

    Args:
        dataset (:obj:`list[tuple]`):
            A list of :obj:`(input, output)` pairs.
            If :obj:`input` consists of multiple fields (e.g. "premise" and "hypothesis" for SNLI),
            :obj:`input` must be of the form :obj:`(input_1, input_2, ...)` and :obj:`input_columns` parameter must be set.
            :obj:`output` can either be an integer representing labels for classification or a string for seq2seq tasks.
        input_columns (:obj:`list[str]`, `optional`, defaults to :obj:`["text"]`):
            List of column names of inputs in order.
        label_map (:obj:`dict[int, int]`, `optional`, defaults to :obj:`None`):
            Mapping if output labels of the dataset should be re-mapped. Useful if model was trained with a different label arrangement.
            For example, if dataset's arrangement is 0 for `Negative` and 1 for `Positive`, but model's label
            arrangement is 1 for `Negative` and 0 for `Positive`, passing :obj:`{0: 1, 1: 0}` will remap the dataset's label to match with model's arrangements.
            Could also be used to remap literal labels to numerical labels (e.g. :obj:`{"positive": 1, "negative": 0}`).
        label_names (:obj:`list[str]`, `optional`, defaults to :obj:`None`):
            List of label names in corresponding order (e.g. :obj:`["World", "Sports", "Business", "Sci/Tech"]` for AG-News dataset).
            If not set, labels will printed as is (e.g. "0", "1", ...). This should be set to :obj:`None` for non-classification datasets.
        output_scale_factor (:obj:`float`, `optional`, defaults to :obj:`None`):
            Factor to divide ground-truth outputs by. Generally, TextAttack goal functions require model outputs between 0 and 1.
            Some datasets are regression tasks, in which case this is necessary.
        shuffle (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether to shuffle the underlying dataset.

            .. note::
                Generally not recommended to shuffle the underlying dataset. Shuffling can be performed using DataLoader or by shuffling the order of indices we attack.

    Examples::

        >>> import textattack

        >>> # Example of sentiment-classification dataset
        >>> data = [("I enjoyed the movie a lot!", 1), ("Absolutely horrible film.", 0), ("Our family had a fun time!", 1)]
        >>> dataset = textattack.datasets.Dataset(data)
        >>> dataset[1:2]


        >>> # Example for pair of sequence inputs (e.g. SNLI)
        >>> data = [("A man inspects the uniform of a figure in some East Asian country.", "The man is sleeping"), 1)]
        >>> dataset = textattack.datasets.Dataset(data, input_columns=("premise", "hypothesis"))

        >>> # Example for seq2seq
        >>> data = [("J'aime le film.", "I love the movie.")]
        >>> dataset = textattack.datasets.Dataset(data)
    """

    def __init__(
        self,
        dataset,
        input_columns=["text"],
        label_map=None,
        label_names=None,
        output_scale_factor=None,
        shuffle=False,
    ):
        self._dataset = dataset
        self.input_columns = input_columns
        self.label_map = label_map
        self.label_names = label_names
        if label_map:
            # If labels are remapped, the label names have to be remapped as well.
            self.label_names = [
                self.label_names[self.label_map[i]] for i in self.label_map
            ]
        self.shuffled = shuffle
        self.output_scale_factor = output_scale_factor

        if shuffle:
            random.shuffle(self._dataset)

    def _format_as_dict(self, example):
        output = example[1]
        if self.label_map:
            output = self.label_map[output]
        if self.output_scale_factor:
            output = output / self.output_scale_factor

        if isinstance(example[0], str):
            if len(self.input_columns) != 1:
                raise ValueError(
                    "Mismatch between the number of columns in `input_columns` and number of columns of actual input."
                )
            input_dict = OrderedDict([(self.input_columns[0], example[0])])
        else:
            if len(self.input_columns) != len(example[0]):
                raise ValueError(
                    "Mismatch between the number of columns in `input_columns` and number of columns of actual input."
                )
            input_dict = OrderedDict(
                [(c, example[0][i]) for i, c in enumerate(self.input_columns)]
            )
        return input_dict, output

    def shuffle(self):
        random.shuffle(self._dataset)
        self.shuffled = True

    def filter_by_labels_(self, labels_to_keep):
        """Filter items by their labels for classification datasets. Performs
        in-place filtering.

        Args:
            labels_to_keep (:obj:`Union[Set, Tuple, List, Iterable]`):
                Set, tuple, list, or iterable of integers representing labels.
        """
        if not isinstance(labels_to_keep, set):
            labels_to_keep = set(labels_to_keep)
        self._dataset = filter(lambda x: x[1] in labels_to_keep, self._dataset)

    def __getitem__(self, i):
        """Return i-th sample."""
        if isinstance(i, int):
            return self._format_as_dict(self._dataset[i])
        else:
            # `idx` could be a slice or an integer. if it's a slice,
            # return the formatted version of the proper slice of the list
            return [self._format_as_dict(ex) for ex in self._dataset[i]]

    def __len__(self):
        """Returns the size of dataset."""
        return len(self._dataset)