jgauthier commited on
Commit
8a3618a
1 Parent(s): 8059baf

batch surprisal computation, now GPU friendly

Browse files
Files changed (1) hide show
  1. syntaxgym.py +59 -48
syntaxgym.py CHANGED
@@ -174,68 +174,82 @@ class SyntaxGym(evaluate.EvaluationModule):
174
 
175
  tokenizer, tokenizer_kwargs = prepare_tokenizer(model, batch_size, add_start_token)
176
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
  results = {}
178
  result_keys = ["prediction_results", "region_totals"]
179
- # TODO batch all items together
180
- for item in datasets.logging.tqdm(dataset):
181
- result_single = self._compute_single(item, tokenizer, tokenizer_kwargs,
182
- model, device)
183
 
184
  suite_name = item["suite_name"]
185
  if suite_name not in results:
186
  results[suite_name] = SyntaxGymMetricSuiteResult(suite_name, [], [])
187
  for k in result_keys:
188
- getattr(results[suite_name], k).append(result_single[k])
189
 
190
  return results
191
 
192
- def _compute_single(self, item, tokenizer, tokenizer_kwargs, model, device):
193
- tokenized = tokenizer(item["conditions"]["content"],
194
- return_tensors="pt",
195
- return_offsets_mapping=True,
196
- **tokenizer_kwargs).to(device)
197
-
198
- # input_ids: B * T
199
- input_ids = tokenized["input_ids"]
200
- assert input_ids.ndim == 2
201
-
202
- # Compute sentence level surprisals.
203
- with torch.no_grad():
204
- # Pre-softmax predictive distribution B * T * V
205
- logits = model(input_ids).logits
206
- surprisals = -logits.log_softmax(dim=2) / np.log(2)
207
-
208
- # surprisals: B * T * V
209
- assert surprisals.ndim == 3
210
-
211
- # Get surprisals of expected words.
212
- surps_shifted = surprisals[:, :-1, :]
213
- expected_ids = input_ids[:, 1:]
214
-
215
- # reindexed surprisals: B * (T - 1)
216
- surprisals = torch.gather(surps_shifted, 2, expected_ids.unsqueeze(2)) \
217
- .squeeze(2)
218
 
219
  #### aggregate
220
- condition_names = item["conditions"]["condition_name"]
221
  region_totals = {condition_name: defaultdict(float)
222
- for condition_name in condition_names}
223
  region2tokens = self.compute_region_token_mapping(
224
- item, input_ids, tokenized["offset_mapping"])
225
 
226
- for i, (i_cond, i_inputs) in enumerate(zip(condition_names, input_ids)):
227
- for region_number, region_tokens in region2tokens[i_cond].items():
228
  for token in region_tokens:
229
  if token == 0:
230
  # surprisal not defined. pass.
231
  continue
232
- elif token <= surprisals.shape[1]:
233
- region_totals[i_cond][region_number] += surprisals[i, token - 1]
234
  else:
235
  # TODO don't think this is an issue, just should clean
236
  # up the aggregation output
237
- assert token == surprisals.shape[1], \
238
- "%s %s" % (token, surprisals.shape[1])
239
 
240
  region_totals = {(condition_name, region_number): float(total)
241
  for condition_name, totals in region_totals.items()
@@ -275,23 +289,20 @@ class SyntaxGym(evaluate.EvaluationModule):
275
 
276
  return ret
277
 
278
- def compute_region_token_mapping(self, item, input_ids: torch.LongTensor,
279
  offset_mapping: List[Tuple[int, int]]
280
  ) -> Dict[str, Dict[int, List[int]]]:
281
- # input_ids: B * T
282
  # offset_mapping: B * T * 2
283
- # assumes batch is sorted according to item's condition_name order
284
 
285
- condition_names = item["conditions"]["condition_name"]
286
- region2tokens = {cond: defaultdict(list) for cond in condition_names}
287
 
288
  max_long = torch.iinfo(torch.int64).max
289
 
290
- for i_cond, (i_tokens, i_offsets) in enumerate(zip(input_ids, offset_mapping)):
291
  region_edges = self.get_region_edges(item, i_cond)
292
 
293
  t_cursor, r_cursor = 0, 0
294
- while t_cursor < i_tokens.shape[0]:
295
  # token = i_tokens[t_cursor]
296
  token_char_start, token_char_end = i_offsets[t_cursor]
297
 
@@ -310,7 +321,7 @@ class SyntaxGym(evaluate.EvaluationModule):
310
  r_cursor += 1
311
  continue
312
 
313
- region2tokens[condition_names[i_cond]][r_cursor + 1].append(t_cursor)
314
  t_cursor += 1
315
 
316
  return region2tokens
 
174
 
175
  tokenizer, tokenizer_kwargs = prepare_tokenizer(model, batch_size, add_start_token)
176
 
177
+ # Flatten sentences, enforcing that sentences are always ordered by the same condition.
178
+ condition_order = dataset[0]["conditions"]["condition_name"]
179
+ all_sentences = []
180
+ for item in dataset:
181
+ for condition_name in condition_order:
182
+ # Get idx of condition for this item.
183
+ condition_idx = item["conditions"]["condition_name"].index(condition_name)
184
+ all_sentences.append(item["conditions"]["content"][condition_idx])
185
+
186
+ # Tokenize sentences and split into batches.
187
+ all_tokenized_sentences = tokenizer(all_sentences, return_tensors="pt",
188
+ return_offsets_mapping=True,
189
+ **tokenizer_kwargs).to(device)
190
+ tokenized_batches = torch.split(all_tokenized_sentences["input_ids"], batch_size)
191
+
192
+ # Compute surprisal per-batch and combine into a single surprisal tensor.
193
+ n_sentences, n_timesteps = all_tokenized_sentences["input_ids"].shape
194
+ surprisals = torch.zeros(n_sentences, n_timesteps - 1).float().to(device)
195
+ for i, batch in enumerate(datasets.logging.tqdm(tokenized_batches)) :
196
+ batch = batch.to(device)
197
+ with torch.no_grad():
198
+ # logits are B * T * V
199
+ b_logits = model(batch)["logits"]
200
+ b_surprisals = -b_logits.log_softmax(dim=2) / np.log(2)
201
+
202
+ # Get surprisals of ground-truth words.
203
+ gt_idxs = batch[:, 1:]
204
+ # Reindexed surprisals: B * (T - 1)
205
+ b_surprisals_gt = torch.gather(b_surprisals[:, :-1, :], 2, gt_idxs.unsqueeze(2)).squeeze(2)
206
+
207
+ surprisals[i * batch_size : (i + 1) * batch_size] = b_surprisals_gt
208
+
209
+ # Reshape to intuitive axes n_items * n_conditions * ...
210
+ surprisals = surprisals.reshape((len(dataset), len(condition_order), -1))
211
+ offset_mapping = all_tokenized_sentences["offset_mapping"] \
212
+ .reshape((len(dataset), len(condition_order), -1, 2))
213
+
214
+ # Now evaluate per-item.
215
  results = {}
216
  result_keys = ["prediction_results", "region_totals"]
217
+ for item, item_surprisals, item_offset_mapping in zip(datasets.logging.tqdm(dataset), surprisals, offset_mapping):
218
+ result_i = self._compute_item(item, item_surprisals, item_offset_mapping, condition_order)
 
 
219
 
220
  suite_name = item["suite_name"]
221
  if suite_name not in results:
222
  results[suite_name] = SyntaxGymMetricSuiteResult(suite_name, [], [])
223
  for k in result_keys:
224
+ getattr(results[suite_name], k).append(result_i[k])
225
 
226
  return results
227
 
228
+ def _compute_item(self, item, item_surprisals, offset_mapping, condition_order):
229
+ """
230
+ Aggregate token-level surprisals to region-level surprisals for the given item,
231
+ and evaluate the item's predictions.
232
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
 
234
  #### aggregate
 
235
  region_totals = {condition_name: defaultdict(float)
236
+ for condition_name in condition_order}
237
  region2tokens = self.compute_region_token_mapping(
238
+ item, condition_order, offset_mapping)
239
 
240
+ for i, (cond_i, surprisals_i) in enumerate(zip(condition_order, item_surprisals)):
241
+ for region_number, region_tokens in region2tokens[cond_i].items():
242
  for token in region_tokens:
243
  if token == 0:
244
  # surprisal not defined. pass.
245
  continue
246
+ elif token <= item_surprisals.shape[1]:
247
+ region_totals[cond_i][region_number] += surprisals_i[token - 1]
248
  else:
249
  # TODO don't think this is an issue, just should clean
250
  # up the aggregation output
251
+ assert token == surprisals_i.shape[1], \
252
+ "%s %s" % (token, surprisals_i.shape[1])
253
 
254
  region_totals = {(condition_name, region_number): float(total)
255
  for condition_name, totals in region_totals.items()
 
289
 
290
  return ret
291
 
292
+ def compute_region_token_mapping(self, item, condition_order,
293
  offset_mapping: List[Tuple[int, int]]
294
  ) -> Dict[str, Dict[int, List[int]]]:
 
295
  # offset_mapping: B * T * 2
 
296
 
297
+ region2tokens = {cond: defaultdict(list) for cond in condition_order}
 
298
 
299
  max_long = torch.iinfo(torch.int64).max
300
 
301
+ for i_cond, i_offsets in enumerate(offset_mapping):
302
  region_edges = self.get_region_edges(item, i_cond)
303
 
304
  t_cursor, r_cursor = 0, 0
305
+ while t_cursor < i_offsets.shape[0]:
306
  # token = i_tokens[t_cursor]
307
  token_char_start, token_char_end = i_offsets[t_cursor]
308
 
 
321
  r_cursor += 1
322
  continue
323
 
324
+ region2tokens[condition_order[i_cond]][r_cursor + 1].append(t_cursor)
325
  t_cursor += 1
326
 
327
  return region2tokens