File size: 15,011 Bytes
9a11e1b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bcb8ccf
0d58633
e00b8f2
bcb8ccf
9a11e1b
bcb8ccf
 
 
0d58633
 
 
bcb8ccf
 
9a11e1b
0d58633
 
9a11e1b
 
bcb8ccf
 
 
 
 
9a11e1b
 
 
 
bcb8ccf
9a11e1b
 
 
 
 
bcb8ccf
9a11e1b
bcb8ccf
 
 
 
 
 
 
9a11e1b
bcb8ccf
 
 
 
 
 
 
9a11e1b
bcb8ccf
9a11e1b
bcb8ccf
 
9a11e1b
 
bcb8ccf
 
 
 
 
 
 
 
 
 
 
 
5cd2907
bcb8ccf
 
 
 
 
 
e00b8f2
5cd2907
 
 
 
bcb8ccf
 
9a11e1b
e00b8f2
 
 
 
9a11e1b
5cd2907
 
 
4bd2962
 
 
 
 
 
 
 
0d58633
4bd2962
0d58633
 
 
 
 
 
 
4bd2962
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9a11e1b
bcb8ccf
 
 
 
9a11e1b
 
bcb8ccf
 
5cd2907
bcb8ccf
9a11e1b
bcb8ccf
 
9a11e1b
bcb8ccf
 
 
 
9a11e1b
 
5cd2907
bcb8ccf
 
 
 
 
 
 
 
 
 
 
4bd2962
bcb8ccf
8cca3d0
 
 
 
 
 
8a3618a
8cca3d0
 
 
 
8a3618a
8cca3d0
8a3618a
 
8cca3d0
8a3618a
8cca3d0
 
8a3618a
 
 
 
 
 
 
 
 
 
8cca3d0
8a3618a
 
 
 
 
 
 
 
 
 
 
 
 
8cca3d0
5cd2907
8cca3d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bcb8ccf
 
 
8a3618a
 
 
 
 
bcb8ccf
 
 
8a3618a
bcb8ccf
8a3618a
bcb8ccf
8a3618a
 
bcb8ccf
 
 
 
8a3618a
 
bcb8ccf
 
 
8a3618a
 
bcb8ccf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7eae31f
bcb8ccf
 
 
 
 
7eae31f
bcb8ccf
 
 
 
 
 
 
 
8059baf
 
 
bcb8ccf
 
 
 
 
 
 
8a3618a
bcb8ccf
 
 
 
8a3618a
bcb8ccf
 
 
7eae31f
 
bcb8ccf
 
8a3618a
bcb8ccf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7eae31f
bcb8ccf
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TODO: Add a description here."""

from collections import defaultdict
import logging
from typing import List, Dict, Tuple, NamedTuple

import datasets
import evaluate
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, \
    PreTrainedTokenizer, PreTrainedTokenizerFast, \
    GPT2TokenizerFast

from .prediction import Prediction

L = logging.getLogger(__name__)


_CITATION = """\
@inproceedings{Hu:et-al:2020,
  author = {Hu, Jennifer and Gauthier, Jon and Qian, Peng and Wilcox, Ethan and Levy, Roger},
  title = {A systematic assessment of syntactic generalization in neural language models},
  booktitle = {Proceedings of the Association of Computational Linguistics},
  year = {2020}
}
"""

# TODO: Add description of the module here
_DESCRIPTION = """
"""


# TODO: Add description of the arguments of the module here
_KWARGS_DESCRIPTION = """
Runs SyntaxGym evaluations on the given model and test suite.
Args:
    suite (Dataset): SyntaxGym test suite loaded as a Dataset.
    model_id (str): model used for calculating surprisals
            NOTE: The SyntaxGym evaluations are only well-defined for causal language models.
                    This includes models such as gpt2, causal variations of bert,
                    causal versions of t5, and more (the full list can be found
                    in the AutoModelForCausalLM documentation here:
                    https://huggingface.co/docs/transformers/master/en/model_doc/auto#transformers.AutoModelForCausalLM )
Returns:
    prediction_results: A list of prediction results per item. A list of lists,
            one per item, containing the boolean prediction result for each
            prediction in the test suite,
    region_totals: A list of total surprisals for each region (nested within
            condition and item). A list of dictionaries (one per item), each
            mapping tuples (condition_name, region_number) to a float
            total surprisal value (i.e. negative log-2 probability).
Examples:
    TODO

    >>> my_new_module = evaluate.load("cpllab/syntaxgym")
    >>> ...
"""


SUITE_DATASET_CONDITION_SPEC = {
    "condition_name": datasets.Value("string"),
    "content": datasets.Value("string"),
    "regions": datasets.Sequence({
        "region_number": datasets.Value("int32"),
        "content": datasets.Value("string")
    })
}


SUITE_DATASET_SPEC = {
    "suite_name": datasets.Value("string"),
    "item_number": datasets.Value("int32"),
    "conditions": datasets.Sequence(SUITE_DATASET_CONDITION_SPEC),
    "predictions": datasets.Sequence(datasets.Value("string")),
}


class SyntaxGymMetricSuiteResult(NamedTuple):
    """
    Evaluation results for a single suite.
    """
    suite_name: str
    prediction_results: List[List[bool]]
    region_totals: List[Dict[Tuple[str, int], float]]

    @property
    def accuracy(self) -> float:
        return np.array(self.prediction_results).all(axis=1).mean(axis=0)


SyntaxGymMetricResult = Dict[str, SyntaxGymMetricSuiteResult]


def prepare_tokenizer(model, batch_size, add_start_token=True) -> Tuple[PreTrainedTokenizer, Dict]:
    """
    Load and prepare a tokenizer for SyntaxGym evaluation.

    Returns:
        tokenizer:
        tokenizer_kwargs: suggested kwargs for any tokenizer calls
    """
    
    tokenizer = AutoTokenizer.from_pretrained(model.name_or_path)
    if not isinstance(tokenizer, PreTrainedTokenizerFast):
        # We need a fast tokenizer because these are the only tokenizers that support
        # return_offsets_mapping. Try to use GPT2 tokenizer -- this is sufficient for
        # OPT.
        L.warning(f"The model {model.name_or_path} does not have a fast tokenizer, "
                  f"which is required for this metric. Running with GPT2 tokenizer.")
        tokenizer = GPT2TokenizerFast.from_pretrained(model.name_or_path)

    # if batch_size > 1 (which generally leads to padding being required), and
    # if there is not an already assigned pad_token, assign an existing
    # special token to also be the padding token
    if tokenizer.pad_token is None and batch_size > 1:
        existing_special_tokens = list(tokenizer.special_tokens_map_extended.values())
        # check that the model already has at least one special token defined
        assert (
            len(existing_special_tokens) > 0
        ), "If batch_size > 1, model must have at least one special token to use for padding. Please use a different model or set batch_size=1."
        # assign one of the special tokens to also be the pad token
        tokenizer.add_special_tokens({"pad_token": existing_special_tokens[0]})

    if add_start_token:
        # leave room for <BOS> token to be added:
        assert (
            tokenizer.bos_token is not None
        ), "Input model must already have a BOS token if using add_start_token=True. Please use a different model, or set add_start_token=False"
        max_tokenized_len = model.config.max_length - 1
    else:
        max_tokenized_len = model.config.max_length

    tokenizer_kwargs = {
        "add_special_tokens": False,
        "padding": True,
        "max_length": max_tokenized_len
    }
    return tokenizer, tokenizer_kwargs


@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
class SyntaxGym(evaluate.EvaluationModule):
    """
    Defines SyntaxGym evaluation logic for causal language models.
    """

    def _info(self):
        seq = datasets.Sequence
        features = datasets.Features({
            "dataset": SUITE_DATASET_SPEC
        })
        return evaluate.EvaluationModuleInfo(
            module_type="metric",
            description="TODO",
            citation=_CITATION,
            inputs_description="TODO",
            features=features,
            homepage="https://syntaxgym.org",
            codebase_urls=["https://github.com/cpllab/syntaxgym-core"],
        )

    def _compute(self, dataset, model_id, batch_size=8, add_start_token=False, device=None) -> SyntaxGymMetricResult:
        if device is not None:
            assert device in ["gpu", "cpu", "cuda"]
            if device == "gpu":
                device = "cuda"
        else:
            device = "cuda" if torch.cuda.is_available() else "cpu"

        model = AutoModelForCausalLM.from_pretrained(model_id)
        model = model.to(device)
        model.eval()

        tokenizer, tokenizer_kwargs = prepare_tokenizer(model, batch_size, add_start_token)

        # Flatten sentences, enforcing that sentences are always ordered by the same condition
        # within-suite.
        condition_orders = {}
        for item in dataset:
            condition_orders[item["suite_name"]] = item["conditions"]["condition_name"]
        # Flattened batch of sentences
        all_sentences = []
        # Mapping from sentence back to originating suite
        all_sentence_suites = []
        # Mapping from item back to originating suite
        all_item_suites = []
        for item in dataset:
            for condition_name in condition_orders[item["suite_name"]]:
                # Get idx of condition for this item.
                condition_idx = item["conditions"]["condition_name"].index(condition_name)

                all_sentences.append(item["conditions"]["content"][condition_idx])
                all_sentence_suites.append(item["suite_name"])
            all_item_suites.append(item["suite_name"])

        # Tokenize sentences and split into batches.
        all_tokenized_sentences = tokenizer(all_sentences, return_tensors="pt",
                                            return_offsets_mapping=True,
                                            **tokenizer_kwargs).to(device)
        tokenized_batches = torch.split(all_tokenized_sentences["input_ids"], batch_size)

        # Compute surprisal per-batch and combine into a single surprisal tensor.
        n_sentences, n_timesteps = all_tokenized_sentences["input_ids"].shape
        surprisals = torch.zeros(n_sentences, n_timesteps - 1).float().to(device)
        for i, batch in enumerate(datasets.logging.tqdm(tokenized_batches, desc="Computing surprisals", unit="batch")) :
            batch = batch.to(device)
            with torch.no_grad():
                # logits are B * T * V
                b_logits = model(batch)["logits"]
                b_surprisals = -b_logits.log_softmax(dim=2) / np.log(2)

            # Get surprisals of ground-truth words.
            gt_idxs = batch[:, 1:]
            # Reindexed surprisals: B * (T - 1)
            b_surprisals_gt = torch.gather(b_surprisals[:, :-1, :], 2, gt_idxs.unsqueeze(2)).squeeze(2)

            surprisals[i * batch_size : (i + 1) * batch_size] = b_surprisals_gt

        # Aggregate results within-suite
        results = {}
        all_sentence_suites = np.array(all_sentence_suites)
        all_item_suites = np.array(all_item_suites)
        for suite, condition_order in datasets.logging.tqdm(condition_orders.items(), unit="suite"):
            suite_sentence_idxs = np.where(all_sentence_suites == suite)[0]
            suite_item_idxs = np.where(all_item_suites == suite)[0]
            suite_surprisals = surprisals[suite_sentence_idxs]

            # Reshape to intuitive axes n_items * n_conditions * ...
            suite_surprisals = suite_surprisals.reshape((len(suite_item_idxs), len(condition_order), -1))
            suite_offset_mapping = all_tokenized_sentences["offset_mapping"][suite_sentence_idxs] \
                .reshape((len(suite_item_idxs), len(condition_order), -1, 2))

            # Evaluate per-item
            suite_result = SyntaxGymMetricSuiteResult(suite, [], [])
            suite_items = datasets.logging.tqdm([dataset[idx] for idx in suite_item_idxs], unit="item")
            for item, item_surprisals, item_offset_mapping in zip(suite_items, suite_surprisals, suite_offset_mapping):
                result_i = self._compute_item(item, item_surprisals, item_offset_mapping, condition_order)
                suite_result.prediction_results.append(result_i["prediction_results"])
                suite_result.region_totals.append(result_i["region_totals"])

            results[suite] = suite_result

        return results

    def _compute_item(self, item, item_surprisals, offset_mapping, condition_order):
        """
        Aggregate token-level surprisals to region-level surprisals for the given item,
        and evaluate the item's predictions.
        """

        #### aggregate
        region_totals = {condition_name: defaultdict(float)
                         for condition_name in condition_order}
        region2tokens = self.compute_region_token_mapping(
            item, condition_order, offset_mapping)

        for i, (cond_i, surprisals_i) in enumerate(zip(condition_order, item_surprisals)):
            for region_number, region_tokens in region2tokens[cond_i].items():
                for token in region_tokens:
                    if token == 0:
                        # surprisal not defined. pass.
                        continue
                    elif token <= item_surprisals.shape[1]:
                        region_totals[cond_i][region_number] += surprisals_i[token - 1]
                    else:
                        # TODO don't think this is an issue, just should clean
                        # up the aggregation output
                        assert token == surprisals_i.shape[1], \
                            "%s %s" % (token, surprisals_i.shape[1])

        region_totals = {(condition_name, region_number): float(total)
                         for condition_name, totals in region_totals.items()
                         for region_number, total in totals.items()}

        results = {
            "prediction_results": [
                Prediction(i, formula, "sum").formula(region_totals)
                for i, formula in enumerate(item["predictions"])
            ],

            "region_totals": region_totals
        }
        return results

    def get_region_edges(self, item, condition_name):
        """
        Get left edge of each region as a character index.
        """
        # NB this is coupled with `condition_to_string` logic of course

        condition_idx = item["conditions"]["condition_name"].index(condition_name)
        regions = item["conditions"]["regions"][condition_idx]

        idx = 0
        ret = []
        for r_idx, region_content in enumerate(regions["content"]):
            ret.append(idx)

            region_size = len(region_content)
            # If this is not the first nonspace/nonpunct region, then it will
            # be preceded by a joining space.
            if region_content.strip() != "" and idx > 0 and not region_content.startswith(","):
                # Add joining space
                region_size += 1

            idx += region_size

        return ret

    def compute_region_token_mapping(self, item, condition_order,
                                     offset_mapping: List[Tuple[int, int]]
                                     ) -> Dict[str, Dict[int, List[int]]]:
        # offset_mapping: B * T * 2

        region2tokens = {cond: defaultdict(list) for cond in condition_order}

        max_long = torch.iinfo(torch.int64).max

        for cond_name, i_offsets in zip(condition_order, offset_mapping):
            region_edges = self.get_region_edges(item, cond_name)

            t_cursor, r_cursor = 0, 0
            while t_cursor < i_offsets.shape[0]:
                # token = i_tokens[t_cursor]
                token_char_start, token_char_end = i_offsets[t_cursor]

                if token_char_start == token_char_end == 0:
                    # This is a padding token. Skip.
                    # TODO what about BOS/EOS? some models incorporate them
                    t_cursor += 1
                    continue

                region_start = region_edges[r_cursor]
                region_end = region_edges[r_cursor + 1] \
                    if r_cursor + 1 < len(region_edges) else max_long

                # NB region boundaries are left edges, hence the >= here.
                if token_char_start >= region_end:
                    r_cursor += 1
                    continue

                region2tokens[cond_name][r_cursor + 1].append(t_cursor)
                t_cursor += 1

        return region2tokens