dmahata commited on
Commit
d0ca0bf
1 Parent(s): da4e718

Upload utils.py

Browse files
Files changed (1) hide show
  1. utils.py +739 -0
utils.py ADDED
@@ -0,0 +1,739 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import itertools
16
+ import json
17
+ import linecache
18
+ import math
19
+ import os
20
+ import pickle
21
+ import socket
22
+ from logging import getLogger
23
+ from pathlib import Path
24
+ from typing import Callable, Dict, Iterable, List, Tuple, Union
25
+
26
+ import git
27
+ import numpy as np
28
+ import torch
29
+ import torch.distributed as dist
30
+ from rouge_score import rouge_scorer, scoring
31
+ from sacrebleu import corpus_bleu
32
+ from torch import nn
33
+ from torch.utils.data import Dataset, Sampler
34
+
35
+ from sentence_splitter import add_newline_to_end_of_each_sentence
36
+ from transformers import BartTokenizer, EvalPrediction, PreTrainedTokenizer, T5Tokenizer
37
+ from transformers.file_utils import cached_property
38
+ from transformers.models.bart.modeling_bart import shift_tokens_right
39
+
40
+
41
+ try:
42
+ from fairseq.data.data_utils import batch_by_size
43
+
44
+ FAIRSEQ_AVAILABLE = True
45
+ except (ImportError, ModuleNotFoundError):
46
+ FAIRSEQ_AVAILABLE = False
47
+
48
+
49
+ def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100):
50
+ """From fairseq"""
51
+ if target.dim() == lprobs.dim() - 1:
52
+ target = target.unsqueeze(-1)
53
+ nll_loss = -lprobs.gather(dim=-1, index=target)
54
+ smooth_loss = -lprobs.sum(dim=-1, keepdim=True)
55
+ if ignore_index is not None:
56
+ pad_mask = target.eq(ignore_index)
57
+ nll_loss.masked_fill_(pad_mask, 0.0)
58
+ smooth_loss.masked_fill_(pad_mask, 0.0)
59
+ else:
60
+ nll_loss = nll_loss.squeeze(-1)
61
+ smooth_loss = smooth_loss.squeeze(-1)
62
+
63
+ nll_loss = nll_loss.sum() # mean()? Scared to break other math.
64
+ smooth_loss = smooth_loss.sum()
65
+ eps_i = epsilon / lprobs.size(-1)
66
+ loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss
67
+ return loss, nll_loss
68
+
69
+
70
+ def lmap(f: Callable, x: Iterable) -> List:
71
+ """list(map(f, x))"""
72
+ return list(map(f, x))
73
+
74
+
75
+ def calculate_bleu(output_lns, refs_lns, **kwargs) -> dict:
76
+ """Uses sacrebleu's corpus_bleu implementation."""
77
+ return {"bleu": round(corpus_bleu(output_lns, [refs_lns], **kwargs).score, 4)}
78
+
79
+
80
+ def build_compute_metrics_fn(
81
+ task_name: str, tokenizer: PreTrainedTokenizer
82
+ ) -> Callable[[EvalPrediction], Dict]:
83
+ def non_pad_len(tokens: np.ndarray) -> int:
84
+ return np.count_nonzero(tokens != tokenizer.pad_token_id)
85
+
86
+ def decode_pred(pred: EvalPrediction) -> Tuple[List[str], List[str]]:
87
+ pred_ids = pred.predictions
88
+ label_ids = pred.label_ids
89
+ pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
90
+ label_ids[label_ids == -100] = tokenizer.pad_token_id
91
+ label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
92
+ pred_str = lmap(str.strip, pred_str)
93
+ label_str = lmap(str.strip, label_str)
94
+ return pred_str, label_str
95
+
96
+ def summarization_metrics(pred: EvalPrediction) -> Dict:
97
+ pred_str, label_str = decode_pred(pred)
98
+ rouge: Dict = calculate_rouge(pred_str, label_str)
99
+ summ_len = np.round(np.mean(lmap(non_pad_len, pred.predictions)), 1)
100
+ rouge.update({"gen_len": summ_len})
101
+ return rouge
102
+
103
+ def translation_metrics(pred: EvalPrediction) -> Dict:
104
+ pred_str, label_str = decode_pred(pred)
105
+ bleu: Dict = calculate_bleu(pred_str, label_str)
106
+ gen_len = np.round(np.mean(lmap(non_pad_len, pred.predictions)), 1)
107
+ bleu.update({"gen_len": gen_len})
108
+ return bleu
109
+
110
+ compute_metrics_fn = (
111
+ summarization_metrics if "summarization" in task_name else translation_metrics
112
+ )
113
+ return compute_metrics_fn
114
+
115
+
116
+ def trim_batch(
117
+ input_ids,
118
+ pad_token_id,
119
+ attention_mask=None,
120
+ ):
121
+ """Remove columns that are populated exclusively by pad_token_id"""
122
+ keep_column_mask = input_ids.ne(pad_token_id).any(dim=0)
123
+ if attention_mask is None:
124
+ return input_ids[:, keep_column_mask]
125
+ else:
126
+ return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask])
127
+
128
+
129
+ class AbstractSeq2SeqDataset(Dataset):
130
+ def __init__(
131
+ self,
132
+ tokenizer,
133
+ data_dir,
134
+ max_source_length,
135
+ max_target_length,
136
+ type_path="train",
137
+ n_obs=None,
138
+ prefix="",
139
+ **dataset_kwargs,
140
+ ):
141
+ super().__init__()
142
+ self.src_file = Path(data_dir).joinpath(type_path + ".source")
143
+ self.tgt_file = Path(data_dir).joinpath(type_path + ".target")
144
+ self.len_file = Path(data_dir).joinpath(type_path + ".len")
145
+ if os.path.exists(self.len_file):
146
+ self.src_lens = pickle_load(self.len_file)
147
+ self.used_char_len = False
148
+ else:
149
+ self.src_lens = self.get_char_lens(self.src_file)
150
+ self.used_char_len = True
151
+ self.max_source_length = max_source_length
152
+ self.max_target_length = max_target_length
153
+ assert min(self.src_lens) > 0, f"found empty line in {self.src_file}"
154
+ self.tokenizer = tokenizer
155
+ self.prefix = prefix if prefix is not None else ""
156
+
157
+ if n_obs is not None:
158
+ self.src_lens = self.src_lens[:n_obs]
159
+ self.pad_token_id = self.tokenizer.pad_token_id
160
+ self.dataset_kwargs = dataset_kwargs
161
+ dataset_kwargs.update(
162
+ {"add_prefix_space": True}
163
+ if isinstance(self.tokenizer, BartTokenizer)
164
+ else {}
165
+ )
166
+
167
+ def __len__(self):
168
+ return len(self.src_lens)
169
+
170
+ @staticmethod
171
+ def get_char_lens(data_file):
172
+ return [len(x) for x in Path(data_file).open().readlines()]
173
+
174
+ @cached_property
175
+ def tgt_lens(self):
176
+ """Length in characters of target documents"""
177
+ return self.get_char_lens(self.tgt_file)
178
+
179
+ def make_sortish_sampler(
180
+ self, batch_size, distributed=False, shuffle=True, **kwargs
181
+ ):
182
+ if distributed:
183
+ return DistributedSortishSampler(
184
+ self, batch_size, shuffle=shuffle, **kwargs
185
+ )
186
+ else:
187
+ return SortishSampler(self.src_lens, batch_size, shuffle=shuffle)
188
+
189
+ def make_dynamic_sampler(self, max_tokens_per_batch=1024, **kwargs):
190
+ assert FAIRSEQ_AVAILABLE, "Dynamic batch size requires `pip install fairseq`"
191
+ assert (
192
+ not self.used_char_len
193
+ ), "You must call python make_len_file.py before calling make_dynamic_sampler"
194
+ sorted_indices = list(self.make_sortish_sampler(1024, shuffle=False))
195
+
196
+ def num_tokens_in_example(i):
197
+ return min(self.src_lens[i], self.max_target_length)
198
+
199
+ # call fairseq cython function
200
+ batch_sampler: List[List[int]] = batch_by_size(
201
+ sorted_indices,
202
+ num_tokens_fn=num_tokens_in_example,
203
+ max_tokens=max_tokens_per_batch,
204
+ required_batch_size_multiple=64,
205
+ )
206
+ shuffled_batches = [
207
+ batch_sampler[i] for i in np.random.permutation(range(len(batch_sampler)))
208
+ ]
209
+ # move the largest batch to the front to OOM quickly (uses an approximation for padding)
210
+ approximate_toks_per_batch = [
211
+ max(self.src_lens[i] for i in batch) * len(batch)
212
+ for batch in shuffled_batches
213
+ ]
214
+ largest_batch_idx = np.argmax(approximate_toks_per_batch)
215
+ shuffled_batches[0], shuffled_batches[largest_batch_idx] = (
216
+ shuffled_batches[largest_batch_idx],
217
+ shuffled_batches[0],
218
+ )
219
+ return shuffled_batches
220
+
221
+ def __getitem__(self, item):
222
+ raise NotImplementedError("You must implement this")
223
+
224
+ def collate_fn(self, batch):
225
+ raise NotImplementedError("You must implement this")
226
+
227
+
228
+ class LegacySeq2SeqDataset(AbstractSeq2SeqDataset):
229
+ def __getitem__(self, index) -> Dict[str, torch.Tensor]:
230
+ """Call tokenizer on src and tgt_lines"""
231
+ index = index + 1 # linecache starts at 1
232
+ source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip(
233
+ "\n"
234
+ )
235
+ tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n")
236
+ assert source_line, f"empty source line for index {index}"
237
+ assert tgt_line, f"empty tgt line for index {index}"
238
+ source_inputs = self.encode_line(
239
+ self.tokenizer, source_line, self.max_source_length
240
+ )
241
+ target_inputs = self.encode_line(
242
+ self.tokenizer, tgt_line, self.max_target_length
243
+ )
244
+
245
+ source_ids = source_inputs["input_ids"].squeeze()
246
+ target_ids = target_inputs["input_ids"].squeeze()
247
+ src_mask = source_inputs["attention_mask"].squeeze()
248
+ return {
249
+ "input_ids": source_ids,
250
+ "attention_mask": src_mask,
251
+ "labels": target_ids,
252
+ }
253
+
254
+ def encode_line(
255
+ self, tokenizer, line, max_length, pad_to_max_length=True, return_tensors="pt"
256
+ ):
257
+ """Only used by LegacyDataset"""
258
+ return tokenizer(
259
+ [line],
260
+ max_length=max_length,
261
+ padding="max_length" if pad_to_max_length else None,
262
+ truncation=True,
263
+ return_tensors=return_tensors,
264
+ **self.dataset_kwargs,
265
+ )
266
+
267
+ def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
268
+ input_ids = torch.stack([x["input_ids"] for x in batch])
269
+ masks = torch.stack([x["attention_mask"] for x in batch])
270
+ target_ids = torch.stack([x["labels"] for x in batch])
271
+ pad_token_id = self.pad_token_id
272
+ y = trim_batch(target_ids, pad_token_id)
273
+ source_ids, source_mask = trim_batch(
274
+ input_ids, pad_token_id, attention_mask=masks
275
+ )
276
+ batch = {
277
+ "input_ids": source_ids,
278
+ "attention_mask": source_mask,
279
+ "labels": y,
280
+ }
281
+ return batch
282
+
283
+
284
+ class Seq2SeqDataset(AbstractSeq2SeqDataset):
285
+ """A dataset that calls prepare_seq2seq_batch."""
286
+
287
+ def __getitem__(self, index) -> Dict[str, str]:
288
+ index = index + 1 # linecache starts at 1
289
+ source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip(
290
+ "\n"
291
+ )
292
+ tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n")
293
+ assert source_line, f"empty source line for index {index}"
294
+ assert tgt_line, f"empty tgt line for index {index}"
295
+ return {"tgt_texts": tgt_line, "src_texts": source_line, "id": index - 1}
296
+
297
+ def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
298
+ """Call prepare_seq2seq_batch."""
299
+ batch_encoding: Dict[str, torch.Tensor] = self.tokenizer.prepare_seq2seq_batch(
300
+ [x["src_texts"] for x in batch],
301
+ tgt_texts=[x["tgt_texts"] for x in batch],
302
+ max_length=self.max_source_length,
303
+ max_target_length=self.max_target_length,
304
+ return_tensors="pt",
305
+ **self.dataset_kwargs,
306
+ ).data
307
+ batch_encoding["ids"] = torch.tensor([x["id"] for x in batch])
308
+ return batch_encoding
309
+
310
+
311
+ class Seq2SeqDataCollator:
312
+ def __init__(
313
+ self, tokenizer, data_args, decoder_start_token_id, tpu_num_cores=None
314
+ ):
315
+ self.tokenizer = tokenizer
316
+ self.pad_token_id = tokenizer.pad_token_id
317
+ self.decoder_start_token_id = decoder_start_token_id
318
+ assert (
319
+ self.pad_token_id is not None
320
+ ), f"pad_token_id is not defined for ({self.tokenizer.__class__.__name__}), it must be defined."
321
+ self.data_args = data_args
322
+ self.tpu_num_cores = tpu_num_cores
323
+ self.dataset_kwargs = (
324
+ {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) else {}
325
+ )
326
+ if data_args.src_lang is not None:
327
+ self.dataset_kwargs["src_lang"] = data_args.src_lang
328
+ if data_args.tgt_lang is not None:
329
+ self.dataset_kwargs["tgt_lang"] = data_args.tgt_lang
330
+
331
+ def __call__(self, batch) -> Dict[str, torch.Tensor]:
332
+ if hasattr(self.tokenizer, "prepare_seq2seq_batch"):
333
+ batch = self._encode(batch)
334
+ input_ids, attention_mask, labels = (
335
+ batch["input_ids"],
336
+ batch["attention_mask"],
337
+ batch["labels"],
338
+ )
339
+ else:
340
+ input_ids = torch.stack([x["input_ids"] for x in batch])
341
+ attention_mask = torch.stack([x["attention_mask"] for x in batch])
342
+ labels = torch.stack([x["labels"] for x in batch])
343
+
344
+ labels = trim_batch(labels, self.pad_token_id)
345
+ input_ids, attention_mask = trim_batch(
346
+ input_ids, self.pad_token_id, attention_mask=attention_mask
347
+ )
348
+
349
+ if isinstance(self.tokenizer, T5Tokenizer):
350
+ decoder_input_ids = self._shift_right_t5(labels)
351
+ else:
352
+ decoder_input_ids = shift_tokens_right(
353
+ labels, self.pad_token_id, self.decoder_start_token_id
354
+ )
355
+
356
+ batch = {
357
+ "input_ids": input_ids,
358
+ "attention_mask": attention_mask,
359
+ "decoder_input_ids": decoder_input_ids,
360
+ "labels": labels,
361
+ }
362
+ return batch
363
+
364
+ def _shift_right_t5(self, input_ids):
365
+ # shift inputs to the right
366
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
367
+ shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
368
+ shifted_input_ids[..., 0] = self.pad_token_id
369
+ return shifted_input_ids
370
+
371
+ def _encode(self, batch) -> Dict[str, torch.Tensor]:
372
+ batch_encoding = self.tokenizer.prepare_seq2seq_batch(
373
+ [x["src_texts"] for x in batch],
374
+ tgt_texts=[x["tgt_texts"] for x in batch],
375
+ max_length=self.data_args.max_source_length,
376
+ max_target_length=self.data_args.max_target_length,
377
+ padding="max_length"
378
+ if self.tpu_num_cores is not None
379
+ else "longest", # TPU hack
380
+ return_tensors="pt",
381
+ **self.dataset_kwargs,
382
+ )
383
+ return batch_encoding.data
384
+
385
+
386
+ class SortishSampler(Sampler):
387
+ "Go through the text data by order of src length with a bit of randomness. From fastai repo."
388
+
389
+ def __init__(self, data, batch_size, shuffle=True):
390
+ self.data, self.bs, self.shuffle = data, batch_size, shuffle
391
+
392
+ def __len__(self) -> int:
393
+ return len(self.data)
394
+
395
+ def __iter__(self):
396
+ return iter(sortish_sampler_indices(self.data, self.bs, shuffle=self.shuffle))
397
+
398
+
399
+ def sortish_sampler_indices(data: List, bs: int, shuffle=True) -> np.array:
400
+ "Go through the text data by order of src length with a bit of randomness. From fastai repo."
401
+ if not shuffle:
402
+ return np.argsort(np.array(data) * -1)
403
+
404
+ def key_fn(i):
405
+ return data[i]
406
+
407
+ idxs = np.random.permutation(len(data))
408
+ sz = bs * 50
409
+ ck_idx = [idxs[i : i + sz] for i in range(0, len(idxs), sz)]
410
+ sort_idx = np.concatenate([sorted(s, key=key_fn, reverse=True) for s in ck_idx])
411
+ sz = bs
412
+ ck_idx = [sort_idx[i : i + sz] for i in range(0, len(sort_idx), sz)]
413
+ max_ck = np.argmax(
414
+ [key_fn(ck[0]) for ck in ck_idx]
415
+ ) # find the chunk with the largest key,
416
+ ck_idx[0], ck_idx[max_ck] = (
417
+ ck_idx[max_ck],
418
+ ck_idx[0],
419
+ ) # then make sure it goes first.
420
+ sort_idx = (
421
+ np.concatenate(np.random.permutation(ck_idx[1:]))
422
+ if len(ck_idx) > 1
423
+ else np.array([], dtype=np.int)
424
+ )
425
+ sort_idx = np.concatenate((ck_idx[0], sort_idx))
426
+ return sort_idx
427
+
428
+
429
+ class DistributedSortishSampler(Sampler):
430
+ """Copied from torch DistributedSampler"""
431
+
432
+ def __init__(
433
+ self,
434
+ dataset,
435
+ batch_size,
436
+ num_replicas=None,
437
+ rank=None,
438
+ add_extra_examples=True,
439
+ shuffle=True,
440
+ ):
441
+ if num_replicas is None:
442
+ if not dist.is_available():
443
+ raise RuntimeError("Requires distributed package to be available")
444
+ num_replicas = dist.get_world_size()
445
+ if rank is None:
446
+ if not dist.is_available():
447
+ raise RuntimeError("Requires distributed package to be available")
448
+ rank = dist.get_rank()
449
+ self.dataset = dataset
450
+ self.num_replicas = num_replicas
451
+ self.rank = rank
452
+ self.epoch = 0
453
+ if add_extra_examples:
454
+ self.num_samples = int(
455
+ math.ceil(len(self.dataset) * 1.0 / self.num_replicas)
456
+ )
457
+ self.total_size = self.num_samples * self.num_replicas
458
+ else:
459
+ self.total_size = len(dataset)
460
+ self.num_samples = len(self.available_indices)
461
+ self.batch_size = batch_size
462
+ self.add_extra_examples = add_extra_examples
463
+ self.shuffle = shuffle
464
+
465
+ def __iter__(self) -> Iterable:
466
+ g = torch.Generator()
467
+ g.manual_seed(self.epoch)
468
+
469
+ sortish_data = [self.dataset.src_lens[i] for i in self.available_indices]
470
+ sortish_indices = sortish_sampler_indices(
471
+ sortish_data, self.batch_size, shuffle=self.shuffle
472
+ )
473
+ indices = [self.available_indices[i] for i in sortish_indices]
474
+ assert len(indices) == self.num_samples
475
+ return iter(indices)
476
+
477
+ @cached_property
478
+ def available_indices(self) -> np.array:
479
+ indices = list(range(len(self.dataset)))
480
+ # add extra samples to make it evenly divisible
481
+ indices += indices[: (self.total_size - len(indices))]
482
+ assert len(indices) == self.total_size
483
+ # subsample
484
+ available_indices = indices[self.rank : self.total_size : self.num_replicas]
485
+ return available_indices
486
+
487
+ def __len__(self):
488
+ return self.num_samples
489
+
490
+ def set_epoch(self, epoch):
491
+ self.epoch = epoch
492
+
493
+
494
+ logger = getLogger(__name__)
495
+
496
+
497
+ def use_task_specific_params(model, task):
498
+ """Update config with summarization specific params."""
499
+ task_specific_params = model.config.task_specific_params
500
+
501
+ if task_specific_params is not None:
502
+ pars = task_specific_params.get(task, {})
503
+ logger.info(
504
+ f"setting model.config to task specific params for {task}:\n {pars}"
505
+ )
506
+ logger.info("note: command line args may override some of these")
507
+ model.config.update(pars)
508
+
509
+
510
+ def pickle_load(path):
511
+ """pickle.load(path)"""
512
+ with open(path, "rb") as f:
513
+ return pickle.load(f)
514
+
515
+
516
+ def pickle_save(obj, path):
517
+ """pickle.dump(obj, path)"""
518
+ with open(path, "wb") as f:
519
+ return pickle.dump(obj, f)
520
+
521
+
522
+ def flatten_list(summary_ids: List[List]):
523
+ return [x for x in itertools.chain.from_iterable(summary_ids)]
524
+
525
+
526
+ def save_git_info(folder_path: str) -> None:
527
+ """Save git information to output_dir/git_log.json"""
528
+ repo_infos = get_git_info()
529
+ save_json(repo_infos, os.path.join(folder_path, "git_log.json"))
530
+
531
+
532
+ def save_json(content, path, indent=4, **json_dump_kwargs):
533
+ with open(path, "w") as f:
534
+ json.dump(content, f, indent=indent, sort_keys=True, **json_dump_kwargs)
535
+
536
+
537
+ def load_json(path):
538
+ with open(path) as f:
539
+ return json.load(f)
540
+
541
+
542
+ def get_git_info():
543
+ try:
544
+ repo = git.Repo(search_parent_directories=True)
545
+ repo_infos = {
546
+ "repo_id": str(repo),
547
+ "repo_sha": str(repo.head.object.hexsha),
548
+ "repo_branch": str(repo.active_branch),
549
+ "hostname": str(socket.gethostname()),
550
+ }
551
+ return repo_infos
552
+ except TypeError:
553
+ return {
554
+ "repo_id": None,
555
+ "repo_sha": None,
556
+ "repo_branch": None,
557
+ "hostname": None,
558
+ }
559
+
560
+
561
+ ROUGE_KEYS = ["rouge1", "rouge2", "rougeL", "rougeLsum"]
562
+
563
+
564
+ def extract_rouge_mid_statistics(dct):
565
+ new_dict = {}
566
+ for k1, v1 in dct.items():
567
+ mid = v1.mid
568
+ new_dict[k1] = {
569
+ stat: round(getattr(mid, stat), 4)
570
+ for stat in ["precision", "recall", "fmeasure"]
571
+ }
572
+ return new_dict
573
+
574
+
575
+ def calculate_rouge(
576
+ pred_lns: List[str],
577
+ tgt_lns: List[str],
578
+ use_stemmer=True,
579
+ rouge_keys=ROUGE_KEYS,
580
+ return_precision_and_recall=False,
581
+ bootstrap_aggregation=True,
582
+ newline_sep=True,
583
+ ) -> Dict:
584
+ """Calculate rouge using rouge_scorer package.
585
+
586
+ Args:
587
+ pred_lns: list of summaries generated by model
588
+ tgt_lns: list of groundtruth summaries (e.g. contents of val.target)
589
+ use_stemmer: Bool indicating whether Porter stemmer should be used to
590
+ strip word suffixes to improve matching.
591
+ rouge_keys: which metrics to compute, defaults to rouge1, rouge2, rougeL, rougeLsum
592
+ return_precision_and_recall: (False) whether to also return precision and recall.
593
+ bootstrap_aggregation: whether to do the typical bootstrap resampling of scores. Defaults to True, if False
594
+ this function returns a collections.defaultdict[metric: list of values for each observation for each subscore]``
595
+ newline_sep:(default=True) whether to add newline between sentences. This is essential for calculation rougeL
596
+ on multi sentence summaries (CNN/DM dataset).
597
+
598
+ Returns:
599
+ Dict[score: value] if aggregate else defaultdict(list) keyed by rouge_keys
600
+
601
+ """
602
+ scorer = rouge_scorer.RougeScorer(rouge_keys, use_stemmer=use_stemmer)
603
+ aggregator = scoring.BootstrapAggregator()
604
+ for pred, tgt in zip(tgt_lns, pred_lns):
605
+ # rougeLsum expects "\n" separated sentences within a summary
606
+ if newline_sep:
607
+ pred = add_newline_to_end_of_each_sentence(pred)
608
+ tgt = add_newline_to_end_of_each_sentence(tgt)
609
+ scores = scorer.score(pred, tgt)
610
+ aggregator.add_scores(scores)
611
+
612
+ if bootstrap_aggregation:
613
+ result = aggregator.aggregate()
614
+ if return_precision_and_recall:
615
+ return extract_rouge_mid_statistics(result) # here we return dict
616
+ else:
617
+ return {k: round(v.mid.fmeasure * 100, 4) for k, v in result.items()}
618
+
619
+ else:
620
+ return aggregator._scores # here we return defaultdict(list)
621
+
622
+
623
+ # Utilities for freezing parameters and checking whether they are frozen
624
+
625
+
626
+ def freeze_params(model: nn.Module):
627
+ """Set requires_grad=False for each of model.parameters()"""
628
+ for par in model.parameters():
629
+ par.requires_grad = False
630
+
631
+
632
+ def freeze_embeds(model):
633
+ """Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
634
+ model_type = model.config.model_type
635
+
636
+ if model_type in ["t5", "mt5"]:
637
+ freeze_params(model.shared)
638
+ for d in [model.encoder, model.decoder]:
639
+ freeze_params(d.embed_tokens)
640
+ elif model_type == "fsmt":
641
+ for d in [model.model.encoder, model.model.decoder]:
642
+ freeze_params(d.embed_positions)
643
+ freeze_params(d.embed_tokens)
644
+ else:
645
+ freeze_params(model.model.shared)
646
+ for d in [model.model.encoder, model.model.decoder]:
647
+ freeze_params(d.embed_positions)
648
+ freeze_params(d.embed_tokens)
649
+
650
+
651
+ def grad_status(model: nn.Module) -> Iterable:
652
+ return (par.requires_grad for par in model.parameters())
653
+
654
+
655
+ def any_requires_grad(model: nn.Module) -> bool:
656
+ return any(grad_status(model))
657
+
658
+
659
+ def assert_all_frozen(model):
660
+ model_grads: List[bool] = list(grad_status(model))
661
+ n_require_grad = sum(lmap(int, model_grads))
662
+ npars = len(model_grads)
663
+ assert not any(
664
+ model_grads
665
+ ), f"{n_require_grad/npars:.1%} of {npars} weights require grad"
666
+
667
+
668
+ def assert_not_all_frozen(model):
669
+ model_grads: List[bool] = list(grad_status(model))
670
+ npars = len(model_grads)
671
+ assert any(model_grads), f"none of {npars} weights require grad"
672
+
673
+
674
+ def parse_numeric_n_bool_cl_kwargs(
675
+ unparsed_args: List[str],
676
+ ) -> Dict[str, Union[int, float, bool]]:
677
+ """
678
+ Parse an argv list of unspecified command line args to a dict.
679
+ Assumes all values are either numeric or boolean in the form of true/false.
680
+ """
681
+ result = {}
682
+ assert (
683
+ len(unparsed_args) % 2 == 0
684
+ ), f"got odd number of unparsed args: {unparsed_args}"
685
+ num_pairs = len(unparsed_args) // 2
686
+ for pair_num in range(num_pairs):
687
+ i = 2 * pair_num
688
+ assert unparsed_args[i].startswith("--")
689
+ if unparsed_args[i + 1].lower() == "true":
690
+ value = True
691
+ elif unparsed_args[i + 1].lower() == "false":
692
+ value = False
693
+ else:
694
+ try:
695
+ value = int(unparsed_args[i + 1])
696
+ except ValueError:
697
+ value = float(
698
+ unparsed_args[i + 1]
699
+ ) # this can raise another informative ValueError
700
+
701
+ result[unparsed_args[i][2:]] = value
702
+ return result
703
+
704
+
705
+ def write_txt_file(ordered_tgt, path):
706
+ f = Path(path).open("w")
707
+ for ln in ordered_tgt:
708
+ f.write(ln + "\n")
709
+ f.flush()
710
+
711
+
712
+ def chunks(lst, n):
713
+ """Yield successive n-sized chunks from lst."""
714
+ for i in range(0, len(lst), n):
715
+ yield lst[i : i + n]
716
+
717
+
718
+ def check_output_dir(args, expected_items=0):
719
+ """
720
+ Checks whether to bail out if output_dir already exists and has more than expected_items in it
721
+
722
+ `args`: needs to have the following attributes of `args`:
723
+ - output_dir
724
+ - do_train
725
+ - overwrite_output_dir
726
+
727
+ `expected_items`: normally 0 (default) - i.e. empty dir, but in some cases a few files are expected (e.g. recovery from OOM)
728
+ """
729
+ if (
730
+ os.path.exists(args.output_dir)
731
+ and len(os.listdir(args.output_dir)) > expected_items
732
+ and args.do_train
733
+ and not args.overwrite_output_dir
734
+ ):
735
+ raise ValueError(
736
+ f"Output directory ({args.output_dir}) already exists and "
737
+ f"has {len(os.listdir(args.output_dir))} items in it (expected {expected_items} items). "
738
+ "Use --overwrite_output_dir to overcome."
739
+ )