HarryLee commited on
Commit
3f1db0e
1 Parent(s): ff775c5

Add data source

Browse files
data/__init__.py ADDED
File without changes
data/data_utils.py ADDED
@@ -0,0 +1,601 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The OFA-Sys Team.
2
+ # All rights reserved.
3
+ # This source code is licensed under the Apache 2.0 license
4
+ # found in the LICENSE file in the root directory.
5
+
6
+ try:
7
+ from collections.abc import Iterable
8
+ except ImportError:
9
+ from collections import Iterable
10
+ import contextlib
11
+ import itertools
12
+ import logging
13
+ import re
14
+ import warnings
15
+ from typing import Optional, Tuple
16
+
17
+ import numpy as np
18
+ import torch
19
+
20
+ from fairseq.file_io import PathManager
21
+ from fairseq import utils
22
+ import os
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ def infer_language_pair(path):
28
+ """Infer language pair from filename: <split>.<lang1>-<lang2>.(...).idx"""
29
+ src, dst = None, None
30
+ for filename in PathManager.ls(path):
31
+ parts = filename.split(".")
32
+ if len(parts) >= 3 and len(parts[1].split("-")) == 2:
33
+ return parts[1].split("-")
34
+ return src, dst
35
+
36
+
37
+ def collate_tokens(
38
+ values,
39
+ pad_idx,
40
+ eos_idx=None,
41
+ left_pad=False,
42
+ move_eos_to_beginning=False,
43
+ pad_to_length=None,
44
+ pad_to_multiple=1,
45
+ pad_to_bsz=None,
46
+ ):
47
+ """Convert a list of 1d tensors into a padded 2d tensor."""
48
+ size = max(v.size(0) for v in values)
49
+ size = size if pad_to_length is None else max(size, pad_to_length)
50
+ if pad_to_multiple != 1 and size % pad_to_multiple != 0:
51
+ size = int(((size - 0.1) // pad_to_multiple + 1) * pad_to_multiple)
52
+
53
+ def copy_tensor(src, dst):
54
+ assert dst.numel() == src.numel()
55
+ if move_eos_to_beginning:
56
+ if eos_idx is None:
57
+ # if no eos_idx is specified, then use the last token in src
58
+ dst[0] = src[-1]
59
+ else:
60
+ dst[0] = eos_idx
61
+ dst[1:] = src[:-1]
62
+ else:
63
+ dst.copy_(src)
64
+
65
+ if values[0].dim() == 1:
66
+ res = values[0].new(len(values), size).fill_(pad_idx)
67
+ elif values[0].dim() == 2:
68
+ assert move_eos_to_beginning is False
69
+ res = values[0].new(len(values), size, values[0].size(1)).fill_(pad_idx)
70
+ else:
71
+ raise NotImplementedError
72
+
73
+ for i, v in enumerate(values):
74
+ copy_tensor(v, res[i][size - len(v) :] if left_pad else res[i][: len(v)])
75
+ return res
76
+
77
+
78
+ def load_indexed_dataset(
79
+ path, dictionary=None, dataset_impl=None, combine=False, default="cached"
80
+ ):
81
+ """A helper function for loading indexed datasets.
82
+
83
+ Args:
84
+ path (str): path to indexed dataset (e.g., 'data-bin/train')
85
+ dictionary (~fairseq.data.Dictionary): data dictionary
86
+ dataset_impl (str, optional): which dataset implementation to use. If
87
+ not provided, it will be inferred automatically. For legacy indexed
88
+ data we use the 'cached' implementation by default.
89
+ combine (bool, optional): automatically load and combine multiple
90
+ datasets. For example, if *path* is 'data-bin/train', then we will
91
+ combine 'data-bin/train', 'data-bin/train1', ... and return a
92
+ single ConcatDataset instance.
93
+ """
94
+ import fairseq.data.indexed_dataset as indexed_dataset
95
+ from fairseq.data.concat_dataset import ConcatDataset
96
+
97
+ datasets = []
98
+ for k in itertools.count():
99
+ path_k = path + (str(k) if k > 0 else "")
100
+ try:
101
+ path_k = indexed_dataset.get_indexed_dataset_to_local(path_k)
102
+ except Exception as e:
103
+ if "StorageException: [404] Path not found" in str(e):
104
+ logger.warning(f"path_k: {e} not found")
105
+ else:
106
+ raise e
107
+
108
+ dataset_impl_k = dataset_impl
109
+ if dataset_impl_k is None:
110
+ dataset_impl_k = indexed_dataset.infer_dataset_impl(path_k)
111
+ dataset = indexed_dataset.make_dataset(
112
+ path_k,
113
+ impl=dataset_impl_k or default,
114
+ fix_lua_indexing=True,
115
+ dictionary=dictionary,
116
+ )
117
+ if dataset is None:
118
+ break
119
+ logger.info("loaded {:,} examples from: {}".format(len(dataset), path_k))
120
+ datasets.append(dataset)
121
+ if not combine:
122
+ break
123
+ if len(datasets) == 0:
124
+ return None
125
+ elif len(datasets) == 1:
126
+ return datasets[0]
127
+ else:
128
+ return ConcatDataset(datasets)
129
+
130
+
131
+ @contextlib.contextmanager
132
+ def numpy_seed(seed, *addl_seeds):
133
+ """Context manager which seeds the NumPy PRNG with the specified seed and
134
+ restores the state afterward"""
135
+ if seed is None:
136
+ yield
137
+ return
138
+ if len(addl_seeds) > 0:
139
+ seed = int(hash((seed, *addl_seeds)) % 1e6)
140
+ state = np.random.get_state()
141
+ np.random.seed(seed)
142
+ try:
143
+ yield
144
+ finally:
145
+ np.random.set_state(state)
146
+
147
+
148
+ def collect_filtered(function, iterable, filtered):
149
+ """
150
+ Similar to :func:`filter` but collects filtered elements in ``filtered``.
151
+
152
+ Args:
153
+ function (callable): function that returns ``False`` for elements that
154
+ should be filtered
155
+ iterable (iterable): iterable to filter
156
+ filtered (list): list to store filtered elements
157
+ """
158
+ for el in iterable:
159
+ if function(el):
160
+ yield el
161
+ else:
162
+ filtered.append(el)
163
+
164
+
165
+ def _filter_by_size_dynamic(indices, size_fn, max_positions, raise_exception=False):
166
+ def compare_leq(a, b):
167
+ return a <= b if not isinstance(a, tuple) else max(a) <= b
168
+
169
+ def check_size(idx):
170
+ if isinstance(max_positions, float) or isinstance(max_positions, int):
171
+ return size_fn(idx) <= max_positions
172
+ elif isinstance(max_positions, dict):
173
+ idx_size = size_fn(idx)
174
+ assert isinstance(idx_size, dict)
175
+ intersect_keys = set(max_positions.keys()) & set(idx_size.keys())
176
+ return all(
177
+ all(
178
+ a is None or b is None or a <= b
179
+ for a, b in zip(idx_size[key], max_positions[key])
180
+ )
181
+ for key in intersect_keys
182
+ )
183
+ else:
184
+ # For MultiCorpusSampledDataset, will generalize it later
185
+ if not isinstance(size_fn(idx), Iterable):
186
+ return all(size_fn(idx) <= b for b in max_positions)
187
+ return all(
188
+ a is None or b is None or a <= b
189
+ for a, b in zip(size_fn(idx), max_positions)
190
+ )
191
+
192
+ ignored = []
193
+ itr = collect_filtered(check_size, indices, ignored)
194
+ indices = np.fromiter(itr, dtype=np.int64, count=-1)
195
+ return indices, ignored
196
+
197
+
198
+ def filter_by_size(indices, dataset, max_positions, raise_exception=False):
199
+ """
200
+ [deprecated] Filter indices based on their size.
201
+ Use `FairseqDataset::filter_indices_by_size` instead.
202
+
203
+ Args:
204
+ indices (List[int]): ordered list of dataset indices
205
+ dataset (FairseqDataset): fairseq dataset instance
206
+ max_positions (tuple): filter elements larger than this size.
207
+ Comparisons are done component-wise.
208
+ raise_exception (bool, optional): if ``True``, raise an exception if
209
+ any elements are filtered (default: False).
210
+ """
211
+ warnings.warn(
212
+ "data_utils.filter_by_size is deprecated. "
213
+ "Use `FairseqDataset::filter_indices_by_size` instead.",
214
+ stacklevel=2,
215
+ )
216
+ if isinstance(max_positions, float) or isinstance(max_positions, int):
217
+ if hasattr(dataset, "sizes") and isinstance(dataset.sizes, np.ndarray):
218
+ ignored = indices[dataset.sizes[indices] > max_positions].tolist()
219
+ indices = indices[dataset.sizes[indices] <= max_positions]
220
+ elif (
221
+ hasattr(dataset, "sizes")
222
+ and isinstance(dataset.sizes, list)
223
+ and len(dataset.sizes) == 1
224
+ ):
225
+ ignored = indices[dataset.sizes[0][indices] > max_positions].tolist()
226
+ indices = indices[dataset.sizes[0][indices] <= max_positions]
227
+ else:
228
+ indices, ignored = _filter_by_size_dynamic(
229
+ indices, dataset.size, max_positions
230
+ )
231
+ else:
232
+ indices, ignored = _filter_by_size_dynamic(indices, dataset.size, max_positions)
233
+
234
+ if len(ignored) > 0 and raise_exception:
235
+ raise Exception(
236
+ (
237
+ "Size of sample #{} is invalid (={}) since max_positions={}, "
238
+ "skip this example with --skip-invalid-size-inputs-valid-test"
239
+ ).format(ignored[0], dataset.size(ignored[0]), max_positions)
240
+ )
241
+ if len(ignored) > 0:
242
+ logger.warning(
243
+ (
244
+ "{} samples have invalid sizes and will be skipped, "
245
+ "max_positions={}, first few sample ids={}"
246
+ ).format(len(ignored), max_positions, ignored[:10])
247
+ )
248
+ return indices
249
+
250
+
251
+ def filter_paired_dataset_indices_by_size(src_sizes, tgt_sizes, indices, max_sizes):
252
+ """Filter a list of sample indices. Remove those that are longer
253
+ than specified in max_sizes.
254
+
255
+ Args:
256
+ indices (np.array): original array of sample indices
257
+ max_sizes (int or list[int] or tuple[int]): max sample size,
258
+ can be defined separately for src and tgt (then list or tuple)
259
+
260
+ Returns:
261
+ np.array: filtered sample array
262
+ list: list of removed indices
263
+ """
264
+ if max_sizes is None:
265
+ return indices, []
266
+ if type(max_sizes) in (int, float):
267
+ max_src_size, max_tgt_size = max_sizes, max_sizes
268
+ else:
269
+ max_src_size, max_tgt_size = max_sizes
270
+ if tgt_sizes is None:
271
+ ignored = indices[src_sizes[indices] > max_src_size]
272
+ else:
273
+ ignored = indices[
274
+ (src_sizes[indices] > max_src_size) | (tgt_sizes[indices] > max_tgt_size)
275
+ ]
276
+ if len(ignored) > 0:
277
+ if tgt_sizes is None:
278
+ indices = indices[src_sizes[indices] <= max_src_size]
279
+ else:
280
+ indices = indices[
281
+ (src_sizes[indices] <= max_src_size)
282
+ & (tgt_sizes[indices] <= max_tgt_size)
283
+ ]
284
+ return indices, ignored.tolist()
285
+
286
+
287
+ def batch_by_size(
288
+ indices,
289
+ num_tokens_fn,
290
+ num_tokens_vec=None,
291
+ max_tokens=None,
292
+ max_sentences=None,
293
+ required_batch_size_multiple=1,
294
+ fixed_shapes=None,
295
+ ):
296
+ """
297
+ Yield mini-batches of indices bucketed by size. Batches may contain
298
+ sequences of different lengths.
299
+
300
+ Args:
301
+ indices (List[int]): ordered list of dataset indices
302
+ num_tokens_fn (callable): function that returns the number of tokens at
303
+ a given index
304
+ num_tokens_vec (List[int], optional): precomputed vector of the number
305
+ of tokens for each index in indices (to enable faster batch generation)
306
+ max_tokens (int, optional): max number of tokens in each batch
307
+ (default: None).
308
+ max_sentences (int, optional): max number of sentences in each
309
+ batch (default: None).
310
+ required_batch_size_multiple (int, optional): require batch size to
311
+ be less than N or a multiple of N (default: 1).
312
+ fixed_shapes (List[Tuple[int, int]], optional): if given, batches will
313
+ only be created with the given shapes. *max_sentences* and
314
+ *required_batch_size_multiple* will be ignored (default: None).
315
+ """
316
+ try:
317
+ from fairseq.data.data_utils_fast import (
318
+ batch_by_size_fn,
319
+ batch_by_size_vec,
320
+ batch_fixed_shapes_fast,
321
+ )
322
+ except ImportError:
323
+ raise ImportError(
324
+ "Please build Cython components with: "
325
+ "`python setup.py build_ext --inplace`"
326
+ )
327
+ except ValueError:
328
+ raise ValueError(
329
+ "Please build (or rebuild) Cython components with `python setup.py build_ext --inplace`."
330
+ )
331
+
332
+ # added int() to avoid TypeError: an integer is required
333
+ max_tokens = (
334
+ int(max_tokens) if max_tokens is not None else -1
335
+ )
336
+ max_sentences = max_sentences if max_sentences is not None else -1
337
+ bsz_mult = required_batch_size_multiple
338
+
339
+ if not isinstance(indices, np.ndarray):
340
+ indices = np.fromiter(indices, dtype=np.int64, count=-1)
341
+
342
+ if num_tokens_vec is not None and not isinstance(num_tokens_vec, np.ndarray):
343
+ num_tokens_vec = np.fromiter(num_tokens_vec, dtype=np.int64, count=-1)
344
+
345
+ if fixed_shapes is None:
346
+ if num_tokens_vec is None:
347
+ return batch_by_size_fn(
348
+ indices,
349
+ num_tokens_fn,
350
+ max_tokens,
351
+ max_sentences,
352
+ bsz_mult,
353
+ )
354
+ else:
355
+ return batch_by_size_vec(
356
+ indices,
357
+ num_tokens_vec,
358
+ max_tokens,
359
+ max_sentences,
360
+ bsz_mult,
361
+ )
362
+
363
+ else:
364
+ fixed_shapes = np.array(fixed_shapes, dtype=np.int64)
365
+ sort_order = np.lexsort(
366
+ [
367
+ fixed_shapes[:, 1].argsort(), # length
368
+ fixed_shapes[:, 0].argsort(), # bsz
369
+ ]
370
+ )
371
+ fixed_shapes_sorted = fixed_shapes[sort_order]
372
+ return batch_fixed_shapes_fast(indices, num_tokens_fn, fixed_shapes_sorted)
373
+
374
+
375
+ def post_process(sentence: str, symbol: str):
376
+ if symbol == "sentencepiece":
377
+ sentence = sentence.replace(" ", "").replace("\u2581", " ").strip()
378
+ elif symbol == "wordpiece":
379
+ sentence = sentence.replace(" ", "").replace("_", " ").strip()
380
+ elif symbol == "letter":
381
+ sentence = sentence.replace(" ", "").replace("|", " ").strip()
382
+ elif symbol == "silence":
383
+ import re
384
+ sentence = sentence.replace("<SIL>", "")
385
+ sentence = re.sub(' +', ' ', sentence).strip()
386
+ elif symbol == "_EOW":
387
+ sentence = sentence.replace(" ", "").replace("_EOW", " ").strip()
388
+ elif symbol in {"subword_nmt", "@@ ", "@@"}:
389
+ if symbol == "subword_nmt":
390
+ symbol = "@@ "
391
+ sentence = (sentence + " ").replace(symbol, "").rstrip()
392
+ elif symbol == "none":
393
+ pass
394
+ elif symbol is not None:
395
+ raise NotImplementedError(f"Unknown post_process option: {symbol}")
396
+ return sentence
397
+
398
+
399
+ def compute_mask_indices(
400
+ shape: Tuple[int, int],
401
+ padding_mask: Optional[torch.Tensor],
402
+ mask_prob: float,
403
+ mask_length: int,
404
+ mask_type: str = "static",
405
+ mask_other: float = 0.0,
406
+ min_masks: int = 0,
407
+ no_overlap: bool = False,
408
+ min_space: int = 0,
409
+ ) -> np.ndarray:
410
+ """
411
+ Computes random mask spans for a given shape
412
+
413
+ Args:
414
+ shape: the the shape for which to compute masks.
415
+ should be of size 2 where first element is batch size and 2nd is timesteps
416
+ padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
417
+ mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
418
+ number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
419
+ however due to overlaps, the actual number will be smaller (unless no_overlap is True)
420
+ mask_type: how to compute mask lengths
421
+ static = fixed size
422
+ uniform = sample from uniform distribution [mask_other, mask_length*2]
423
+ normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
424
+ poisson = sample from possion distribution with lambda = mask length
425
+ min_masks: minimum number of masked spans
426
+ no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
427
+ min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
428
+ """
429
+
430
+ bsz, all_sz = shape
431
+ mask = np.full((bsz, all_sz), False)
432
+
433
+ all_num_mask = int(
434
+ # add a random number for probabilistic rounding
435
+ mask_prob * all_sz / float(mask_length)
436
+ + np.random.rand()
437
+ )
438
+
439
+ all_num_mask = max(min_masks, all_num_mask)
440
+
441
+ mask_idcs = []
442
+ for i in range(bsz):
443
+ if padding_mask is not None:
444
+ sz = all_sz - padding_mask[i].long().sum().item()
445
+ num_mask = int(
446
+ # add a random number for probabilistic rounding
447
+ mask_prob * sz / float(mask_length)
448
+ + np.random.rand()
449
+ )
450
+ num_mask = max(min_masks, num_mask)
451
+ else:
452
+ sz = all_sz
453
+ num_mask = all_num_mask
454
+
455
+ if mask_type == "static":
456
+ lengths = np.full(num_mask, mask_length)
457
+ elif mask_type == "uniform":
458
+ lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask)
459
+ elif mask_type == "normal":
460
+ lengths = np.random.normal(mask_length, mask_other, size=num_mask)
461
+ lengths = [max(1, int(round(x))) for x in lengths]
462
+ elif mask_type == "poisson":
463
+ lengths = np.random.poisson(mask_length, size=num_mask)
464
+ lengths = [int(round(x)) for x in lengths]
465
+ else:
466
+ raise Exception("unknown mask selection " + mask_type)
467
+
468
+ if sum(lengths) == 0:
469
+ lengths[0] = min(mask_length, sz - 1)
470
+
471
+ if no_overlap:
472
+ mask_idc = []
473
+
474
+ def arrange(s, e, length, keep_length):
475
+ span_start = np.random.randint(s, e - length)
476
+ mask_idc.extend(span_start + i for i in range(length))
477
+
478
+ new_parts = []
479
+ if span_start - s - min_space >= keep_length:
480
+ new_parts.append((s, span_start - min_space + 1))
481
+ if e - span_start - keep_length - min_space > keep_length:
482
+ new_parts.append((span_start + length + min_space, e))
483
+ return new_parts
484
+
485
+ parts = [(0, sz)]
486
+ min_length = min(lengths)
487
+ for length in sorted(lengths, reverse=True):
488
+ lens = np.fromiter(
489
+ (e - s if e - s >= length + min_space else 0 for s, e in parts),
490
+ np.int,
491
+ )
492
+ l_sum = np.sum(lens)
493
+ if l_sum == 0:
494
+ break
495
+ probs = lens / np.sum(lens)
496
+ c = np.random.choice(len(parts), p=probs)
497
+ s, e = parts.pop(c)
498
+ parts.extend(arrange(s, e, length, min_length))
499
+ mask_idc = np.asarray(mask_idc)
500
+ else:
501
+ min_len = min(lengths)
502
+ if sz - min_len <= num_mask:
503
+ min_len = sz - num_mask - 1
504
+
505
+ mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
506
+
507
+ mask_idc = np.asarray(
508
+ [
509
+ mask_idc[j] + offset
510
+ for j in range(len(mask_idc))
511
+ for offset in range(lengths[j])
512
+ ]
513
+ )
514
+
515
+ mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
516
+
517
+ min_len = min([len(m) for m in mask_idcs])
518
+ for i, mask_idc in enumerate(mask_idcs):
519
+ if len(mask_idc) > min_len:
520
+ mask_idc = np.random.choice(mask_idc, min_len, replace=False)
521
+ mask[i, mask_idc] = True
522
+
523
+ return mask
524
+
525
+
526
+ def get_mem_usage():
527
+ try:
528
+ import psutil
529
+
530
+ mb = 1024 * 1024
531
+ return f"used={psutil.virtual_memory().used / mb}Mb; avail={psutil.virtual_memory().available / mb}Mb"
532
+ except ImportError:
533
+ return "N/A"
534
+
535
+
536
+ # lens: torch.LongTensor
537
+ # returns: torch.BoolTensor
538
+ def lengths_to_padding_mask(lens):
539
+ bsz, max_lens = lens.size(0), torch.max(lens).item()
540
+ mask = torch.arange(max_lens).to(lens.device).view(1, max_lens)
541
+ mask = mask.expand(bsz, -1) >= lens.view(bsz, 1).expand(-1, max_lens)
542
+ return mask
543
+
544
+
545
+ # lens: torch.LongTensor
546
+ # returns: torch.BoolTensor
547
+ def lengths_to_mask(lens):
548
+ return ~lengths_to_padding_mask(lens)
549
+
550
+
551
+ def get_buckets(sizes, num_buckets):
552
+ buckets = np.unique(
553
+ np.percentile(
554
+ sizes,
555
+ np.linspace(0, 100, num_buckets + 1),
556
+ interpolation='lower',
557
+ )[1:]
558
+ )
559
+ return buckets
560
+
561
+
562
+ def get_bucketed_sizes(orig_sizes, buckets):
563
+ sizes = np.copy(orig_sizes)
564
+ assert np.min(sizes) >= 0
565
+ start_val = -1
566
+ for end_val in buckets:
567
+ mask = (sizes > start_val) & (sizes <= end_val)
568
+ sizes[mask] = end_val
569
+ start_val = end_val
570
+ return sizes
571
+
572
+
573
+
574
+ def _find_extra_valid_paths(dataset_path: str) -> set:
575
+ paths = utils.split_paths(dataset_path)
576
+ all_valid_paths = set()
577
+ for sub_dir in paths:
578
+ contents = PathManager.ls(sub_dir)
579
+ valid_paths = [c for c in contents if re.match("valid*[0-9].*", c) is not None]
580
+ all_valid_paths |= {os.path.basename(p) for p in valid_paths}
581
+ # Remove .bin, .idx etc
582
+ roots = {os.path.splitext(p)[0] for p in all_valid_paths}
583
+ return roots
584
+
585
+
586
+ def raise_if_valid_subsets_unintentionally_ignored(train_cfg) -> None:
587
+ """Raises if there are paths matching 'valid*[0-9].*' which are not combined or ignored."""
588
+ if (
589
+ train_cfg.dataset.ignore_unused_valid_subsets
590
+ or train_cfg.dataset.combine_valid_subsets
591
+ or train_cfg.dataset.disable_validation
592
+ or not hasattr(train_cfg.task, "data")
593
+ ):
594
+ return
595
+ other_paths = _find_extra_valid_paths(train_cfg.task.data)
596
+ specified_subsets = train_cfg.dataset.valid_subset.split(",")
597
+ ignored_paths = [p for p in other_paths if p not in specified_subsets]
598
+ if ignored_paths:
599
+ advice = "Set --combine-val to combine them or --ignore-unused-valid-subsets to ignore them."
600
+ msg = f"Valid paths {ignored_paths} will be ignored. {advice}"
601
+ raise ValueError(msg)
data/file_dataset.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The OFA-Sys Team.
2
+ # All rights reserved.
3
+ # This source code is licensed under the Apache 2.0 license
4
+ # found in the LICENSE file in the root directory.
5
+
6
+ import os
7
+ import torch
8
+ import pickle
9
+
10
+
11
+ class FileDataset:
12
+ def __init__(self, file_path, selected_col_ids=None, dtypes=None, separator="\t", cached_index=False):
13
+ self.file_path = file_path
14
+ assert os.path.exists(self.file_path), "Error: The local datafile {} not exists!".format(self.file_path)
15
+
16
+ self.separator = separator
17
+ if selected_col_ids is None:
18
+ # default to all fields
19
+ self.selected_col_ids = list(
20
+ range(len(open(self.file_path).readline().rstrip("\n").split(self.separator))))
21
+ else:
22
+ self.selected_col_ids = [int(col_id) for col_id in selected_col_ids.split(",")]
23
+ if dtypes is None:
24
+ # default to str
25
+ self.dtypes = [str for col_id in self.selected_col_ids]
26
+ else:
27
+ self.dtypes = [eval(col_dtype) for col_dtype in dtypes.split(",")]
28
+ assert len(self.dtypes) == len(self.selected_col_ids)
29
+
30
+ self.data_cnt = 0
31
+ try:
32
+ self.slice_id = torch.distributed.get_rank()
33
+ self.slice_count = torch.distributed.get_world_size()
34
+ except Exception:
35
+ self.slice_id = 0
36
+ self.slice_count = 1
37
+ self.cached_index = cached_index
38
+ self._init_seek_index()
39
+ self._reader = self._get_reader()
40
+ print("file {} slice_id {} row count {} total row count {}".format(
41
+ self.file_path, self.slice_id, self.row_count, self.total_row_count)
42
+ )
43
+
44
+ def _init_seek_index(self):
45
+ if self.cached_index:
46
+ cache_path = "{}.index".format(self.file_path)
47
+ assert os.path.exists(cache_path), "cache file {} not exists!".format(cache_path)
48
+ self.total_row_count, self.lineid_to_offset = pickle.load(open(cache_path, "rb"))
49
+ print("local datafile {} slice_id {} use cached row_count and line_idx-to-offset mapping".format(
50
+ self.file_path, self.slice_id))
51
+ else:
52
+ # make an iteration over the file to get row_count and line_idx-to-offset mapping
53
+ fp = open(self.file_path, "r")
54
+ print("local datafile {} slice_id {} begin to initialize row_count and line_idx-to-offset mapping".format(
55
+ self.file_path, self.slice_id))
56
+ self.total_row_count = 0
57
+ offset = 0
58
+ self.lineid_to_offset = []
59
+ for line in fp:
60
+ self.lineid_to_offset.append(offset)
61
+ self.total_row_count += 1
62
+ offset += len(line.encode('utf-8'))
63
+ self._compute_start_pos_and_row_count()
64
+ print("local datafile {} slice_id {} finished initializing row_count and line_idx-to-offset mapping".format(
65
+ self.file_path, self.slice_id))
66
+
67
+ def _compute_start_pos_and_row_count(self):
68
+ self.row_count = self.total_row_count // self.slice_count
69
+ if self.slice_id < self.total_row_count - self.row_count * self.slice_count:
70
+ self.row_count += 1
71
+ self.start_pos = self.row_count * self.slice_id
72
+ else:
73
+ self.start_pos = self.row_count * self.slice_id + (self.total_row_count - self.row_count * self.slice_count)
74
+
75
+ def _get_reader(self):
76
+ fp = open(self.file_path, "r")
77
+ fp.seek(self.lineid_to_offset[self.start_pos])
78
+ return fp
79
+
80
+ def _seek(self, offset=0):
81
+ try:
82
+ print("slice_id {} seek offset {}".format(self.slice_id, self.start_pos + offset))
83
+ self._reader.seek(self.lineid_to_offset[self.start_pos + offset])
84
+ self.data_cnt = offset
85
+ except Exception:
86
+ print("slice_id {} seek offset {}".format(self.slice_id, offset))
87
+ self._reader.seek(self.lineid_to_offset[offset])
88
+ self.data_cnt = offset
89
+
90
+ def __del__(self):
91
+ self._reader.close()
92
+
93
+ def __len__(self):
94
+ return self.row_count
95
+
96
+ def get_total_row_count(self):
97
+ return self.total_row_count
98
+
99
+ def __getitem__(self, index):
100
+ if self.data_cnt == self.row_count:
101
+ print("reach the end of datafile, start a new reader")
102
+ self.data_cnt = 0
103
+ self._reader = self._get_reader()
104
+ column_l = self._reader.readline().rstrip("\n").split(self.separator)
105
+ self.data_cnt += 1
106
+ column_l = [dtype(column_l[col_id]) for col_id, dtype in zip(self.selected_col_ids, self.dtypes)]
107
+ return column_l
data/mm_data/caption_dataset.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The OFA-Sys Team.
2
+ # All rights reserved.
3
+ # This source code is licensed under the Apache 2.0 license
4
+ # found in the LICENSE file in the root directory.
5
+
6
+ from io import BytesIO
7
+
8
+ import logging
9
+ import warnings
10
+ import string
11
+
12
+ import numpy as np
13
+ import torch
14
+ import base64
15
+ from torchvision import transforms
16
+
17
+ from PIL import Image, ImageFile
18
+
19
+ from data import data_utils
20
+ from data.ofa_dataset import OFADataset
21
+
22
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
23
+ ImageFile.MAX_IMAGE_PIXELS = None
24
+ Image.MAX_IMAGE_PIXELS = None
25
+
26
+ logger = logging.getLogger(__name__)
27
+ warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
28
+
29
+ IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
30
+ IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
31
+
32
+
33
+ def collate(samples, pad_idx, eos_idx):
34
+ if len(samples) == 0:
35
+ return {}
36
+
37
+ def merge(key):
38
+ return data_utils.collate_tokens(
39
+ [s[key] for s in samples],
40
+ pad_idx,
41
+ eos_idx=eos_idx,
42
+ )
43
+
44
+ id = np.array([s["id"] for s in samples])
45
+ src_tokens = merge("source")
46
+ src_lengths = torch.LongTensor([s["source"].ne(pad_idx).long().sum() for s in samples])
47
+
48
+ patch_images = torch.stack([sample['patch_image'] for sample in samples], dim=0)
49
+ patch_masks = torch.cat([sample['patch_mask'] for sample in samples])
50
+
51
+ prev_output_tokens = None
52
+ target = None
53
+ if samples[0].get("target", None) is not None:
54
+ target = merge("target")
55
+ tgt_lengths = torch.LongTensor([s["target"].ne(pad_idx).long().sum() for s in samples])
56
+ ntokens = tgt_lengths.sum().item()
57
+
58
+ if samples[0].get("prev_output_tokens", None) is not None:
59
+ prev_output_tokens = merge("prev_output_tokens")
60
+ else:
61
+ ntokens = src_lengths.sum().item()
62
+
63
+ batch = {
64
+ "id": id,
65
+ "nsentences": len(samples),
66
+ "ntokens": ntokens,
67
+ "net_input": {
68
+ "src_tokens": src_tokens,
69
+ "src_lengths": src_lengths,
70
+ "patch_images": patch_images,
71
+ "patch_masks": patch_masks,
72
+ "prev_output_tokens": prev_output_tokens
73
+ },
74
+ "target": target,
75
+ }
76
+
77
+ return batch
78
+
79
+
80
+ class CaptionDataset(OFADataset):
81
+ def __init__(
82
+ self,
83
+ split,
84
+ dataset,
85
+ bpe,
86
+ src_dict,
87
+ tgt_dict=None,
88
+ max_src_length=128,
89
+ max_tgt_length=30,
90
+ patch_image_size=224,
91
+ imagenet_default_mean_and_std=False,
92
+ scst=False
93
+ ):
94
+ super().__init__(split, dataset, bpe, src_dict, tgt_dict)
95
+ self.max_src_length = max_src_length
96
+ self.max_tgt_length = max_tgt_length
97
+ self.patch_image_size = patch_image_size
98
+ self.scst = scst
99
+
100
+ self.transtab = str.maketrans({key: None for key in string.punctuation})
101
+
102
+ if imagenet_default_mean_and_std:
103
+ mean = IMAGENET_DEFAULT_MEAN
104
+ std = IMAGENET_DEFAULT_STD
105
+ else:
106
+ mean = [0.5, 0.5, 0.5]
107
+ std = [0.5, 0.5, 0.5]
108
+
109
+ self.patch_resize_transform = transforms.Compose([
110
+ lambda image: image.convert("RGB"),
111
+ transforms.Resize((patch_image_size, patch_image_size), interpolation=Image.BICUBIC),
112
+ transforms.ToTensor(),
113
+ transforms.Normalize(mean=mean, std=std),
114
+ ])
115
+
116
+ def __getitem__(self, index):
117
+ uniq_id, image, caption = self.dataset[index]
118
+
119
+ image = Image.open(BytesIO(base64.urlsafe_b64decode(image)))
120
+ patch_image = self.patch_resize_transform(image)
121
+ patch_mask = torch.tensor([True])
122
+
123
+ if self.split == 'train' and not self.scst:
124
+ caption = caption.translate(self.transtab).strip()
125
+ caption_token_list = caption.strip().split()
126
+ tgt_caption = ' '.join(caption_token_list[:self.max_tgt_length])
127
+ else:
128
+ caption = ' '.join(caption.strip().split())
129
+ caption_list = [cap.translate(self.transtab).strip() for cap in caption.strip().split('&&')]
130
+ tgt_caption = '&&'.join(caption_list)
131
+ src_item = self.encode_text(" what does the image describe?")
132
+ tgt_item = self.encode_text(" {}".format(tgt_caption))
133
+
134
+ src_item = torch.cat([self.bos_item, src_item, self.eos_item])
135
+ target_item = torch.cat([tgt_item, self.eos_item])
136
+ prev_output_item = torch.cat([self.bos_item, tgt_item])
137
+
138
+ example = {
139
+ "id": uniq_id,
140
+ "source": src_item,
141
+ "patch_image": patch_image,
142
+ "patch_mask": patch_mask,
143
+ "target": target_item,
144
+ "prev_output_tokens": prev_output_item
145
+ }
146
+ return example
147
+
148
+ def collater(self, samples, pad_to_length=None):
149
+ """Merge a list of samples to form a mini-batch.
150
+ Args:
151
+ samples (List[dict]): samples to collate
152
+ Returns:
153
+ dict: a mini-batch containing the data of the task
154
+ """
155
+ return collate(samples, pad_idx=self.pad, eos_idx=self.eos)
data/ofa_dataset.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The OFA-Sys Team.
2
+ # All rights reserved.
3
+ # This source code is licensed under the Apache 2.0 license
4
+ # found in the LICENSE file in the root directory.
5
+
6
+ import logging
7
+ import re
8
+ import torch.utils.data
9
+ from fairseq.data import FairseqDataset
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class OFADataset(FairseqDataset):
15
+ def __init__(self, split, dataset, bpe, src_dict, tgt_dict):
16
+ self.split = split
17
+ self.dataset = dataset
18
+ self.bpe = bpe
19
+ self.src_dict = src_dict
20
+ self.tgt_dict = tgt_dict
21
+
22
+ self.bos = src_dict.bos()
23
+ self.eos = src_dict.eos()
24
+ self.pad = src_dict.pad()
25
+ self.bos_item = torch.LongTensor([self.bos])
26
+ self.eos_item = torch.LongTensor([self.eos])
27
+
28
+ def __len__(self):
29
+ return len(self.dataset)
30
+
31
+ def encode_text(self, text, length=None, append_bos=False, append_eos=False, use_bpe=True):
32
+ s = self.tgt_dict.encode_line(
33
+ line=self.bpe.encode(text) if use_bpe else text,
34
+ add_if_not_exist=False,
35
+ append_eos=False
36
+ ).long()
37
+ if length is not None:
38
+ s = s[:length]
39
+ if append_bos:
40
+ s = torch.cat([self.bos_item, s])
41
+ if append_eos:
42
+ s = torch.cat([s, self.eos_item])
43
+ return s
44
+
45
+ def pre_question(self, question, max_ques_words):
46
+ question = question.lower().lstrip(",.!?*#:;~").replace('-', ' ').replace('/', ' ')
47
+
48
+ question = re.sub(
49
+ r"\s{2,}",
50
+ ' ',
51
+ question,
52
+ )
53
+ question = question.rstrip('\n')
54
+ question = question.strip(' ')
55
+
56
+ # truncate question
57
+ question_words = question.split(' ')
58
+ if len(question_words) > max_ques_words:
59
+ question = ' '.join(question_words[:max_ques_words])
60
+
61
+ return question
62
+
63
+ def pre_caption(self, caption, max_words):
64
+ caption = caption.lower().lstrip(",.!?*#:;~").replace('-', ' ').replace('/', ' ').replace('<person>', 'person')
65
+
66
+ caption = re.sub(
67
+ r"\s{2,}",
68
+ ' ',
69
+ caption,
70
+ )
71
+ caption = caption.rstrip('\n')
72
+ caption = caption.strip(' ')
73
+
74
+ # truncate caption
75
+ caption_words = caption.split(' ')
76
+ if len(caption_words) > max_words:
77
+ caption = ' '.join(caption_words[:max_words])
78
+
79
+ return caption