yu-val-weiss
commited on
Commit
·
8f3cd77
1
Parent(s):
803da62
Update blimp.py
Browse files
blimp.py
CHANGED
@@ -15,13 +15,83 @@
|
|
15 |
|
16 |
import datasets
|
17 |
import evaluate
|
18 |
-
import numpy as np
|
19 |
import torch
|
20 |
from evaluate import logging
|
21 |
-
from torch.nn import CrossEntropyLoss
|
22 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
23 |
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
@article{warstadt2020blimp,
|
26 |
author = {Warstadt, Alex and Parrish, Alicia and Liu, Haokun and Mohananey, Anhad and Peng, Wei and Wang, Sheng-Fu and Bowman, Samuel R.},
|
27 |
title = {BLiMP: The Benchmark of Linguistic Minimal Pairs for English},
|
@@ -37,8 +107,7 @@ _CITATION = """\
|
|
37 |
}
|
38 |
"""
|
39 |
|
40 |
-
_DESCRIPTION = """
|
41 |
-
BLiMP is a challenge set for evaluating what language models (LMs) know about major grammatical phenomena in English.
|
42 |
BLiMP consists of 67 sub-datasets, each containing 1000 minimal pairs isolating specific contrasts in syntax, morphology, or semantics.
|
43 |
The data is automatically generated according to expert-crafted grammars. Aggregate human agreement with the labels is 96.4%.
|
44 |
We use BLiMP to evaluate an n-gram LM, LSTM LM, GPT-2, and Transformer-XL.
|
@@ -48,9 +117,12 @@ For more info see https://github.com/alexwarstadt/blimp.
|
|
48 |
|
49 |
_KWARGS_DESCRIPTION = """
|
50 |
Args:
|
51 |
-
model_id (str): model used for calculating Blimp
|
|
|
52 |
batch_size (int): the batch size to run texts through the model. Defaults to 16.
|
53 |
-
device (str): device to run on, defaults to 'cuda' when available
|
|
|
|
|
54 |
Returns:
|
55 |
blimp: dictionary containing the blimp scores for each of the 67 sub-datasets, as well as the overall accuracy.
|
56 |
An LM’s overall accuracy on BLiMP is simply the proportion of the 67,000 minimal pairs in which the model assigns a higher probability to the acceptable sentence.
|
@@ -60,7 +132,7 @@ Examples:
|
|
60 |
|
61 |
|
62 |
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
|
63 |
-
class
|
64 |
def _info(self):
|
65 |
return evaluate.MetricInfo(
|
66 |
module_type="metric",
|
@@ -80,12 +152,11 @@ class Perplexity(evaluate.Metric):
|
|
80 |
|
81 |
def _compute(
|
82 |
self,
|
83 |
-
predictions,
|
84 |
model_id,
|
|
|
85 |
batch_size: int = 16,
|
86 |
-
add_start_token: bool = True,
|
87 |
device=None,
|
88 |
-
|
89 |
):
|
90 |
if device is not None:
|
91 |
assert device in ["gpu", "cpu", "cuda", "mps"], (
|
@@ -102,6 +173,7 @@ class Perplexity(evaluate.Metric):
|
|
102 |
|
103 |
model = AutoModelForCausalLM.from_pretrained(model_id)
|
104 |
model = model.to(device)
|
|
|
105 |
|
106 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
107 |
|
@@ -119,78 +191,93 @@ class Perplexity(evaluate.Metric):
|
|
119 |
# assign one of the special tokens to also be the pad token
|
120 |
tokenizer.add_special_tokens({"pad_token": existing_special_tokens[0]})
|
121 |
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
encodings = tokenizer(
|
132 |
-
predictions,
|
133 |
-
add_special_tokens=False,
|
134 |
-
padding=True,
|
135 |
-
truncation=True if max_tokenized_len else False,
|
136 |
-
max_length=max_tokenized_len,
|
137 |
-
return_tensors="pt",
|
138 |
-
return_attention_mask=True,
|
139 |
-
).to(device)
|
140 |
|
141 |
-
|
142 |
-
attn_masks = encodings["attention_mask"]
|
143 |
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
)
|
149 |
-
|
150 |
-
|
151 |
-
|
|
|
|
|
|
|
|
|
|
|
152 |
)
|
153 |
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
193 |
|
194 |
-
|
195 |
|
196 |
-
|
|
|
15 |
|
16 |
import datasets
|
17 |
import evaluate
|
|
|
18 |
import torch
|
19 |
from evaluate import logging
|
|
|
20 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
21 |
|
22 |
+
datasets.logging.set_verbosity_error()
|
23 |
+
|
24 |
+
BLIMP_PHENOMENA = [
|
25 |
+
"adjunct_island",
|
26 |
+
"anaphor_gender_agreement",
|
27 |
+
"anaphor_number_agreement",
|
28 |
+
"animate_subject_passive",
|
29 |
+
"animate_subject_trans",
|
30 |
+
"causative",
|
31 |
+
"complex_NP_island",
|
32 |
+
"coordinate_structure_constraint_complex_left_branch",
|
33 |
+
"coordinate_structure_constraint_object_extraction",
|
34 |
+
"determiner_noun_agreement_1",
|
35 |
+
"determiner_noun_agreement_2",
|
36 |
+
"determiner_noun_agreement_irregular_1",
|
37 |
+
"determiner_noun_agreement_irregular_2",
|
38 |
+
"determiner_noun_agreement_with_adj_2",
|
39 |
+
"determiner_noun_agreement_with_adj_irregular_1",
|
40 |
+
"determiner_noun_agreement_with_adj_irregular_2",
|
41 |
+
"determiner_noun_agreement_with_adjective_1",
|
42 |
+
"distractor_agreement_relational_noun",
|
43 |
+
"distractor_agreement_relative_clause",
|
44 |
+
"drop_argument",
|
45 |
+
"ellipsis_n_bar_1",
|
46 |
+
"ellipsis_n_bar_2",
|
47 |
+
"existential_there_object_raising",
|
48 |
+
"existential_there_quantifiers_1",
|
49 |
+
"existential_there_quantifiers_2",
|
50 |
+
"existential_there_subject_raising",
|
51 |
+
"expletive_it_object_raising",
|
52 |
+
"inchoative",
|
53 |
+
"intransitive",
|
54 |
+
"irregular_past_participle_adjectives",
|
55 |
+
"irregular_past_participle_verbs",
|
56 |
+
"irregular_plural_subject_verb_agreement_1",
|
57 |
+
"irregular_plural_subject_verb_agreement_2",
|
58 |
+
"left_branch_island_echo_question",
|
59 |
+
"left_branch_island_simple_question",
|
60 |
+
"matrix_question_npi_licensor_present",
|
61 |
+
"npi_present_1",
|
62 |
+
"npi_present_2",
|
63 |
+
"only_npi_licensor_present",
|
64 |
+
"only_npi_scope",
|
65 |
+
"passive_1",
|
66 |
+
"passive_2",
|
67 |
+
"principle_A_c_command",
|
68 |
+
"principle_A_case_1",
|
69 |
+
"principle_A_case_2",
|
70 |
+
"principle_A_domain_1",
|
71 |
+
"principle_A_domain_2",
|
72 |
+
"principle_A_domain_3",
|
73 |
+
"principle_A_reconstruction",
|
74 |
+
"regular_plural_subject_verb_agreement_1",
|
75 |
+
"regular_plural_subject_verb_agreement_2",
|
76 |
+
"sentential_negation_npi_licensor_present",
|
77 |
+
"sentential_negation_npi_scope",
|
78 |
+
"sentential_subject_island",
|
79 |
+
"superlative_quantifiers_1",
|
80 |
+
"superlative_quantifiers_2",
|
81 |
+
"tough_vs_raising_1",
|
82 |
+
"tough_vs_raising_2",
|
83 |
+
"transitive",
|
84 |
+
"wh_island",
|
85 |
+
"wh_questions_object_gap",
|
86 |
+
"wh_questions_subject_gap",
|
87 |
+
"wh_questions_subject_gap_long_distance",
|
88 |
+
"wh_vs_that_no_gap",
|
89 |
+
"wh_vs_that_no_gap_long_distance",
|
90 |
+
"wh_vs_that_with_gap",
|
91 |
+
"wh_vs_that_with_gap_long_distance",
|
92 |
+
]
|
93 |
+
|
94 |
+
_CITATION = r"""
|
95 |
@article{warstadt2020blimp,
|
96 |
author = {Warstadt, Alex and Parrish, Alicia and Liu, Haokun and Mohananey, Anhad and Peng, Wei and Wang, Sheng-Fu and Bowman, Samuel R.},
|
97 |
title = {BLiMP: The Benchmark of Linguistic Minimal Pairs for English},
|
|
|
107 |
}
|
108 |
"""
|
109 |
|
110 |
+
_DESCRIPTION = """BLiMP is a challenge set for evaluating what language models (LMs) know about major grammatical phenomena in English.
|
|
|
111 |
BLiMP consists of 67 sub-datasets, each containing 1000 minimal pairs isolating specific contrasts in syntax, morphology, or semantics.
|
112 |
The data is automatically generated according to expert-crafted grammars. Aggregate human agreement with the labels is 96.4%.
|
113 |
We use BLiMP to evaluate an n-gram LM, LSTM LM, GPT-2, and Transformer-XL.
|
|
|
117 |
|
118 |
_KWARGS_DESCRIPTION = """
|
119 |
Args:
|
120 |
+
model_id (str): model used for calculating Blimp, NOTE: should be a causal LM model
|
121 |
+
predictions (list[str]): names of metrics to run. pass empty list or ["*"] to run all of them
|
122 |
batch_size (int): the batch size to run texts through the model. Defaults to 16.
|
123 |
+
device (str): device to run on, defaults to 'cuda' when available.
|
124 |
+
samples_per_set (int): the number of samples per phenomenon, defaults to 1_000.
|
125 |
+
|
126 |
Returns:
|
127 |
blimp: dictionary containing the blimp scores for each of the 67 sub-datasets, as well as the overall accuracy.
|
128 |
An LM’s overall accuracy on BLiMP is simply the proportion of the 67,000 minimal pairs in which the model assigns a higher probability to the acceptable sentence.
|
|
|
132 |
|
133 |
|
134 |
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
|
135 |
+
class Blimp(evaluate.Metric):
|
136 |
def _info(self):
|
137 |
return evaluate.MetricInfo(
|
138 |
module_type="metric",
|
|
|
152 |
|
153 |
def _compute(
|
154 |
self,
|
|
|
155 |
model_id,
|
156 |
+
predictions=None,
|
157 |
batch_size: int = 16,
|
|
|
158 |
device=None,
|
159 |
+
samples_per_set: int = 1_000,
|
160 |
):
|
161 |
if device is not None:
|
162 |
assert device in ["gpu", "cpu", "cuda", "mps"], (
|
|
|
173 |
|
174 |
model = AutoModelForCausalLM.from_pretrained(model_id)
|
175 |
model = model.to(device)
|
176 |
+
model.eval()
|
177 |
|
178 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
179 |
|
|
|
191 |
# assign one of the special tokens to also be the pad token
|
192 |
tokenizer.add_special_tokens({"pad_token": existing_special_tokens[0]})
|
193 |
|
194 |
+
print("PAD", tokenizer.pad_token_id)
|
195 |
+
|
196 |
+
run_all = len(predictions) == 0 or predictions[0] == "*"
|
197 |
+
blimp_sets = (
|
198 |
+
BLIMP_PHENOMENA
|
199 |
+
if run_all
|
200 |
+
else [p for p in BLIMP_PHENOMENA if p.lower() in predictions]
|
201 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
202 |
|
203 |
+
assert len(blimp_sets) > 0, "no valid phenomena selected"
|
|
|
204 |
|
205 |
+
results = {}
|
206 |
+
|
207 |
+
for phenomenon in logging.tqdm(blimp_sets, desc="Evaluating phenomena..."):
|
208 |
+
dataset = datasets.load_dataset("nyu-mll/blimp", phenomenon)["train"]
|
209 |
+
|
210 |
+
# Prepare batches of good and bad sentences
|
211 |
+
|
212 |
+
sents = [(x["sentence_good"], x["sentence_bad"]) for x in dataset]
|
213 |
+
good_sents, bad_sents = zip(*sents[: min(1000, samples_per_set)])
|
214 |
+
|
215 |
+
# Get probabilities in batches
|
216 |
+
good_probs = get_batch_probabilities(
|
217 |
+
model, tokenizer, good_sents, device, batch_size, phenomenon
|
218 |
)
|
219 |
+
bad_probs = get_batch_probabilities(
|
220 |
+
model,
|
221 |
+
tokenizer,
|
222 |
+
bad_sents,
|
223 |
+
device,
|
224 |
+
batch_size,
|
225 |
+
phenomenon,
|
226 |
+
sent_type="bad",
|
227 |
)
|
228 |
|
229 |
+
# Compare probabilities
|
230 |
+
correct = sum(g > b for g, b in zip(good_probs, bad_probs))
|
231 |
+
accuracy = correct / len(good_probs)
|
232 |
+
results[phenomenon] = accuracy
|
233 |
+
|
234 |
+
# Calculate overall accuracy
|
235 |
+
overall_accuracy = sum(results.values()) / len(results)
|
236 |
+
|
237 |
+
return {"phenomenon_accuracies": results, "overall_accuracy": overall_accuracy}
|
238 |
+
|
239 |
+
|
240 |
+
def get_batch_probabilities(
|
241 |
+
model,
|
242 |
+
tokenizer,
|
243 |
+
sentences: list[str],
|
244 |
+
device: str,
|
245 |
+
batch_size: int,
|
246 |
+
phenomenon: str,
|
247 |
+
sent_type: str = "good",
|
248 |
+
):
|
249 |
+
"""Compute log probabilities for a batch of sentences"""
|
250 |
+
probs = []
|
251 |
+
|
252 |
+
for i in logging.tqdm(
|
253 |
+
range(0, len(sentences), batch_size),
|
254 |
+
desc=f"{phenomenon} - {sent_type} sentences...",
|
255 |
+
leave=False,
|
256 |
+
):
|
257 |
+
batch = sentences[i : i + batch_size]
|
258 |
+
inputs = tokenizer(
|
259 |
+
batch, padding=batch_size > 1, return_tensors="pt", truncation=True
|
260 |
+
).to(device)
|
261 |
+
|
262 |
+
with torch.no_grad():
|
263 |
+
outputs = model(**inputs)
|
264 |
+
|
265 |
+
labels = inputs.input_ids
|
266 |
+
|
267 |
+
# Compute log probabilities
|
268 |
+
log_probs = torch.nn.functional.log_softmax(outputs.logits, dim=-1)
|
269 |
+
|
270 |
+
# Get probability of each actual token
|
271 |
+
token_log_probs = torch.gather(log_probs, 2, labels.unsqueeze(-1)).squeeze(-1)
|
272 |
+
|
273 |
+
if batch_size > 1:
|
274 |
+
# Create attention mask for padding
|
275 |
+
mask = (labels != tokenizer.pad_token_id).float()
|
276 |
+
token_log_probs *= mask
|
277 |
+
|
278 |
+
# sum log probabilities
|
279 |
+
sequence_log_probs = (token_log_probs).sum(dim=1)
|
280 |
|
281 |
+
probs.extend(sequence_log_probs.cpu().tolist())
|
282 |
|
283 |
+
return probs
|