Yixuan Li commited on
Commit
85ba398
·
1 Parent(s): e3e7837

add fairseq folder

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. fairseq/__init__.py +45 -0
  2. fairseq/__pycache__/__init__.cpython-310.pyc +0 -0
  3. fairseq/__pycache__/__init__.cpython-311.pyc +0 -0
  4. fairseq/__pycache__/checkpoint_utils.cpython-310.pyc +0 -0
  5. fairseq/__pycache__/file_chunker_utils.cpython-310.pyc +0 -0
  6. fairseq/__pycache__/file_io.cpython-310.pyc +0 -0
  7. fairseq/__pycache__/file_utils.cpython-310.pyc +0 -0
  8. fairseq/__pycache__/hub_utils.cpython-310.pyc +0 -0
  9. fairseq/__pycache__/incremental_decoding_utils.cpython-310.pyc +0 -0
  10. fairseq/__pycache__/iterative_refinement_generator.cpython-310.pyc +0 -0
  11. fairseq/__pycache__/ngram_repeat_block.cpython-310.pyc +0 -0
  12. fairseq/__pycache__/options.cpython-310.pyc +0 -0
  13. fairseq/__pycache__/pdb.cpython-310.pyc +0 -0
  14. fairseq/__pycache__/quantization_utils.cpython-310.pyc +0 -0
  15. fairseq/__pycache__/registry.cpython-310.pyc +0 -0
  16. fairseq/__pycache__/search.cpython-310.pyc +0 -0
  17. fairseq/__pycache__/sequence_generator.cpython-310.pyc +0 -0
  18. fairseq/__pycache__/speech_generator.cpython-310.pyc +0 -0
  19. fairseq/__pycache__/token_generation_constraints.cpython-310.pyc +0 -0
  20. fairseq/__pycache__/tokenizer.cpython-310.pyc +0 -0
  21. fairseq/__pycache__/utils.cpython-310.pyc +0 -0
  22. fairseq/__pycache__/version.cpython-310.pyc +0 -0
  23. fairseq/__pycache__/version.cpython-311.pyc +0 -0
  24. fairseq/benchmark/__init__.py +7 -0
  25. fairseq/benchmark/__pycache__/__init__.cpython-310.pyc +0 -0
  26. fairseq/benchmark/__pycache__/dummy_dataset.cpython-310.pyc +0 -0
  27. fairseq/benchmark/__pycache__/dummy_lm.cpython-310.pyc +0 -0
  28. fairseq/benchmark/__pycache__/dummy_masked_lm.cpython-310.pyc +0 -0
  29. fairseq/benchmark/__pycache__/dummy_model.cpython-310.pyc +0 -0
  30. fairseq/benchmark/__pycache__/dummy_mt.cpython-310.pyc +0 -0
  31. fairseq/benchmark/benchmark_multihead_attention.py +172 -0
  32. fairseq/benchmark/dummy_dataset.py +36 -0
  33. fairseq/benchmark/dummy_lm.py +83 -0
  34. fairseq/benchmark/dummy_masked_lm.py +94 -0
  35. fairseq/benchmark/dummy_model.py +96 -0
  36. fairseq/benchmark/dummy_mt.py +119 -0
  37. fairseq/binarizer.py +381 -0
  38. fairseq/checkpoint_utils.py +936 -0
  39. fairseq/clib/cuda/ngram_repeat_block_cuda.cpp +55 -0
  40. fairseq/clib/cuda/ngram_repeat_block_cuda_kernel.cu +82 -0
  41. fairseq/clib/libbase/balanced_assignment.cpp +109 -0
  42. fairseq/clib/libbleu/libbleu.cpp +157 -0
  43. fairseq/clib/libbleu/module.cpp +33 -0
  44. fairseq/clib/libnat/edit_dist.cpp +231 -0
  45. fairseq/clib/libnat_cuda/binding.cpp +67 -0
  46. fairseq/clib/libnat_cuda/edit_dist.cu +344 -0
  47. fairseq/clib/libnat_cuda/edit_dist.h +25 -0
  48. fairseq/config/__init__.py +4 -0
  49. fairseq/config/config.yaml +19 -0
  50. fairseq/config/fb_run_config/slurm.yaml +29 -0
fairseq/__init__.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ """isort:skip_file"""
6
+
7
+ import os
8
+ import sys
9
+
10
+ try:
11
+ from .version import __version__ # noqa
12
+ except ImportError:
13
+ version_txt = os.path.join(os.path.dirname(__file__), "version.txt")
14
+ with open(version_txt) as f:
15
+ __version__ = f.read().strip()
16
+
17
+ __all__ = ["pdb"]
18
+
19
+ # backwards compatibility to support `from fairseq.X import Y`
20
+ from fairseq.distributed import utils as distributed_utils
21
+ from fairseq.logging import meters, metrics, progress_bar # noqa
22
+
23
+ sys.modules["fairseq.distributed_utils"] = distributed_utils
24
+ sys.modules["fairseq.meters"] = meters
25
+ sys.modules["fairseq.metrics"] = metrics
26
+ sys.modules["fairseq.progress_bar"] = progress_bar
27
+
28
+ # initialize hydra
29
+ from fairseq.dataclass.initialize import hydra_init
30
+
31
+ hydra_init()
32
+
33
+ import fairseq.criterions # noqa
34
+ import fairseq.distributed # noqa
35
+ import fairseq.models # noqa
36
+ import fairseq.modules # noqa
37
+ import fairseq.optim # noqa
38
+ import fairseq.optim.lr_scheduler # noqa
39
+ import fairseq.pdb # noqa
40
+ import fairseq.scoring # noqa
41
+ import fairseq.tasks # noqa
42
+ import fairseq.token_generation_constraints # noqa
43
+
44
+ import fairseq.benchmark # noqa
45
+ import fairseq.model_parallel # noqa
fairseq/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (1.2 kB). View file
 
fairseq/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (2.01 kB). View file
 
fairseq/__pycache__/checkpoint_utils.cpython-310.pyc ADDED
Binary file (22 kB). View file
 
fairseq/__pycache__/file_chunker_utils.cpython-310.pyc ADDED
Binary file (2.77 kB). View file
 
fairseq/__pycache__/file_io.cpython-310.pyc ADDED
Binary file (5.11 kB). View file
 
fairseq/__pycache__/file_utils.cpython-310.pyc ADDED
Binary file (9.1 kB). View file
 
fairseq/__pycache__/hub_utils.cpython-310.pyc ADDED
Binary file (10.9 kB). View file
 
fairseq/__pycache__/incremental_decoding_utils.cpython-310.pyc ADDED
Binary file (2.21 kB). View file
 
fairseq/__pycache__/iterative_refinement_generator.cpython-310.pyc ADDED
Binary file (8.72 kB). View file
 
fairseq/__pycache__/ngram_repeat_block.cpython-310.pyc ADDED
Binary file (3.78 kB). View file
 
fairseq/__pycache__/options.cpython-310.pyc ADDED
Binary file (10.6 kB). View file
 
fairseq/__pycache__/pdb.cpython-310.pyc ADDED
Binary file (1.32 kB). View file
 
fairseq/__pycache__/quantization_utils.cpython-310.pyc ADDED
Binary file (3.57 kB). View file
 
fairseq/__pycache__/registry.cpython-310.pyc ADDED
Binary file (2.52 kB). View file
 
fairseq/__pycache__/search.cpython-310.pyc ADDED
Binary file (24.3 kB). View file
 
fairseq/__pycache__/sequence_generator.cpython-310.pyc ADDED
Binary file (24 kB). View file
 
fairseq/__pycache__/speech_generator.cpython-310.pyc ADDED
Binary file (10.7 kB). View file
 
fairseq/__pycache__/token_generation_constraints.cpython-310.pyc ADDED
Binary file (16.2 kB). View file
 
fairseq/__pycache__/tokenizer.cpython-310.pyc ADDED
Binary file (367 Bytes). View file
 
fairseq/__pycache__/utils.cpython-310.pyc ADDED
Binary file (29.8 kB). View file
 
fairseq/__pycache__/version.cpython-310.pyc ADDED
Binary file (166 Bytes). View file
 
fairseq/__pycache__/version.cpython-311.pyc ADDED
Binary file (238 Bytes). View file
 
fairseq/benchmark/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ # import models/tasks to register them
7
+ from . import dummy_dataset, dummy_lm, dummy_masked_lm, dummy_model, dummy_mt # noqa
fairseq/benchmark/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (279 Bytes). View file
 
fairseq/benchmark/__pycache__/dummy_dataset.cpython-310.pyc ADDED
Binary file (1.74 kB). View file
 
fairseq/benchmark/__pycache__/dummy_lm.cpython-310.pyc ADDED
Binary file (3.12 kB). View file
 
fairseq/benchmark/__pycache__/dummy_masked_lm.cpython-310.pyc ADDED
Binary file (3.33 kB). View file
 
fairseq/benchmark/__pycache__/dummy_model.cpython-310.pyc ADDED
Binary file (3.42 kB). View file
 
fairseq/benchmark/__pycache__/dummy_mt.cpython-310.pyc ADDED
Binary file (4.5 kB). View file
 
fairseq/benchmark/benchmark_multihead_attention.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import itertools
7
+ import random
8
+
9
+ import torch
10
+ from torch.utils import benchmark
11
+
12
+ from fairseq.modules.multihead_attention import MultiheadAttention
13
+
14
+ BATCH = [20, 41, 97]
15
+ SEQ = 64
16
+ EMB = 48
17
+ HEADS = 4
18
+ DROP = 0.1
19
+ DEVICE = torch.device("cuda")
20
+ ATTN_MASK_DTYPE = [torch.uint8, torch.bool, torch.float]
21
+ KEY_PADDING_MASK_DTYPE = [torch.uint8, torch.bool]
22
+
23
+
24
+ def _reset_seeds():
25
+ torch.manual_seed(0)
26
+ random.seed(0)
27
+
28
+
29
+ def _get_mask(to_dtype: torch.dtype, dim0: int, dim1: int):
30
+ if to_dtype == torch.float:
31
+ mask = torch.randint(0, 2, (dim0, dim1)).to(dtype=torch.bool)
32
+ return mask.to(dtype=to_dtype).masked_fill(mask, -float("inf"))
33
+ return torch.randint(0, 2, (dim0, dim1)).to(dtype=to_dtype)
34
+
35
+
36
+ def benchmark_multihead_attention(
37
+ label="",
38
+ attn_dtype=torch.uint8,
39
+ key_padding_dtype=torch.uint8,
40
+ add_bias_kv=False,
41
+ add_zero_attn=False,
42
+ static_kv=False,
43
+ batch_size=20,
44
+ embedding=EMB,
45
+ seq_len=SEQ,
46
+ num_heads=HEADS,
47
+ ):
48
+
49
+ results = []
50
+ # device = torch.device("cuda")
51
+
52
+ xformers_att_config = '{"name": "scaled_dot_product"}'
53
+
54
+ attn_mask = _get_mask(to_dtype=attn_dtype, dim0=seq_len, dim1=seq_len)
55
+ key_padding_mask = _get_mask(
56
+ to_dtype=key_padding_dtype, dim0=batch_size, dim1=seq_len
57
+ )
58
+
59
+ q = torch.rand(seq_len, batch_size, embedding, requires_grad=True)
60
+ k = torch.rand(seq_len, batch_size, embedding, requires_grad=True)
61
+ v = torch.rand(seq_len, batch_size, embedding, requires_grad=True)
62
+
63
+ _reset_seeds()
64
+
65
+ original_mha = MultiheadAttention(
66
+ embedding,
67
+ num_heads,
68
+ dropout=0.0,
69
+ xformers_att_config=None,
70
+ add_bias_kv=add_bias_kv,
71
+ add_zero_attn=add_zero_attn,
72
+ )
73
+
74
+ xformers_mha = MultiheadAttention(
75
+ embedding,
76
+ num_heads,
77
+ dropout=0.0,
78
+ xformers_att_config=xformers_att_config,
79
+ add_bias_kv=add_bias_kv,
80
+ add_zero_attn=add_zero_attn,
81
+ )
82
+
83
+ def original_bench_fw(q, k, v, key_padding_mask, attn_mask, static_kv):
84
+ original_mha(
85
+ query=q,
86
+ key=k,
87
+ value=v,
88
+ key_padding_mask=key_padding_mask,
89
+ attn_mask=attn_mask,
90
+ static_kv=static_kv,
91
+ )
92
+
93
+ def xformers_bench_fw(q, k, v, key_padding_mask, attn_mask, static_kv):
94
+ xformers_mha(
95
+ query=q,
96
+ key=k,
97
+ value=v,
98
+ key_padding_mask=key_padding_mask,
99
+ attn_mask=attn_mask,
100
+ static_kv=static_kv,
101
+ )
102
+
103
+ def original_bench_fw_bw(q, k, v, key_padding_mask, attn_mask, static_kv):
104
+ output, _ = original_mha(
105
+ query=q,
106
+ key=k,
107
+ value=v,
108
+ key_padding_mask=key_padding_mask,
109
+ attn_mask=attn_mask,
110
+ static_kv=static_kv,
111
+ )
112
+ loss = torch.norm(output)
113
+ loss.backward()
114
+
115
+ def xformers_bench_fw_bw(q, k, v, key_padding_mask, attn_mask, static_kv):
116
+ output, _ = xformers_mha(
117
+ query=q,
118
+ key=k,
119
+ value=v,
120
+ key_padding_mask=key_padding_mask,
121
+ attn_mask=attn_mask,
122
+ static_kv=static_kv,
123
+ )
124
+ loss = torch.norm(output)
125
+ loss.backward()
126
+
127
+ fns = [
128
+ original_bench_fw,
129
+ xformers_bench_fw,
130
+ original_bench_fw_bw,
131
+ xformers_bench_fw_bw,
132
+ ]
133
+
134
+ for fn in fns:
135
+ results.append(
136
+ benchmark.Timer(
137
+ stmt="fn(q, k, v, key_padding_mask, attn_mask, static_kv)",
138
+ globals={
139
+ "q": q,
140
+ "k": k,
141
+ "v": v,
142
+ "key_padding_mask": key_padding_mask,
143
+ "attn_mask": attn_mask,
144
+ "static_kv": static_kv,
145
+ "fn": fn,
146
+ },
147
+ label="multihead fw + bw",
148
+ sub_label=f"{fn.__name__}",
149
+ description=label,
150
+ ).blocked_autorange(min_run_time=1)
151
+ )
152
+
153
+ compare = benchmark.Compare(results)
154
+ compare.print()
155
+
156
+
157
+ def run_benchmarks():
158
+ for attn_dtype, key_padding_dtype, add_bias_kv, add_zero_attn in itertools.product(
159
+ ATTN_MASK_DTYPE, KEY_PADDING_MASK_DTYPE, [True, False], [True, False]
160
+ ):
161
+ label = f"attn_dtype {attn_dtype}, key_padding_dtype {key_padding_dtype}, \
162
+ add_bias_kv {add_bias_kv}, add_zero_attn {add_zero_attn}"
163
+ benchmark_multihead_attention(
164
+ label=label,
165
+ attn_dtype=attn_dtype,
166
+ key_padding_dtype=key_padding_dtype,
167
+ add_bias_kv=add_bias_kv,
168
+ add_zero_attn=add_zero_attn,
169
+ )
170
+
171
+
172
+ run_benchmarks()
fairseq/benchmark/dummy_dataset.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from fairseq.data import FairseqDataset
3
+
4
+
5
+ class DummyDataset(FairseqDataset):
6
+ def __init__(self, batch, num_items, item_size):
7
+ super().__init__()
8
+ self.batch = batch
9
+ self.num_items = num_items
10
+ self.item_size = item_size
11
+
12
+ def __getitem__(self, index):
13
+ return index
14
+
15
+ def __len__(self):
16
+ return self.num_items
17
+
18
+ def collater(self, samples):
19
+ return self.batch
20
+
21
+ @property
22
+ def sizes(self):
23
+ return np.array([self.item_size] * self.num_items)
24
+
25
+ def num_tokens(self, index):
26
+ return self.item_size
27
+
28
+ def size(self, index):
29
+ return self.item_size
30
+
31
+ def ordered_indices(self):
32
+ return np.arange(self.num_items)
33
+
34
+ @property
35
+ def supports_prefetch(self):
36
+ return False
fairseq/benchmark/dummy_lm.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+ from dataclasses import dataclass, field
8
+ from typing import Optional
9
+
10
+ import torch
11
+ from .dummy_dataset import DummyDataset
12
+ from fairseq.data import Dictionary
13
+ from fairseq.dataclass import FairseqDataclass
14
+ from fairseq.tasks import FairseqTask, register_task
15
+ from omegaconf import II
16
+
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ @dataclass
22
+ class DummyLMConfig(FairseqDataclass):
23
+ dict_size: int = 49996
24
+ dataset_size: int = 100000
25
+ tokens_per_sample: int = field(
26
+ default=512, metadata={"help": "max sequence length"}
27
+ )
28
+ add_bos_token: bool = False
29
+ batch_size: Optional[int] = II("dataset.batch_size")
30
+ max_tokens: Optional[int] = II("dataset.max_tokens")
31
+ max_target_positions: int = II("task.tokens_per_sample")
32
+
33
+
34
+ @register_task("dummy_lm", dataclass=DummyLMConfig)
35
+ class DummyLMTask(FairseqTask):
36
+ def __init__(self, cfg: DummyLMConfig):
37
+ super().__init__(cfg)
38
+
39
+ # load dictionary
40
+ self.dictionary = Dictionary()
41
+ for i in range(cfg.dict_size):
42
+ self.dictionary.add_symbol("word{}".format(i))
43
+ self.dictionary.pad_to_multiple_(8) # often faster if divisible by 8
44
+ logger.info("dictionary: {} types".format(len(self.dictionary)))
45
+
46
+ seq = torch.arange(cfg.tokens_per_sample + 1) + self.dictionary.pad() + 1
47
+
48
+ self.dummy_src = seq[:-1]
49
+ self.dummy_tgt = seq[1:]
50
+
51
+ def load_dataset(self, split, epoch=1, combine=False, **kwargs):
52
+ """Load a given dataset split.
53
+ Args:
54
+ split (str): name of the split (e.g., train, valid, test)
55
+ """
56
+ if self.cfg.batch_size is not None:
57
+ bsz = self.cfg.batch_size
58
+ else:
59
+ bsz = max(1, self.cfg.max_tokens // self.cfg.tokens_per_sample)
60
+ self.datasets[split] = DummyDataset(
61
+ {
62
+ "id": 1,
63
+ "net_input": {
64
+ "src_tokens": torch.stack([self.dummy_src for _ in range(bsz)]),
65
+ "src_lengths": torch.full(
66
+ (bsz,), self.cfg.tokens_per_sample, dtype=torch.long
67
+ ),
68
+ },
69
+ "target": torch.stack([self.dummy_tgt for _ in range(bsz)]),
70
+ "nsentences": bsz,
71
+ "ntokens": bsz * self.cfg.tokens_per_sample,
72
+ },
73
+ num_items=self.cfg.dataset_size,
74
+ item_size=self.cfg.tokens_per_sample,
75
+ )
76
+
77
+ @property
78
+ def source_dictionary(self):
79
+ return self.dictionary
80
+
81
+ @property
82
+ def target_dictionary(self):
83
+ return self.dictionary
fairseq/benchmark/dummy_masked_lm.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+ from dataclasses import dataclass, field
8
+ from typing import Optional
9
+
10
+ import torch
11
+ from omegaconf import II
12
+
13
+ from .dummy_dataset import DummyDataset
14
+ from fairseq.data import Dictionary
15
+ from fairseq.dataclass import FairseqDataclass
16
+ from fairseq.tasks import FairseqTask, register_task
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ @dataclass
22
+ class DummyMaskedLMConfig(FairseqDataclass):
23
+ dict_size: int = 49996
24
+ dataset_size: int = 100000
25
+ tokens_per_sample: int = field(
26
+ default=512,
27
+ metadata={
28
+ "help": "max number of total tokens over all"
29
+ " segments per sample for BERT dataset"
30
+ },
31
+ )
32
+ batch_size: Optional[int] = II("dataset.batch_size")
33
+ max_tokens: Optional[int] = II("dataset.max_tokens")
34
+ max_target_positions: int = II("task.tokens_per_sample")
35
+
36
+
37
+ @register_task("dummy_masked_lm", dataclass=DummyMaskedLMConfig)
38
+ class DummyMaskedLMTask(FairseqTask):
39
+ def __init__(self, cfg: DummyMaskedLMConfig):
40
+ super().__init__(cfg)
41
+
42
+ self.dictionary = Dictionary()
43
+ for i in range(cfg.dict_size):
44
+ self.dictionary.add_symbol("word{}".format(i))
45
+ logger.info("dictionary: {} types".format(len(self.dictionary)))
46
+ # add mask token
47
+ self.mask_idx = self.dictionary.add_symbol("<mask>")
48
+ self.dictionary.pad_to_multiple_(8) # often faster if divisible by 8
49
+
50
+ mask_idx = 0
51
+ pad_idx = 1
52
+ seq = torch.arange(cfg.tokens_per_sample) + pad_idx + 1
53
+ mask = torch.arange(2, cfg.tokens_per_sample, 7) # ~15%
54
+ src = seq.clone()
55
+ src[mask] = mask_idx
56
+ tgt = torch.full_like(seq, pad_idx)
57
+ tgt[mask] = seq[mask]
58
+
59
+ self.dummy_src = src
60
+ self.dummy_tgt = tgt
61
+
62
+ def load_dataset(self, split, epoch=1, combine=False, **kwargs):
63
+ """Load a given dataset split.
64
+ Args:
65
+ split (str): name of the split (e.g., train, valid, test)
66
+ """
67
+ if self.cfg.batch_size is not None:
68
+ bsz = self.cfg.batch_size
69
+ else:
70
+ bsz = max(1, self.cfg.max_tokens // self.cfg.tokens_per_sample)
71
+ self.datasets[split] = DummyDataset(
72
+ {
73
+ "id": 1,
74
+ "net_input": {
75
+ "src_tokens": torch.stack([self.dummy_src for _ in range(bsz)]),
76
+ "src_lengths": torch.full(
77
+ (bsz,), self.cfg.tokens_per_sample, dtype=torch.long
78
+ ),
79
+ },
80
+ "target": torch.stack([self.dummy_tgt for _ in range(bsz)]),
81
+ "nsentences": bsz,
82
+ "ntokens": bsz * self.cfg.tokens_per_sample,
83
+ },
84
+ num_items=self.cfg.dataset_size,
85
+ item_size=self.cfg.tokens_per_sample,
86
+ )
87
+
88
+ @property
89
+ def source_dictionary(self):
90
+ return self.dictionary
91
+
92
+ @property
93
+ def target_dictionary(self):
94
+ return self.dictionary
fairseq/benchmark/dummy_model.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from fairseq.data import Dictionary
9
+ from fairseq.models import (
10
+ FairseqDecoder,
11
+ FairseqLanguageModel,
12
+ register_model,
13
+ register_model_architecture,
14
+ )
15
+
16
+
17
+ @register_model("dummy_model")
18
+ class DummyModel(FairseqLanguageModel):
19
+ def __init__(self, args, encoder):
20
+ super().__init__(encoder)
21
+ self.args = args
22
+
23
+ @staticmethod
24
+ def add_args(parser):
25
+ parser.add_argument("--num-layers", type=int, default=24)
26
+ parser.add_argument("--embed-dim", type=int, default=1024)
27
+
28
+ @classmethod
29
+ def build_model(cls, args, task):
30
+ encoder = DummyEncoder(
31
+ num_embed=len(task.target_dictionary),
32
+ embed_dim=args.embed_dim,
33
+ num_layers=args.num_layers,
34
+ )
35
+ return cls(args, encoder)
36
+
37
+ def forward(self, src_tokens, masked_tokens=None, **kwargs):
38
+ return self.decoder(src_tokens, masked_tokens=masked_tokens)
39
+
40
+
41
+ class DummyEncoder(FairseqDecoder):
42
+ def __init__(self, num_embed=50000, embed_dim=1024, num_layers=24):
43
+ super().__init__(Dictionary())
44
+ self.embed = nn.Embedding(
45
+ num_embeddings=num_embed, embedding_dim=embed_dim, padding_idx=0
46
+ )
47
+ self.layers_a = nn.ModuleList(
48
+ [
49
+ nn.Sequential(
50
+ nn.LayerNorm(embed_dim),
51
+ nn.Linear(embed_dim, 3 * embed_dim), # q, k, v input projection
52
+ nn.Linear(3 * embed_dim, embed_dim), # skip self-attention
53
+ nn.Linear(embed_dim, embed_dim), # output projection
54
+ nn.Dropout(),
55
+ )
56
+ for i in range(num_layers)
57
+ ]
58
+ )
59
+ self.layers_b = nn.ModuleList(
60
+ [
61
+ nn.Sequential(
62
+ nn.LayerNorm(embed_dim),
63
+ nn.Linear(embed_dim, 4 * embed_dim), # FFN
64
+ nn.ReLU(),
65
+ nn.Linear(4 * embed_dim, embed_dim), # FFN
66
+ nn.Dropout(0.1),
67
+ )
68
+ for i in range(num_layers)
69
+ ]
70
+ )
71
+ self.out_proj = nn.Linear(embed_dim, num_embed)
72
+
73
+ def forward(self, tokens, masked_tokens=None):
74
+ x = self.embed(tokens)
75
+ for layer_a, layer_b in zip(self.layers_a, self.layers_b):
76
+ x = x + layer_a(x)
77
+ x = x + layer_b(x)
78
+ x = self.out_proj(x)
79
+ if masked_tokens is not None:
80
+ x = x[masked_tokens]
81
+ return (x,)
82
+
83
+ def max_positions(self):
84
+ return 1024
85
+
86
+ def get_normalized_probs(self, net_output, log_probs, sample=None):
87
+ logits = net_output[0].float()
88
+ if log_probs:
89
+ return F.log_softmax(logits, dim=-1)
90
+ else:
91
+ return F.softmax(logits, dim=-1)
92
+
93
+
94
+ @register_model_architecture("dummy_model", "dummy_model")
95
+ def base_architecture(args):
96
+ pass
fairseq/benchmark/dummy_mt.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+
8
+ import numpy as np
9
+ import torch
10
+
11
+ from fairseq.data import Dictionary, FairseqDataset
12
+ from fairseq.tasks import LegacyFairseqTask, register_task
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ @register_task("dummy_mt")
18
+ class DummyMTTask(LegacyFairseqTask):
19
+ @staticmethod
20
+ def add_args(parser):
21
+ """Add task-specific arguments to the parser."""
22
+ parser.add_argument("--dict-size", default=49996, type=int)
23
+ parser.add_argument("--dataset-size", default=100000, type=int)
24
+ parser.add_argument("--src-len", default=30, type=int)
25
+ parser.add_argument("--tgt-len", default=30, type=int)
26
+
27
+ def __init__(self, args, dictionary):
28
+ super().__init__(args)
29
+ self.dictionary = dictionary
30
+ self.seed = args.seed
31
+
32
+ dictionary.pad_to_multiple_(8) # often faster if divisible by 8
33
+
34
+ self.dummy_src = torch.arange(args.src_len + 1) + dictionary.pad() + 1
35
+ self.dummy_tgt = torch.arange(args.tgt_len + 1) + dictionary.pad() + 1
36
+
37
+ @classmethod
38
+ def setup_task(cls, args, **kwargs):
39
+ """Setup the task."""
40
+ dictionary = Dictionary()
41
+ for i in range(args.dict_size):
42
+ dictionary.add_symbol("word{}".format(i))
43
+ logger.info("dictionary: {} types".format(len(dictionary)))
44
+
45
+ args.max_source_positions = args.src_len + dictionary.pad() + 2
46
+ args.max_target_positions = args.tgt_len + dictionary.pad() + 2
47
+
48
+ return cls(args, dictionary)
49
+
50
+ def load_dataset(self, split, epoch=1, combine=False, **kwargs):
51
+ """Load a given dataset split.
52
+ Args:
53
+ split (str): name of the split (e.g., train, valid, test)
54
+ """
55
+ item_size = max(self.args.src_len, self.args.tgt_len)
56
+ if self.args.batch_size is not None:
57
+ bsz = self.args.batch_size
58
+ else:
59
+ bsz = max(1, self.args.max_tokens // item_size)
60
+ tgt = torch.stack([self.dummy_tgt for _ in range(bsz)])
61
+ self.datasets[split] = DummyDataset(
62
+ {
63
+ "id": 1,
64
+ "net_input": {
65
+ "src_tokens": torch.stack([self.dummy_src for _ in range(bsz)]),
66
+ "src_lengths": torch.full(
67
+ (bsz,), self.args.src_len, dtype=torch.long
68
+ ),
69
+ "prev_output_tokens": tgt.clone(),
70
+ },
71
+ "target": tgt,
72
+ "nsentences": bsz,
73
+ "ntokens": bsz * self.args.tgt_len,
74
+ },
75
+ num_items=self.args.dataset_size,
76
+ item_size=item_size,
77
+ )
78
+
79
+ @property
80
+ def source_dictionary(self):
81
+ return self.dictionary
82
+
83
+ @property
84
+ def target_dictionary(self):
85
+ return self.dictionary
86
+
87
+
88
+ class DummyDataset(FairseqDataset):
89
+ def __init__(self, batch, num_items, item_size):
90
+ super().__init__()
91
+ self.batch = batch
92
+ self.num_items = num_items
93
+ self.item_size = item_size
94
+
95
+ def __getitem__(self, index):
96
+ return index
97
+
98
+ def __len__(self):
99
+ return self.num_items
100
+
101
+ def collater(self, samples):
102
+ return self.batch
103
+
104
+ @property
105
+ def sizes(self):
106
+ return np.array([self.item_size] * self.num_items)
107
+
108
+ def num_tokens(self, index):
109
+ return self.item_size
110
+
111
+ def size(self, index):
112
+ return self.item_size
113
+
114
+ def ordered_indices(self):
115
+ return np.arange(self.num_items)
116
+
117
+ @property
118
+ def supports_prefetch(self):
119
+ return False
fairseq/binarizer.py ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+ import os
8
+ import typing as tp
9
+ from abc import ABC, abstractmethod
10
+ from collections import Counter
11
+ from dataclasses import dataclass
12
+ from multiprocessing import Pool
13
+
14
+ import torch
15
+
16
+ from fairseq.data import Dictionary, indexed_dataset
17
+ from fairseq.file_chunker_utils import Chunker, find_offsets
18
+ from fairseq.file_io import PathManager
19
+ from fairseq.tokenizer import tokenize_line
20
+
21
+ logger = logging.getLogger("binarizer")
22
+
23
+
24
+ @dataclass
25
+ class BinarizeSummary:
26
+ """
27
+ Keep track of what's going on in the binarizer
28
+ """
29
+
30
+ num_seq: int = 0
31
+ replaced: tp.Optional[Counter] = None
32
+ num_tok: int = 0
33
+
34
+ @property
35
+ def num_replaced(self) -> int:
36
+ if self.replaced is None:
37
+ return 0
38
+ return sum(self.replaced.values())
39
+
40
+ @property
41
+ def replaced_percent(self) -> float:
42
+ return 100 * self.num_replaced / self.num_tok
43
+
44
+ def __str__(self) -> str:
45
+ base = f"{self.num_seq} sents, {self.num_tok} tokens"
46
+ if self.replaced is None:
47
+ return base
48
+
49
+ return f"{base}, {self.replaced_percent:.3}% replaced"
50
+
51
+ def merge(self, other: "BinarizeSummary"):
52
+ replaced = None
53
+ if self.replaced is not None:
54
+ replaced = self.replaced
55
+ if other.replaced is not None:
56
+ if replaced is None:
57
+ replaced = other.replaced
58
+ else:
59
+ replaced += other.replaced
60
+ self.replaced = replaced
61
+ self.num_seq += other.num_seq
62
+ self.num_tok += other.num_tok
63
+
64
+
65
+ class Binarizer(ABC):
66
+ """
67
+ a binarizer describes how to take a string and build a tensor out of it
68
+ """
69
+
70
+ @abstractmethod
71
+ def binarize_line(
72
+ self,
73
+ line: str,
74
+ summary: BinarizeSummary,
75
+ ) -> torch.IntTensor:
76
+ ...
77
+
78
+
79
+ def _worker_prefix(output_prefix: str, worker_id: int):
80
+ return f"{output_prefix}.pt{worker_id}"
81
+
82
+
83
+ class FileBinarizer:
84
+ """
85
+ An file binarizer can take a file, tokenize it, and binarize each line to a tensor
86
+ """
87
+
88
+ @classmethod
89
+ def multiprocess_dataset(
90
+ cls,
91
+ input_file: str,
92
+ dataset_impl: str,
93
+ binarizer: Binarizer,
94
+ output_prefix: str,
95
+ vocab_size=None,
96
+ num_workers=1,
97
+ ) -> BinarizeSummary:
98
+ final_summary = BinarizeSummary()
99
+
100
+ offsets = find_offsets(input_file, num_workers)
101
+ # find_offsets returns a list of position [pos1, pos2, pos3, pos4] but we would want pairs:
102
+ # [(pos1, pos2), (pos2, pos3), (pos3, pos4)] to process the chunks with start/end info
103
+ # we zip the list with itself shifted by one to get all the pairs.
104
+ (first_chunk, *more_chunks) = zip(offsets, offsets[1:])
105
+ pool = None
106
+ if num_workers > 1:
107
+ pool = Pool(processes=num_workers - 1)
108
+ worker_results = [
109
+ pool.apply_async(
110
+ cls._binarize_chunk_and_finalize,
111
+ args=(
112
+ binarizer,
113
+ input_file,
114
+ start_offset,
115
+ end_offset,
116
+ _worker_prefix(
117
+ output_prefix,
118
+ worker_id,
119
+ ),
120
+ dataset_impl,
121
+ ),
122
+ kwds={
123
+ "vocab_size": vocab_size,
124
+ }
125
+ if vocab_size is not None
126
+ else {},
127
+ )
128
+ for worker_id, (start_offset, end_offset) in enumerate(
129
+ more_chunks, start=1
130
+ )
131
+ ]
132
+
133
+ pool.close()
134
+ pool.join()
135
+ for r in worker_results:
136
+ summ = r.get()
137
+ final_summary.merge(summ)
138
+
139
+ # do not close the bin file as we need to merge the worker results in
140
+ final_ds, summ = cls._binarize_file_chunk(
141
+ binarizer,
142
+ input_file,
143
+ offset_start=first_chunk[0],
144
+ offset_end=first_chunk[1],
145
+ output_prefix=output_prefix,
146
+ dataset_impl=dataset_impl,
147
+ vocab_size=vocab_size if vocab_size is not None else None,
148
+ )
149
+ final_summary.merge(summ)
150
+
151
+ if num_workers > 1:
152
+ for worker_id in range(1, num_workers):
153
+ # merge the worker outputs
154
+ worker_output_prefix = _worker_prefix(
155
+ output_prefix,
156
+ worker_id,
157
+ )
158
+ final_ds.merge_file_(worker_output_prefix)
159
+ try:
160
+ os.remove(indexed_dataset.data_file_path(worker_output_prefix))
161
+ os.remove(indexed_dataset.index_file_path(worker_output_prefix))
162
+ except Exception as e:
163
+ logger.error(
164
+ f"couldn't remove {worker_output_prefix}.*", exc_info=e
165
+ )
166
+
167
+ # now we can close the file
168
+ idx_file = indexed_dataset.index_file_path(output_prefix)
169
+ final_ds.finalize(idx_file)
170
+ return final_summary
171
+
172
+ @staticmethod
173
+ def _binarize_file_chunk(
174
+ binarizer: Binarizer,
175
+ filename: str,
176
+ offset_start: int,
177
+ offset_end: int,
178
+ output_prefix: str,
179
+ dataset_impl: str,
180
+ vocab_size=None,
181
+ ) -> tp.Tuple[tp.Any, BinarizeSummary]: # (dataset builder, BinarizeSummary)
182
+ """
183
+ creates a dataset builder and append binarized items to it. This function does not
184
+ finalize the builder, this is useful if you want to do other things with your bin file
185
+ like appending/merging other files
186
+ """
187
+ bin_file = indexed_dataset.data_file_path(output_prefix)
188
+ ds = indexed_dataset.make_builder(
189
+ bin_file,
190
+ impl=dataset_impl,
191
+ vocab_size=vocab_size,
192
+ )
193
+ summary = BinarizeSummary()
194
+
195
+ with Chunker(
196
+ PathManager.get_local_path(filename), offset_start, offset_end
197
+ ) as line_iterator:
198
+ for line in line_iterator:
199
+ ds.add_item(binarizer.binarize_line(line, summary))
200
+
201
+ return ds, summary
202
+
203
+ @classmethod
204
+ def _binarize_chunk_and_finalize(
205
+ cls,
206
+ binarizer: Binarizer,
207
+ filename: str,
208
+ offset_start: int,
209
+ offset_end: int,
210
+ output_prefix: str,
211
+ dataset_impl: str,
212
+ vocab_size=None,
213
+ ):
214
+ """
215
+ same as above, but also finalizes the builder
216
+ """
217
+ ds, summ = cls._binarize_file_chunk(
218
+ binarizer,
219
+ filename,
220
+ offset_start,
221
+ offset_end,
222
+ output_prefix,
223
+ dataset_impl,
224
+ vocab_size=vocab_size,
225
+ )
226
+
227
+ idx_file = indexed_dataset.index_file_path(output_prefix)
228
+ ds.finalize(idx_file)
229
+
230
+ return summ
231
+
232
+
233
+ class VocabularyDatasetBinarizer(Binarizer):
234
+ """
235
+ Takes a Dictionary/Vocabulary, assign ids to each
236
+ token using the dictionary encode_line function.
237
+ """
238
+
239
+ def __init__(
240
+ self,
241
+ dict: Dictionary,
242
+ tokenize: tp.Callable[[str], tp.List[str]] = tokenize_line,
243
+ append_eos: bool = True,
244
+ reverse_order: bool = False,
245
+ already_numberized: bool = False,
246
+ ) -> None:
247
+ self.dict = dict
248
+ self.tokenize = tokenize
249
+ self.append_eos = append_eos
250
+ self.reverse_order = reverse_order
251
+ self.already_numberized = already_numberized
252
+ super().__init__()
253
+
254
+ def binarize_line(
255
+ self,
256
+ line: str,
257
+ summary: BinarizeSummary,
258
+ ):
259
+ if summary.replaced is None:
260
+ summary.replaced = Counter()
261
+
262
+ def replaced_consumer(word, idx):
263
+ if idx == self.dict.unk_index and word != self.dict.unk_word:
264
+ summary.replaced.update([word])
265
+
266
+ if self.already_numberized:
267
+ id_strings = line.strip().split()
268
+ id_list = [int(id_string) for id_string in id_strings]
269
+ if self.reverse_order:
270
+ id_list.reverse()
271
+ if self.append_eos:
272
+ id_list.append(self.dict.eos())
273
+ ids = torch.IntTensor(id_list)
274
+ else:
275
+ ids = self.dict.encode_line(
276
+ line=line,
277
+ line_tokenizer=self.tokenize,
278
+ add_if_not_exist=False,
279
+ consumer=replaced_consumer,
280
+ append_eos=self.append_eos,
281
+ reverse_order=self.reverse_order,
282
+ )
283
+
284
+ summary.num_seq += 1
285
+ summary.num_tok += len(ids)
286
+ return ids
287
+
288
+
289
+ class AlignmentDatasetBinarizer(Binarizer):
290
+ """
291
+ binarize by parsing a set of alignments and packing
292
+ them in a tensor (see utils.parse_alignment)
293
+ """
294
+
295
+ def __init__(
296
+ self,
297
+ alignment_parser: tp.Callable[[str], torch.IntTensor],
298
+ ) -> None:
299
+ super().__init__()
300
+ self.alignment_parser = alignment_parser
301
+
302
+ def binarize_line(
303
+ self,
304
+ line: str,
305
+ summary: BinarizeSummary,
306
+ ):
307
+ ids = self.alignment_parser(line)
308
+ summary.num_seq += 1
309
+ summary.num_tok += len(ids)
310
+ return ids
311
+
312
+
313
+ class LegacyBinarizer:
314
+ @classmethod
315
+ def binarize(
316
+ cls,
317
+ filename: str,
318
+ dico: Dictionary,
319
+ consumer: tp.Callable[[torch.IntTensor], None],
320
+ tokenize: tp.Callable[[str], tp.List[str]] = tokenize_line,
321
+ append_eos: bool = True,
322
+ reverse_order: bool = False,
323
+ offset: int = 0,
324
+ end: int = -1,
325
+ already_numberized: bool = False,
326
+ ) -> tp.Dict[str, int]:
327
+ binarizer = VocabularyDatasetBinarizer(
328
+ dict=dico,
329
+ tokenize=tokenize,
330
+ append_eos=append_eos,
331
+ reverse_order=reverse_order,
332
+ already_numberized=already_numberized,
333
+ )
334
+ return cls._consume_file(
335
+ filename,
336
+ binarizer,
337
+ consumer,
338
+ offset_start=offset,
339
+ offset_end=end,
340
+ )
341
+
342
+ @classmethod
343
+ def binarize_alignments(
344
+ cls,
345
+ filename: str,
346
+ alignment_parser: tp.Callable[[str], torch.IntTensor],
347
+ consumer: tp.Callable[[torch.IntTensor], None],
348
+ offset: int = 0,
349
+ end: int = -1,
350
+ ) -> tp.Dict[str, int]:
351
+ binarizer = AlignmentDatasetBinarizer(alignment_parser)
352
+ return cls._consume_file(
353
+ filename,
354
+ binarizer,
355
+ consumer,
356
+ offset_start=offset,
357
+ offset_end=end,
358
+ )
359
+
360
+ @staticmethod
361
+ def _consume_file(
362
+ filename: str,
363
+ binarizer: Binarizer,
364
+ consumer: tp.Callable[[torch.IntTensor], None],
365
+ offset_start: int,
366
+ offset_end: int,
367
+ ) -> tp.Dict[str, int]:
368
+ summary = BinarizeSummary()
369
+
370
+ with Chunker(
371
+ PathManager.get_local_path(filename), offset_start, offset_end
372
+ ) as line_iterator:
373
+ for line in line_iterator:
374
+ consumer(binarizer.binarize_line(line, summary))
375
+
376
+ return {
377
+ "nseq": summary.num_seq,
378
+ "nunk": summary.num_replaced,
379
+ "ntok": summary.num_tok,
380
+ "replaced": summary.replaced,
381
+ }
fairseq/checkpoint_utils.py ADDED
@@ -0,0 +1,936 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import ast
7
+ import collections
8
+ import contextlib
9
+ import inspect
10
+ import logging
11
+ import os
12
+ import re
13
+ import time
14
+ import traceback
15
+ from collections import OrderedDict
16
+ from pathlib import Path
17
+ from typing import Any, Dict, Optional, Union
18
+
19
+ import numpy as np
20
+ import torch
21
+ from fairseq.data import data_utils
22
+ from fairseq.dataclass.configs import CheckpointConfig
23
+ from fairseq.dataclass.utils import (
24
+ convert_namespace_to_omegaconf,
25
+ overwrite_args_by_name,
26
+ )
27
+ from fairseq.distributed.fully_sharded_data_parallel import FSDP, has_FSDP
28
+ from fairseq.file_io import PathManager
29
+ from fairseq.models import FairseqDecoder, FairseqEncoder
30
+ from omegaconf import DictConfig, OmegaConf, open_dict
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss):
36
+ from fairseq import meters
37
+
38
+ # only one worker should attempt to create the required dir
39
+ if trainer.data_parallel_rank == 0:
40
+ os.makedirs(cfg.save_dir, exist_ok=True)
41
+
42
+ prev_best = getattr(save_checkpoint, "best", val_loss)
43
+ if val_loss is not None:
44
+ best_function = max if cfg.maximize_best_checkpoint_metric else min
45
+ save_checkpoint.best = best_function(val_loss, prev_best)
46
+
47
+ if cfg.no_save:
48
+ return None
49
+
50
+ trainer.consolidate_optimizer() # TODO(SS): do we need this if no_save_optimizer_state
51
+
52
+ if not trainer.should_save_checkpoint_on_current_rank:
53
+ if trainer.always_call_state_dict_during_save_checkpoint:
54
+ trainer.state_dict()
55
+ return None
56
+
57
+ write_timer = meters.StopwatchMeter()
58
+ write_timer.start()
59
+
60
+ epoch = epoch_itr.epoch
61
+ end_of_epoch = epoch_itr.end_of_epoch()
62
+ updates = trainer.get_num_updates()
63
+
64
+ logger.info(f"Preparing to save checkpoint for epoch {epoch} @ {updates} updates")
65
+
66
+ def is_better(a, b):
67
+ return a >= b if cfg.maximize_best_checkpoint_metric else a <= b
68
+
69
+ suffix = trainer.checkpoint_suffix
70
+ checkpoint_conds = collections.OrderedDict()
71
+ checkpoint_conds["checkpoint{}{}.pt".format(epoch, suffix)] = (
72
+ end_of_epoch and not cfg.no_epoch_checkpoints and epoch % cfg.save_interval == 0
73
+ )
74
+ checkpoint_conds["checkpoint_{}_{}{}.pt".format(epoch, updates, suffix)] = (
75
+ not end_of_epoch
76
+ and cfg.save_interval_updates > 0
77
+ and updates % cfg.save_interval_updates == 0
78
+ )
79
+ checkpoint_conds["checkpoint_best{}.pt".format(suffix)] = val_loss is not None and (
80
+ not hasattr(save_checkpoint, "best")
81
+ or is_better(val_loss, save_checkpoint.best)
82
+ )
83
+ if val_loss is not None and cfg.keep_best_checkpoints > 0:
84
+ worst_best = getattr(save_checkpoint, "best", None)
85
+ chkpts = checkpoint_paths(
86
+ cfg.save_dir,
87
+ pattern=r"checkpoint\.best_{}_(\d+\.?\d*){}\.pt".format(
88
+ cfg.best_checkpoint_metric, suffix
89
+ ),
90
+ )
91
+ if len(chkpts) > 0:
92
+ p = chkpts[-1] if cfg.maximize_best_checkpoint_metric else chkpts[0]
93
+ worst_best = float(p.rsplit("_")[-1].replace("{}.pt".format(suffix), ""))
94
+ # add random digits to resolve ties
95
+ with data_utils.numpy_seed(epoch, updates, val_loss):
96
+ rand_sfx = np.random.randint(0, cfg.keep_best_checkpoints)
97
+
98
+ checkpoint_conds[
99
+ "checkpoint.best_{}_{:.3f}{}{}.pt".format(
100
+ cfg.best_checkpoint_metric, val_loss, rand_sfx, suffix
101
+ )
102
+ ] = worst_best is None or is_better(val_loss, worst_best)
103
+ checkpoint_conds[
104
+ "checkpoint_last{}.pt".format(suffix)
105
+ ] = not cfg.no_last_checkpoints
106
+
107
+ extra_state = {
108
+ "train_iterator": epoch_itr.state_dict(),
109
+ "val_loss": val_loss,
110
+ }
111
+
112
+ # Going forward, different tasks could expose an API like this to dump all
113
+ # the checkpoint worthy attributes in a dictionary which then will be
114
+ # merged with the parent dictionary to create the "extra_state". This
115
+ # allows for an extensible yet simple design to checkpoint task level
116
+ # attributes
117
+ if hasattr(trainer.task, "get_checkpoint_dict"):
118
+ extra_state = {**extra_state, **trainer.task.get_checkpoint_dict()}
119
+ logger.info(f"State of {trainer.task.__class__.__name__} is ready to be persisted with the checkpoint")
120
+
121
+ if hasattr(save_checkpoint, "best"):
122
+ extra_state.update({"best": save_checkpoint.best})
123
+
124
+ checkpoints = [
125
+ os.path.join(cfg.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond
126
+ ]
127
+ saved_cp = None
128
+ if len(checkpoints) > 0 and trainer.should_save_checkpoint_on_current_rank:
129
+ saved_cp = trainer.save_checkpoint(checkpoints[0], extra_state)
130
+ for cp in checkpoints[1:]:
131
+ if cfg.write_checkpoints_asynchronously:
132
+ # TODO[ioPath]: Need to implement a delayed asynchronous
133
+ # file copying/moving feature.
134
+ logger.warning(
135
+ f"ioPath is not copying {checkpoints[0]} to {cp} "
136
+ "since async write mode is on."
137
+ )
138
+ else:
139
+ assert PathManager.copy(
140
+ checkpoints[0], cp, overwrite=True
141
+ ), f"Failed to copy {checkpoints[0]} to {cp}"
142
+
143
+ write_timer.stop()
144
+ logger.info(
145
+ "Saved checkpoint {} (epoch {} @ {} updates, score {}) (writing took {} seconds)".format(
146
+ checkpoints[0], epoch, updates, val_loss, write_timer.sum
147
+ )
148
+ )
149
+
150
+ if (
151
+ not end_of_epoch
152
+ and cfg.keep_interval_updates > 0
153
+ and trainer.should_save_checkpoint_on_current_rank
154
+ ):
155
+ # remove old checkpoints; checkpoints are sorted in descending order
156
+ if cfg.keep_interval_updates_pattern == -1:
157
+ checkpoints = checkpoint_paths(
158
+ cfg.save_dir, pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix)
159
+ )
160
+ else:
161
+ checkpoints = checkpoint_paths(
162
+ cfg.save_dir,
163
+ pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix),
164
+ keep_match=True,
165
+ )
166
+ checkpoints = [
167
+ x[0]
168
+ for x in checkpoints
169
+ if x[1] % cfg.keep_interval_updates_pattern != 0
170
+ ]
171
+
172
+ for old_chk in checkpoints[cfg.keep_interval_updates :]:
173
+ if os.path.lexists(old_chk):
174
+ os.remove(old_chk)
175
+ elif PathManager.exists(old_chk):
176
+ PathManager.rm(old_chk)
177
+
178
+ if cfg.keep_last_epochs > 0 and trainer.should_save_checkpoint_on_current_rank:
179
+ # remove old epoch checkpoints; checkpoints are sorted in descending order
180
+ checkpoints = checkpoint_paths(
181
+ cfg.save_dir, pattern=r"checkpoint(\d+){}\.pt".format(suffix)
182
+ )
183
+ for old_chk in checkpoints[cfg.keep_last_epochs :]:
184
+ if os.path.lexists(old_chk):
185
+ os.remove(old_chk)
186
+ elif PathManager.exists(old_chk):
187
+ PathManager.rm(old_chk)
188
+
189
+ if cfg.keep_best_checkpoints > 0 and trainer.should_save_checkpoint_on_current_rank:
190
+ # only keep the best N checkpoints according to validation metric
191
+ checkpoints = checkpoint_paths(
192
+ cfg.save_dir,
193
+ pattern=r"checkpoint\.best_{}_(\d+\.?\d*){}\.pt".format(
194
+ cfg.best_checkpoint_metric, suffix
195
+ ),
196
+ )
197
+ if not cfg.maximize_best_checkpoint_metric:
198
+ checkpoints = checkpoints[::-1]
199
+ for old_chk in checkpoints[cfg.keep_best_checkpoints :]:
200
+ if os.path.lexists(old_chk):
201
+ os.remove(old_chk)
202
+ elif PathManager.exists(old_chk):
203
+ PathManager.rm(old_chk)
204
+
205
+ return saved_cp
206
+
207
+
208
+ def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args):
209
+ """
210
+ Load a checkpoint and restore the training iterator.
211
+
212
+ *passthrough_args* will be passed through to
213
+ ``trainer.get_train_iterator``.
214
+ """
215
+
216
+ reset_optimizer = cfg.reset_optimizer
217
+ reset_lr_scheduler = cfg.reset_lr_scheduler
218
+ optimizer_overrides = ast.literal_eval(cfg.optimizer_overrides)
219
+ reset_meters = cfg.reset_meters
220
+ reset_dataloader = cfg.reset_dataloader
221
+
222
+ if cfg.finetune_from_model is not None and (
223
+ reset_optimizer or reset_lr_scheduler or reset_meters or reset_dataloader
224
+ ):
225
+ raise ValueError(
226
+ "--finetune-from-model can not be set together with either --reset-optimizer"
227
+ " or reset_lr_scheduler or reset_meters or reset_dataloader"
228
+ )
229
+
230
+ suffix = trainer.checkpoint_suffix
231
+ if (
232
+ cfg.restore_file == "checkpoint_last.pt"
233
+ ): # default value of restore_file is 'checkpoint_last.pt'
234
+ checkpoint_path = os.path.join(
235
+ cfg.save_dir, "checkpoint_last{}.pt".format(suffix)
236
+ )
237
+ first_launch = not PathManager.exists(checkpoint_path)
238
+ if first_launch and getattr(cfg, "continue_once", None) is not None:
239
+ checkpoint_path = cfg.continue_once
240
+ elif cfg.finetune_from_model is not None and first_launch:
241
+ # if there is no last checkpoint to restore, start the finetune from pretrained model
242
+ # else just use usual logic to load checkpoint, e.g. restart from last checkpoint and etc.
243
+ if PathManager.exists(cfg.finetune_from_model):
244
+ checkpoint_path = cfg.finetune_from_model
245
+ reset_optimizer = True
246
+ reset_lr_scheduler = True
247
+ reset_meters = True
248
+ reset_dataloader = True
249
+ logger.info(
250
+ f"loading pretrained model from {checkpoint_path}: "
251
+ "optimizer, lr scheduler, meters, dataloader will be reset"
252
+ )
253
+ else:
254
+ raise ValueError(
255
+ f"--finetune-from-model {cfg.finetune_from_model} does not exist"
256
+ )
257
+ elif suffix is not None:
258
+ checkpoint_path = cfg.restore_file.replace(".pt", suffix + ".pt")
259
+ else:
260
+ checkpoint_path = cfg.restore_file
261
+
262
+ if cfg.restore_file != "checkpoint_last.pt" and cfg.finetune_from_model:
263
+ raise ValueError(
264
+ "--finetune-from-model and --restore-file (non-default value) "
265
+ "can not be specified together: " + str(cfg)
266
+ )
267
+
268
+ extra_state = trainer.load_checkpoint(
269
+ checkpoint_path,
270
+ reset_optimizer,
271
+ reset_lr_scheduler,
272
+ optimizer_overrides,
273
+ reset_meters=reset_meters,
274
+ )
275
+
276
+ if (
277
+ extra_state is not None
278
+ and "best" in extra_state
279
+ and not reset_optimizer
280
+ and not reset_meters
281
+ ):
282
+ save_checkpoint.best = extra_state["best"]
283
+
284
+ if extra_state is not None and not reset_dataloader:
285
+ # restore iterator from checkpoint
286
+ itr_state = extra_state["train_iterator"]
287
+ epoch_itr = trainer.get_train_iterator(
288
+ epoch=itr_state["epoch"], load_dataset=True, **passthrough_args
289
+ )
290
+ epoch_itr.load_state_dict(itr_state)
291
+
292
+ # Preload the checkpoint for the task
293
+ task_cp_dict = extra_state.get(trainer.task.__class__.__name__, {})
294
+ if task_cp_dict and hasattr(trainer.task, "set_checkpoint_dict"):
295
+ trainer.task.set_checkpoint_dict(task_cp_dict)
296
+ else:
297
+ epoch_itr = trainer.get_train_iterator(
298
+ epoch=1, load_dataset=True, **passthrough_args
299
+ )
300
+
301
+ trainer.lr_step(epoch_itr.epoch)
302
+
303
+ return extra_state, epoch_itr
304
+
305
+
306
+ def load_checkpoint_to_cpu(path, arg_overrides=None, load_on_all_ranks=False):
307
+ """Loads a checkpoint to CPU (with upgrading for backward compatibility).
308
+
309
+ If doing single-GPU training or if the checkpoint is only being loaded by at
310
+ most one process on each node (current default behavior is for only rank 0
311
+ to read the checkpoint from disk), load_on_all_ranks should be False to
312
+ avoid errors from torch.distributed not having been initialized or
313
+ torch.distributed.barrier() hanging.
314
+
315
+ If all processes on each node may be loading the checkpoint
316
+ simultaneously, load_on_all_ranks should be set to True to avoid I/O
317
+ conflicts.
318
+
319
+ There's currently no support for > 1 but < all processes loading the
320
+ checkpoint on each node.
321
+ """
322
+ local_path = PathManager.get_local_path(path)
323
+ # The locally cached file returned by get_local_path() may be stale for
324
+ # remote files that are periodically updated/overwritten (ex:
325
+ # checkpoint_last.pt) - so we remove the local copy, sync across processes
326
+ # (if needed), and then download a fresh copy.
327
+ if local_path != path and PathManager.path_requires_pathmanager(path):
328
+ try:
329
+ os.remove(local_path)
330
+ except FileNotFoundError:
331
+ # With potentially multiple processes removing the same file, the
332
+ # file being missing is benign (missing_ok isn't available until
333
+ # Python 3.8).
334
+ pass
335
+ if load_on_all_ranks:
336
+ torch.distributed.barrier()
337
+ local_path = PathManager.get_local_path(path)
338
+
339
+ with open(local_path, "rb") as f:
340
+ state = torch.load(f, map_location=torch.device("cpu"), weights_only=False)
341
+
342
+ if "args" in state and state["args"] is not None and arg_overrides is not None:
343
+ args = state["args"]
344
+ for arg_name, arg_val in arg_overrides.items():
345
+ setattr(args, arg_name, arg_val)
346
+
347
+ if "cfg" in state and state["cfg"] is not None:
348
+
349
+ # hack to be able to set Namespace in dict config. this should be removed when we update to newer
350
+ # omegaconf version that supports object flags, or when we migrate all existing models
351
+ from omegaconf import __version__ as oc_version
352
+ from omegaconf import _utils
353
+
354
+ if oc_version < "2.2":
355
+ old_primitive = _utils.is_primitive_type
356
+ _utils.is_primitive_type = lambda _: True
357
+
358
+ state["cfg"] = OmegaConf.create(state["cfg"])
359
+
360
+ _utils.is_primitive_type = old_primitive
361
+ OmegaConf.set_struct(state["cfg"], True)
362
+ else:
363
+ state["cfg"] = OmegaConf.create(state["cfg"], flags={"allow_objects": True})
364
+
365
+ if arg_overrides is not None:
366
+ overwrite_args_by_name(state["cfg"], arg_overrides)
367
+
368
+ state = _upgrade_state_dict(state)
369
+ return state
370
+
371
+
372
+ def load_model_ensemble(
373
+ filenames,
374
+ arg_overrides: Optional[Dict[str, Any]] = None,
375
+ task=None,
376
+ strict=True,
377
+ suffix="",
378
+ num_shards=1,
379
+ state=None,
380
+ ):
381
+ """Loads an ensemble of models.
382
+
383
+ Args:
384
+ filenames (List[str]): checkpoint files to load
385
+ arg_overrides (Dict[str,Any], optional): override model args that
386
+ were used during model training
387
+ task (fairseq.tasks.FairseqTask, optional): task to use for loading
388
+ """
389
+ assert not (
390
+ strict and num_shards > 1
391
+ ), "Cannot load state dict with strict=True and checkpoint shards > 1"
392
+ ensemble, args, _task = load_model_ensemble_and_task(
393
+ filenames,
394
+ arg_overrides,
395
+ task,
396
+ strict,
397
+ suffix,
398
+ num_shards,
399
+ state,
400
+ )
401
+ return ensemble, args
402
+
403
+
404
+ def get_maybe_sharded_checkpoint_filename(
405
+ filename: str, suffix: str, shard_idx: int, num_shards: int
406
+ ) -> str:
407
+ orig_filename = filename
408
+ filename = filename.replace(".pt", suffix + ".pt")
409
+ fsdp_filename = filename[:-3] + f"-shard{shard_idx}.pt"
410
+ model_parallel_filename = orig_filename[:-3] + f"_part{shard_idx}.pt"
411
+ if PathManager.exists(fsdp_filename):
412
+ return fsdp_filename
413
+ elif num_shards > 1:
414
+ return model_parallel_filename
415
+ else:
416
+ return filename
417
+
418
+
419
+ def load_model_ensemble_and_task(
420
+ filenames,
421
+ arg_overrides: Optional[Dict[str, Any]] = None,
422
+ task=None,
423
+ strict=True,
424
+ suffix="",
425
+ num_shards=1,
426
+ state=None,
427
+ ):
428
+ assert state is None or len(filenames) == 1
429
+
430
+ from fairseq import tasks
431
+
432
+ assert not (
433
+ strict and num_shards > 1
434
+ ), "Cannot load state dict with strict=True and checkpoint shards > 1"
435
+ ensemble = []
436
+ cfg = None
437
+ for filename in filenames:
438
+ orig_filename = filename
439
+ model_shard_state = {"shard_weights": [], "shard_metadata": []}
440
+ assert num_shards > 0
441
+ st = time.time()
442
+ for shard_idx in range(num_shards):
443
+ filename = get_maybe_sharded_checkpoint_filename(
444
+ orig_filename, suffix, shard_idx, num_shards
445
+ )
446
+
447
+ if not PathManager.exists(filename):
448
+ raise IOError("Model file not found: {}".format(filename))
449
+ if state is None:
450
+ state = load_checkpoint_to_cpu(filename, arg_overrides)
451
+ if "args" in state and state["args"] is not None:
452
+ cfg = convert_namespace_to_omegaconf(state["args"])
453
+ elif "cfg" in state and state["cfg"] is not None:
454
+ cfg = state["cfg"]
455
+ else:
456
+ raise RuntimeError(
457
+ f"Neither args nor cfg exist in state keys = {state.keys()}"
458
+ )
459
+
460
+ if task is None:
461
+ task = tasks.setup_task(cfg.task, from_checkpoint=True)
462
+
463
+ if "task_state" in state:
464
+ task.load_state_dict(state["task_state"])
465
+
466
+ argspec = inspect.getfullargspec(task.build_model)
467
+
468
+ if "fsdp_metadata" in state and num_shards > 1:
469
+ model_shard_state["shard_weights"].append(state["model"])
470
+ model_shard_state["shard_metadata"].append(state["fsdp_metadata"])
471
+ # check FSDP import before the code goes too far
472
+ if not has_FSDP:
473
+ raise ImportError(
474
+ "Cannot find FullyShardedDataParallel. "
475
+ "Please install fairscale with: pip install fairscale"
476
+ )
477
+ if shard_idx == num_shards - 1:
478
+ consolidated_model_state = FSDP.consolidate_shard_weights(
479
+ shard_weights=model_shard_state["shard_weights"],
480
+ shard_metadata=model_shard_state["shard_metadata"],
481
+ )
482
+ if "from_checkpoint" in argspec.args:
483
+ model = task.build_model(cfg.model, from_checkpoint=True)
484
+ else:
485
+ model = task.build_model(cfg.model)
486
+ if (
487
+ "optimizer_history" in state
488
+ and len(state["optimizer_history"]) > 0
489
+ and "num_updates" in state["optimizer_history"][-1]
490
+ ):
491
+ model.set_num_updates(
492
+ state["optimizer_history"][-1]["num_updates"]
493
+ )
494
+ model.load_state_dict(
495
+ consolidated_model_state, strict=strict, model_cfg=cfg.model
496
+ )
497
+ else:
498
+ # model parallel checkpoint or unsharded checkpoint
499
+ # support old external tasks
500
+
501
+ if "from_checkpoint" in argspec.args:
502
+ model = task.build_model(cfg.model, from_checkpoint=True)
503
+ else:
504
+ model = task.build_model(cfg.model)
505
+ if (
506
+ "optimizer_history" in state
507
+ and len(state["optimizer_history"]) > 0
508
+ and "num_updates" in state["optimizer_history"][-1]
509
+ ):
510
+ model.set_num_updates(state["optimizer_history"][-1]["num_updates"])
511
+ model.load_state_dict(
512
+ state["model"], strict=strict, model_cfg=cfg.model
513
+ )
514
+
515
+ # reset state so it gets loaded for the next model in ensemble
516
+ state = None
517
+ if shard_idx % 10 == 0 and shard_idx > 0:
518
+ elapsed = time.time() - st
519
+ logger.info(
520
+ f"Loaded {shard_idx} shards in {elapsed:.2f}s, {elapsed / (shard_idx+1):.2f}s/shard"
521
+ )
522
+
523
+ # build model for ensemble
524
+ ensemble.append(model)
525
+ return ensemble, cfg, task
526
+
527
+
528
+ def load_model_ensemble_and_task_from_hf_hub(
529
+ model_id,
530
+ cache_dir: Optional[str] = None,
531
+ arg_overrides: Optional[Dict[str, Any]] = None,
532
+ **kwargs: Any,
533
+ ):
534
+ try:
535
+ from huggingface_hub import snapshot_download
536
+ except ImportError:
537
+ raise ImportError(
538
+ "You need to install huggingface_hub to use `load_from_hf_hub`. "
539
+ "See https://pypi.org/project/huggingface-hub/ for installation."
540
+ )
541
+
542
+ library_name = "fairseq"
543
+ cache_dir = cache_dir or (Path.home() / ".cache" / library_name).as_posix()
544
+ cache_dir = snapshot_download(
545
+ model_id, cache_dir=cache_dir, library_name=library_name, **kwargs
546
+ )
547
+
548
+ _arg_overrides = arg_overrides or {}
549
+ _arg_overrides["data"] = cache_dir
550
+ return load_model_ensemble_and_task(
551
+ [p.as_posix() for p in Path(cache_dir).glob("*.pt")],
552
+ arg_overrides=_arg_overrides,
553
+ )
554
+
555
+
556
+ def checkpoint_paths(path, pattern=r"checkpoint(\d+)\.pt", keep_match=False):
557
+ """Retrieves all checkpoints found in `path` directory.
558
+
559
+ Checkpoints are identified by matching filename to the specified pattern. If
560
+ the pattern contains groups, the result will be sorted by the first group in
561
+ descending order.
562
+ """
563
+ pt_regexp = re.compile(pattern)
564
+ files = PathManager.ls(path)
565
+
566
+ entries = []
567
+ for i, f in enumerate(files):
568
+ m = pt_regexp.fullmatch(f)
569
+ if m is not None:
570
+ idx = float(m.group(1)) if len(m.groups()) > 0 else i
571
+ entries.append((idx, m.group(0)))
572
+ if keep_match:
573
+ return [(os.path.join(path, x[1]), x[0]) for x in sorted(entries, reverse=True)]
574
+ else:
575
+ return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)]
576
+
577
+
578
+ def torch_persistent_save(obj, filename, async_write: bool = False):
579
+ if async_write:
580
+ with PathManager.opena(filename, "wb") as f:
581
+ _torch_persistent_save(obj, f)
582
+ else:
583
+ if PathManager.supports_rename(filename):
584
+ # do atomic save
585
+ with PathManager.open(filename + ".tmp", "wb") as f:
586
+ _torch_persistent_save(obj, f)
587
+ PathManager.rename(filename + ".tmp", filename)
588
+ else:
589
+ # fallback to non-atomic save
590
+ with PathManager.open(filename, "wb") as f:
591
+ _torch_persistent_save(obj, f)
592
+
593
+
594
+ def _torch_persistent_save(obj, f):
595
+ if isinstance(f, str):
596
+ with PathManager.open(f, "wb") as h:
597
+ torch_persistent_save(obj, h)
598
+ return
599
+ for i in range(3):
600
+ try:
601
+ return torch.save(obj, f)
602
+ except Exception:
603
+ if i == 2:
604
+ logger.error(traceback.format_exc())
605
+ raise
606
+ else:
607
+ time.sleep(2.5)
608
+
609
+
610
+ def _upgrade_state_dict(state):
611
+ """Helper for upgrading old model checkpoints."""
612
+
613
+ # add optimizer_history
614
+ if "optimizer_history" not in state:
615
+ state["optimizer_history"] = [
616
+ {"criterion_name": "CrossEntropyCriterion", "best_loss": state["best_loss"]}
617
+ ]
618
+ state["last_optimizer_state"] = state["optimizer"]
619
+ del state["optimizer"]
620
+ del state["best_loss"]
621
+ # move extra_state into sub-dictionary
622
+ if "epoch" in state and "extra_state" not in state:
623
+ state["extra_state"] = {
624
+ "epoch": state["epoch"],
625
+ "batch_offset": state["batch_offset"],
626
+ "val_loss": state["val_loss"],
627
+ }
628
+ del state["epoch"]
629
+ del state["batch_offset"]
630
+ del state["val_loss"]
631
+ # reduce optimizer history's memory usage (only keep the last state)
632
+ if "optimizer" in state["optimizer_history"][-1]:
633
+ state["last_optimizer_state"] = state["optimizer_history"][-1]["optimizer"]
634
+ for optim_hist in state["optimizer_history"]:
635
+ del optim_hist["optimizer"]
636
+ # record the optimizer class name
637
+ if "optimizer_name" not in state["optimizer_history"][-1]:
638
+ state["optimizer_history"][-1]["optimizer_name"] = "FairseqNAG"
639
+ # move best_loss into lr_scheduler_state
640
+ if "lr_scheduler_state" not in state["optimizer_history"][-1]:
641
+ state["optimizer_history"][-1]["lr_scheduler_state"] = {
642
+ "best": state["optimizer_history"][-1]["best_loss"]
643
+ }
644
+ del state["optimizer_history"][-1]["best_loss"]
645
+ # keep track of number of updates
646
+ if "num_updates" not in state["optimizer_history"][-1]:
647
+ state["optimizer_history"][-1]["num_updates"] = 0
648
+ # use stateful training data iterator
649
+ if "train_iterator" not in state["extra_state"]:
650
+ state["extra_state"]["train_iterator"] = {
651
+ "epoch": state["extra_state"].get("epoch", 0),
652
+ "iterations_in_epoch": state["extra_state"].get("batch_offset", 0),
653
+ }
654
+
655
+ # backward compatibility, cfg updates
656
+ if "args" in state and state["args"] is not None:
657
+ # old model checkpoints may not have separate source/target positions
658
+ if hasattr(state["args"], "max_positions") and not hasattr(
659
+ state["args"], "max_source_positions"
660
+ ):
661
+ state["args"].max_source_positions = state["args"].max_positions
662
+ state["args"].max_target_positions = state["args"].max_positions
663
+ # default to translation task
664
+ if not hasattr(state["args"], "task"):
665
+ state["args"].task = "translation"
666
+ # --raw-text and --lazy-load are deprecated
667
+ if getattr(state["args"], "raw_text", False):
668
+ state["args"].dataset_impl = "raw"
669
+ elif getattr(state["args"], "lazy_load", False):
670
+ state["args"].dataset_impl = "lazy"
671
+ # epochs start at 1
672
+ if state["extra_state"]["train_iterator"] is not None:
673
+ state["extra_state"]["train_iterator"]["epoch"] = max(
674
+ state["extra_state"]["train_iterator"].get("epoch", 1), 1
675
+ )
676
+ # --remove-bpe ==> --postprocess
677
+ if hasattr(state["args"], "remove_bpe"):
678
+ state["args"].post_process = state["args"].remove_bpe
679
+ # --min-lr ==> --stop-min-lr
680
+ if hasattr(state["args"], "min_lr"):
681
+ state["args"].stop_min_lr = state["args"].min_lr
682
+ del state["args"].min_lr
683
+ # binary_cross_entropy / kd_binary_cross_entropy => wav2vec criterion
684
+ if hasattr(state["args"], "criterion") and state["args"].criterion in [
685
+ "binary_cross_entropy",
686
+ "kd_binary_cross_entropy",
687
+ ]:
688
+ state["args"].criterion = "wav2vec"
689
+ # remove log_keys if it's None (criteria will supply a default value of [])
690
+ if hasattr(state["args"], "log_keys") and state["args"].log_keys is None:
691
+ delattr(state["args"], "log_keys")
692
+ # speech_pretraining => audio pretraining
693
+ if (
694
+ hasattr(state["args"], "task")
695
+ and state["args"].task == "speech_pretraining"
696
+ ):
697
+ state["args"].task = "audio_pretraining"
698
+ # audio_cpc => wav2vec
699
+ if hasattr(state["args"], "arch") and state["args"].arch == "audio_cpc":
700
+ state["args"].arch = "wav2vec"
701
+ # convert legacy float learning rate to List[float]
702
+ if hasattr(state["args"], "lr") and isinstance(state["args"].lr, float):
703
+ state["args"].lr = [state["args"].lr]
704
+ # convert task data arg to a string instead of List[string]
705
+ if (
706
+ hasattr(state["args"], "data")
707
+ and isinstance(state["args"].data, list)
708
+ and len(state["args"].data) > 0
709
+ ):
710
+ state["args"].data = state["args"].data[0]
711
+
712
+ state["cfg"] = convert_namespace_to_omegaconf(state["args"])
713
+
714
+ if "cfg" in state and state["cfg"] is not None:
715
+ cfg = state["cfg"]
716
+ with open_dict(cfg):
717
+ # any upgrades for Hydra-based configs
718
+ if (
719
+ "task" in cfg
720
+ and "eval_wer_config" in cfg.task
721
+ and isinstance(cfg.task.eval_wer_config.print_alignment, bool)
722
+ ):
723
+ cfg.task.eval_wer_config.print_alignment = "hard"
724
+ if "generation" in cfg and isinstance(cfg.generation.print_alignment, bool):
725
+ cfg.generation.print_alignment = (
726
+ "hard" if cfg.generation.print_alignment else None
727
+ )
728
+ if (
729
+ "model" in cfg
730
+ and "w2v_args" in cfg.model
731
+ and cfg.model.w2v_args is not None
732
+ and (
733
+ hasattr(cfg.model.w2v_args, "task") or "task" in cfg.model.w2v_args
734
+ )
735
+ and hasattr(cfg.model.w2v_args.task, "eval_wer_config")
736
+ and cfg.model.w2v_args.task.eval_wer_config is not None
737
+ and isinstance(
738
+ cfg.model.w2v_args.task.eval_wer_config.print_alignment, bool
739
+ )
740
+ ):
741
+ cfg.model.w2v_args.task.eval_wer_config.print_alignment = "hard"
742
+
743
+ return state
744
+
745
+
746
+ def prune_state_dict(state_dict, model_cfg: Optional[DictConfig]):
747
+ """Prune the given state_dict if desired for LayerDrop
748
+ (https://arxiv.org/abs/1909.11556).
749
+
750
+ Training with LayerDrop allows models to be robust to pruning at inference
751
+ time. This function prunes state_dict to allow smaller models to be loaded
752
+ from a larger model and re-maps the existing state_dict for this to occur.
753
+
754
+ It's called by functions that load models from checkpoints and does not
755
+ need to be called directly.
756
+ """
757
+ arch = None
758
+ if model_cfg is not None:
759
+ arch = (
760
+ model_cfg._name
761
+ if isinstance(model_cfg, DictConfig)
762
+ else getattr(model_cfg, "arch", None)
763
+ )
764
+
765
+ if not model_cfg or arch is None or arch == "ptt_transformer":
766
+ # args should not be none, but don't crash if it is.
767
+ return state_dict
768
+
769
+ encoder_layers_to_keep = getattr(model_cfg, "encoder_layers_to_keep", None)
770
+ decoder_layers_to_keep = getattr(model_cfg, "decoder_layers_to_keep", None)
771
+
772
+ if not encoder_layers_to_keep and not decoder_layers_to_keep:
773
+ return state_dict
774
+
775
+ # apply pruning
776
+ logger.info(
777
+ "Pruning model to specified layer configuration - this works best if the model was trained with LayerDrop"
778
+ )
779
+
780
+ def create_pruning_pass(layers_to_keep, layer_name):
781
+ keep_layers = sorted(
782
+ int(layer_string) for layer_string in layers_to_keep.split(",")
783
+ )
784
+ mapping_dict = {}
785
+ for i in range(len(keep_layers)):
786
+ mapping_dict[str(keep_layers[i])] = str(i)
787
+
788
+ regex = re.compile(r"^{layer}.*\.layers\.(\d+)".format(layer=layer_name))
789
+ return {"substitution_regex": regex, "mapping_dict": mapping_dict}
790
+
791
+ pruning_passes = []
792
+ if encoder_layers_to_keep:
793
+ pruning_passes.append(create_pruning_pass(encoder_layers_to_keep, "encoder"))
794
+ if decoder_layers_to_keep:
795
+ pruning_passes.append(create_pruning_pass(decoder_layers_to_keep, "decoder"))
796
+
797
+ new_state_dict = {}
798
+ for layer_name in state_dict.keys():
799
+ match = re.search(r"\.layers\.(\d+)\.", layer_name)
800
+ # if layer has no number in it, it is a supporting layer, such as an
801
+ # embedding
802
+ if not match:
803
+ new_state_dict[layer_name] = state_dict[layer_name]
804
+ continue
805
+
806
+ # otherwise, layer should be pruned.
807
+ original_layer_number = match.group(1)
808
+ # figure out which mapping dict to replace from
809
+ for pruning_pass in pruning_passes:
810
+ if original_layer_number in pruning_pass["mapping_dict"] and pruning_pass[
811
+ "substitution_regex"
812
+ ].search(layer_name):
813
+ new_layer_number = pruning_pass["mapping_dict"][original_layer_number]
814
+ substitution_match = pruning_pass["substitution_regex"].search(
815
+ layer_name
816
+ )
817
+ new_state_key = (
818
+ layer_name[: substitution_match.start(1)]
819
+ + new_layer_number
820
+ + layer_name[substitution_match.end(1) :]
821
+ )
822
+ new_state_dict[new_state_key] = state_dict[layer_name]
823
+
824
+ # Since layers are now pruned, *_layers_to_keep are no longer needed.
825
+ # This is more of "It would make it work fix" rather than a proper fix.
826
+ if isinstance(model_cfg, DictConfig):
827
+ context = open_dict(model_cfg)
828
+ else:
829
+ context = contextlib.ExitStack()
830
+ with context:
831
+ if hasattr(model_cfg, "encoder_layers_to_keep"):
832
+ model_cfg.encoder_layers_to_keep = None
833
+ if hasattr(model_cfg, "decoder_layers_to_keep"):
834
+ model_cfg.decoder_layers_to_keep = None
835
+
836
+ return new_state_dict
837
+
838
+
839
+ def load_pretrained_component_from_model(
840
+ component: Union[FairseqEncoder, FairseqDecoder],
841
+ checkpoint: str,
842
+ strict: bool = True,
843
+ ):
844
+ """
845
+ Load a pretrained FairseqEncoder or FairseqDecoder from checkpoint into the
846
+ provided `component` object. If state_dict fails to load, there may be a
847
+ mismatch in the architecture of the corresponding `component` found in the
848
+ `checkpoint` file.
849
+ """
850
+ if not PathManager.exists(checkpoint):
851
+ raise IOError("Model file not found: {}".format(checkpoint))
852
+ state = load_checkpoint_to_cpu(checkpoint)
853
+ if isinstance(component, FairseqEncoder):
854
+ component_type = "encoder"
855
+ elif isinstance(component, FairseqDecoder):
856
+ component_type = "decoder"
857
+ else:
858
+ raise ValueError(
859
+ "component to load must be either a FairseqEncoder or "
860
+ "FairseqDecoder. Loading other component types are not supported."
861
+ )
862
+ component_state_dict = OrderedDict()
863
+ for key in state["model"].keys():
864
+ if key.startswith(component_type):
865
+ # encoder.input_layers.0.0.weight --> input_layers.0.0.weight
866
+ component_subkey = key[len(component_type) + 1 :]
867
+ component_state_dict[component_subkey] = state["model"][key]
868
+ component.load_state_dict(component_state_dict, strict=strict)
869
+ return component
870
+
871
+
872
+ def verify_checkpoint_directory(save_dir: str) -> None:
873
+ if not os.path.exists(save_dir):
874
+ os.makedirs(save_dir, exist_ok=True)
875
+ temp_file_path = os.path.join(save_dir, "dummy")
876
+ try:
877
+ with open(temp_file_path, "w"):
878
+ pass
879
+ except OSError as e:
880
+ logger.warning(
881
+ "Unable to access checkpoint save directory: {}".format(save_dir)
882
+ )
883
+ raise e
884
+ else:
885
+ os.remove(temp_file_path)
886
+
887
+
888
+ def save_ema_as_checkpoint(src_path, dst_path):
889
+ state = load_ema_from_checkpoint(src_path)
890
+ torch_persistent_save(state, dst_path)
891
+
892
+
893
+ def load_ema_from_checkpoint(fpath):
894
+ """Loads exponential moving averaged (EMA) checkpoint from input and
895
+ returns a model with ema weights.
896
+
897
+ Args:
898
+ fpath: A string path of checkpoint to load from.
899
+
900
+ Returns:
901
+ A dict of string keys mapping to various values. The 'model' key
902
+ from the returned dict should correspond to an OrderedDict mapping
903
+ string parameter names to torch Tensors.
904
+ """
905
+ params_dict = collections.OrderedDict()
906
+ new_state = None
907
+
908
+ with PathManager.open(fpath, "rb") as f:
909
+ new_state = torch.load(
910
+ f,
911
+ map_location=(
912
+ lambda s, _: torch.serialization.default_restore_location(s, "cpu")
913
+ ),
914
+ )
915
+
916
+ # EMA model is stored in a separate "extra state"
917
+ model_params = new_state["extra_state"]["ema"]
918
+
919
+ for key in list(model_params.keys()):
920
+ p = model_params[key]
921
+ if isinstance(p, torch.HalfTensor):
922
+ p = p.float()
923
+ if key not in params_dict:
924
+ params_dict[key] = p.clone()
925
+ # NOTE: clone() is needed in case of p is a shared parameter
926
+ else:
927
+ raise ValueError("Key {} is repeated in EMA model params.".format(key))
928
+
929
+ if len(params_dict) == 0:
930
+ raise ValueError(
931
+ f"Input checkpoint path '{fpath}' does not contain "
932
+ "ema model weights, is this model trained with EMA?"
933
+ )
934
+
935
+ new_state["model"] = params_dict
936
+ return new_state
fairseq/clib/cuda/ngram_repeat_block_cuda.cpp ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ Copyright (c) Microsoft Corporation.
3
+ Licensed under the MIT License.
4
+ */
5
+
6
+ #include <torch/extension.h>
7
+ #include <vector>
8
+
9
+ /*
10
+ CPP Binding for CUDA OP
11
+ */
12
+
13
+ // CUDA forward declarations
14
+ torch::Tensor ngram_repeat_block_cuda_forward(
15
+ torch::Tensor tokens,
16
+ torch::Tensor lprobs,
17
+ int bsz,
18
+ int step,
19
+ int beam_size,
20
+ int no_repeat_ngram_size);
21
+
22
+ #define CHECK_CUDA(x) \
23
+ TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
24
+ #define CHECK_CONTIGUOUS(x) \
25
+ TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
26
+ #define CHECK_INPUT(x) \
27
+ CHECK_CUDA(x); \
28
+ CHECK_CONTIGUOUS(x)
29
+
30
+ // Input check and call to CUDA OP
31
+ // Backward method not required
32
+ torch::Tensor ngram_repeat_block_forward(
33
+ torch::Tensor tokens,
34
+ torch::Tensor lprobs,
35
+ int bsz,
36
+ int step,
37
+ int beam_size,
38
+ int no_repeat_ngram_size) {
39
+ CHECK_INPUT(tokens);
40
+ CHECK_INPUT(lprobs);
41
+ assert(bsz > 0);
42
+ assert(step >= 0);
43
+ assert(beam_size > 0);
44
+ assert(no_repeat_ngram_size > 0);
45
+
46
+ return ngram_repeat_block_cuda_forward(
47
+ tokens, lprobs, bsz, step, beam_size, no_repeat_ngram_size);
48
+ }
49
+
50
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
51
+ m.def(
52
+ "forward",
53
+ &ngram_repeat_block_forward,
54
+ "No Repeat Ngram Block forward (CUDA)");
55
+ }
fairseq/clib/cuda/ngram_repeat_block_cuda_kernel.cu ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ Copyright (c) Microsoft Corporation.
3
+ Licensed under the MIT License.
4
+ */
5
+
6
+ /*
7
+ Kernel implementation for blocking repeated n-grams.
8
+ */
9
+
10
+ #include <cuda.h>
11
+ #include <cuda_runtime.h>
12
+ #include <math.h>
13
+ #include <torch/extension.h>
14
+ #include <vector>
15
+
16
+ // Ban repeated ngrams of length = 'no_repeat_ngram_size'
17
+ __global__ void banRepeatedTokens(
18
+ long* __restrict__ tokens,
19
+ float* __restrict__ lprobs,
20
+ int max_predict_len,
21
+ int vocab_size,
22
+ int no_repeat_ngram_size) {
23
+ auto row = blockIdx.x;
24
+ auto col = threadIdx.x;
25
+ auto start = row * (max_predict_len) + col;
26
+ // Each thread compares ngram starting from
27
+ // thread index with final ngram starting from
28
+ // step - no_repeat_ngram_size +2
29
+ auto check_start_pos = blockDim.x;
30
+ auto lprob_start = row * vocab_size;
31
+ bool is_banned = true;
32
+ extern __shared__ long tokens_shm[];
33
+ tokens_shm[col] = tokens[start];
34
+ if (col == blockDim.x - 1) {
35
+ for (int i = 1; i < no_repeat_ngram_size; i++) {
36
+ if (col + i < max_predict_len) {
37
+ tokens_shm[col + i] = tokens[start + i];
38
+ }
39
+ }
40
+ }
41
+ __syncthreads();
42
+
43
+ for (int k = 0; k < no_repeat_ngram_size - 1; k++) {
44
+ if (tokens_shm[col + k] != tokens_shm[check_start_pos + k]) {
45
+ is_banned = false;
46
+ }
47
+ }
48
+ if (is_banned == true) {
49
+ auto token_to_be_banned = tokens_shm[col + no_repeat_ngram_size - 1];
50
+ lprobs[lprob_start + token_to_be_banned] = -INFINITY;
51
+ }
52
+ }
53
+
54
+ // Allocate blocks and threads based on
55
+ // batch size and sequence length and launch
56
+ // kernel
57
+ torch::Tensor ngram_repeat_block_cuda_forward(
58
+ const torch::Tensor tokens,
59
+ torch::Tensor lprobs,
60
+ int bsz,
61
+ int step,
62
+ int beam_size,
63
+ int no_repeat_ngram_size) {
64
+ int threads = step - no_repeat_ngram_size + 2;
65
+ if (threads <= 0)
66
+ return lprobs;
67
+ int max_predict_len = tokens.size(1);
68
+ int vocab_size = lprobs.size(1);
69
+ auto token_ptr = tokens.data_ptr<long>();
70
+ auto lprob_ptr = lprobs.data_ptr<float>();
71
+ int blocks = bsz * beam_size;
72
+ int shared_mem_size = (step + 1) * sizeof(long);
73
+
74
+ // Launching N blocks where N is number of samples in a batch (beams*bsz)
75
+ // Launching T threads where T is number of previous ngrams in a sample
76
+ // Allocating shared mem per block for fastser access of input tokens since
77
+ // each token will be accessed N times to compare with current Ngram where
78
+ // N is Ngram size.
79
+ banRepeatedTokens<<<blocks, threads, shared_mem_size>>>(
80
+ token_ptr, lprob_ptr, max_predict_len, vocab_size, no_repeat_ngram_size);
81
+ return lprobs;
82
+ }
fairseq/clib/libbase/balanced_assignment.cpp ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**
2
+ * Copyright 2017-present, Facebook, Inc.
3
+ * All rights reserved.
4
+ *
5
+ * This source code is licensed under the license found in the
6
+ * LICENSE file in the root directory of this source tree.
7
+ */
8
+
9
+ /*
10
+ C++ code for solving the linear assignment problem.
11
+ Based on the Auction Algorithm from
12
+ https://dspace.mit.edu/bitstream/handle/1721.1/3265/P-2108-26912652.pdf and the
13
+ implementation from: https://github.com/bkj/auction-lap Adapted to be more
14
+ efficient when each worker is looking for k jobs instead of 1.
15
+ */
16
+ #include <torch/extension.h>
17
+ #include <iostream>
18
+ using namespace torch::indexing;
19
+ torch::Tensor balanced_assignment(torch::Tensor job_and_worker_to_score) {
20
+ int max_iterations = 100;
21
+ torch::Tensor epsilon =
22
+ (job_and_worker_to_score.max() - job_and_worker_to_score.min()) / 50;
23
+ epsilon.clamp_min_(1e-04);
24
+ torch::Tensor worker_and_job_to_score =
25
+ job_and_worker_to_score.detach().transpose(0, 1).contiguous();
26
+ int num_workers = worker_and_job_to_score.size(0);
27
+ int num_jobs = worker_and_job_to_score.size(1);
28
+ auto device = worker_and_job_to_score.device();
29
+ int jobs_per_worker = num_jobs / num_workers;
30
+ torch::Tensor value = worker_and_job_to_score.clone();
31
+ int counter = 0;
32
+ torch::Tensor max_value = worker_and_job_to_score.max();
33
+
34
+ torch::Tensor bid_indices;
35
+ torch::Tensor cost = worker_and_job_to_score.new_zeros({1, num_jobs});
36
+ torch::Tensor bids =
37
+ worker_and_job_to_score.new_empty({num_workers, num_jobs});
38
+ torch::Tensor bid_increments =
39
+ worker_and_job_to_score.new_empty({num_workers, jobs_per_worker});
40
+ torch::Tensor top_values =
41
+ worker_and_job_to_score.new_empty({num_workers, jobs_per_worker + 1});
42
+ torch::Tensor high_bids = worker_and_job_to_score.new_empty({num_jobs});
43
+
44
+ torch::Tensor top_index = top_values.to(torch::kLong);
45
+ torch::Tensor high_bidders = top_index.new_empty({num_jobs});
46
+ torch::Tensor have_bids = high_bidders.to(torch::kBool);
47
+ torch::Tensor jobs_indices =
48
+ torch::arange({num_jobs}, torch::dtype(torch::kLong).device(device));
49
+ torch::Tensor true_tensor =
50
+ torch::ones({1}, torch::dtype(torch::kBool).device(device));
51
+
52
+ while (true) {
53
+ bids.zero_();
54
+ torch::topk_out(top_values, top_index, value, jobs_per_worker + 1, 1);
55
+
56
+ // Each worker bids the difference in value between that job and the k+1th
57
+ // job
58
+ torch::sub_out(
59
+ bid_increments,
60
+ top_values.index({Slice(None, None), Slice(0, jobs_per_worker)}),
61
+ top_values.index({Slice(None, None), jobs_per_worker}).unsqueeze(1));
62
+
63
+ bid_increments.add_(epsilon);
64
+ bids.scatter_(
65
+ 1,
66
+ top_index.index({Slice(None, None), Slice(0, jobs_per_worker)}),
67
+ bid_increments);
68
+
69
+ if (counter < max_iterations && counter > 0) {
70
+ // Put in a minimal bid to retain items from the last round if no-one else
71
+ // bids for them this round
72
+ bids.view(-1).index_put_({bid_indices}, epsilon);
73
+ }
74
+
75
+ // Find the highest bidding worker per job
76
+ torch::max_out(high_bids, high_bidders, bids, 0);
77
+ torch::gt_out(have_bids, high_bids, 0);
78
+
79
+ if (have_bids.all().item<bool>()) {
80
+ // All jobs were bid for
81
+ break;
82
+ }
83
+
84
+ // Make popular items more expensive
85
+ cost.add_(high_bids);
86
+ torch::sub_out(value, worker_and_job_to_score, cost);
87
+
88
+ bid_indices = ((high_bidders * num_jobs) + jobs_indices).index({have_bids});
89
+
90
+ if (counter < max_iterations) {
91
+ // Make sure that this item will be in the winning worker's top-k next
92
+ // time.
93
+ value.view(-1).index_put_({bid_indices}, max_value);
94
+ } else {
95
+ // Suboptimal approximation that converges quickly from current solution
96
+ value.view(-1).index_put_(
97
+ {bid_indices}, worker_and_job_to_score.view(-1).index({bid_indices}));
98
+ }
99
+
100
+ counter += 1;
101
+ }
102
+
103
+ return top_index.index({Slice(None, None), Slice(0, jobs_per_worker)})
104
+ .reshape(-1);
105
+ }
106
+
107
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
108
+ m.def("balanced_assignment", &balanced_assignment, "Balanced Assignment");
109
+ }
fairseq/clib/libbleu/libbleu.cpp ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**
2
+ * Copyright 2017-present, Facebook, Inc.
3
+ * All rights reserved.
4
+ *
5
+ * This source code is licensed under the license found in the
6
+ * LICENSE file in the root directory of this source tree.
7
+ */
8
+
9
+ #include <array>
10
+ #include <cstdio>
11
+ #include <cstring>
12
+ #include <map>
13
+
14
+ // NOLINTNEXTLINE
15
+ typedef struct {
16
+ size_t reflen;
17
+ size_t predlen;
18
+ size_t match1;
19
+ size_t count1;
20
+ size_t match2;
21
+ size_t count2;
22
+ size_t match3;
23
+ size_t count3;
24
+ size_t match4;
25
+ size_t count4;
26
+ } bleu_stat;
27
+
28
+ // left trim (remove pad)
29
+ void bleu_ltrim(size_t* len, int** sent, int pad) {
30
+ size_t start = 0;
31
+ while (start < *len) {
32
+ if (*(*sent + start) != pad) {
33
+ break;
34
+ }
35
+ start++;
36
+ }
37
+ *sent += start;
38
+ *len -= start;
39
+ }
40
+
41
+ // right trim remove (eos)
42
+ void bleu_rtrim(size_t* len, int** sent, int pad, int eos) {
43
+ size_t end = *len - 1;
44
+ while (end > 0) {
45
+ if (*(*sent + end) != eos && *(*sent + end) != pad) {
46
+ break;
47
+ }
48
+ end--;
49
+ }
50
+ *len = end + 1;
51
+ }
52
+
53
+ // left and right trim
54
+ void bleu_trim(size_t* len, int** sent, int pad, int eos) {
55
+ bleu_ltrim(len, sent, pad);
56
+ bleu_rtrim(len, sent, pad, eos);
57
+ }
58
+
59
+ size_t bleu_hash(int len, int* data) {
60
+ size_t h = 14695981039346656037ul;
61
+ size_t prime = 0x100000001b3;
62
+ char* b = (char*)data;
63
+ size_t blen = sizeof(int) * len;
64
+
65
+ while (blen-- > 0) {
66
+ h ^= *b++;
67
+ h *= prime;
68
+ }
69
+
70
+ return h;
71
+ }
72
+
73
+ void bleu_addngram(
74
+ size_t* ntotal,
75
+ size_t* nmatch,
76
+ size_t n,
77
+ size_t reflen,
78
+ int* ref,
79
+ size_t predlen,
80
+ int* pred) {
81
+ if (predlen < n) {
82
+ return;
83
+ }
84
+
85
+ predlen = predlen - n + 1;
86
+ (*ntotal) += predlen;
87
+
88
+ if (reflen < n) {
89
+ return;
90
+ }
91
+
92
+ reflen = reflen - n + 1;
93
+
94
+ std::map<size_t, size_t> count;
95
+ while (predlen > 0) {
96
+ size_t w = bleu_hash(n, pred++);
97
+ count[w]++;
98
+ predlen--;
99
+ }
100
+
101
+ while (reflen > 0) {
102
+ size_t w = bleu_hash(n, ref++);
103
+ if (count[w] > 0) {
104
+ (*nmatch)++;
105
+ count[w] -= 1;
106
+ }
107
+ reflen--;
108
+ }
109
+ }
110
+
111
+ extern "C" {
112
+
113
+ #ifdef _WIN64
114
+ __declspec(dllexport)
115
+ #endif
116
+ void bleu_zero_init(bleu_stat* stat) {
117
+ std::memset(stat, 0, sizeof(bleu_stat));
118
+ }
119
+
120
+ #ifdef _WIN64
121
+ __declspec(dllexport)
122
+ #endif
123
+ void bleu_one_init(bleu_stat* stat) {
124
+ bleu_zero_init(stat);
125
+ stat->count1 = 0;
126
+ stat->count2 = 1;
127
+ stat->count3 = 1;
128
+ stat->count4 = 1;
129
+ stat->match1 = 0;
130
+ stat->match2 = 1;
131
+ stat->match3 = 1;
132
+ stat->match4 = 1;
133
+ }
134
+
135
+ #ifdef _WIN64
136
+ __declspec(dllexport)
137
+ #endif
138
+ void bleu_add(
139
+ bleu_stat* stat,
140
+ size_t reflen,
141
+ int* ref,
142
+ size_t predlen,
143
+ int* pred,
144
+ int pad,
145
+ int eos) {
146
+
147
+ bleu_trim(&reflen, &ref, pad, eos);
148
+ bleu_trim(&predlen, &pred, pad, eos);
149
+ stat->reflen += reflen;
150
+ stat->predlen += predlen;
151
+
152
+ bleu_addngram(&stat->count1, &stat->match1, 1, reflen, ref, predlen, pred);
153
+ bleu_addngram(&stat->count2, &stat->match2, 2, reflen, ref, predlen, pred);
154
+ bleu_addngram(&stat->count3, &stat->match3, 3, reflen, ref, predlen, pred);
155
+ bleu_addngram(&stat->count4, &stat->match4, 4, reflen, ref, predlen, pred);
156
+ }
157
+ }
fairseq/clib/libbleu/module.cpp ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**
2
+ * Copyright 2017-present, Facebook, Inc.
3
+ * All rights reserved.
4
+ *
5
+ * This source code is licensed under the license found in the
6
+ * LICENSE file in the root directory of this source tree.
7
+ */
8
+
9
+ #include <Python.h>
10
+
11
+ static PyMethodDef method_def[] = {{NULL, NULL, 0, NULL}}; // NOLINT
12
+
13
+ static struct PyModuleDef module_def = {
14
+ PyModuleDef_HEAD_INIT,
15
+ "libbleu", /* name of module */
16
+ // NOLINTNEXTLINE
17
+ NULL, /* module documentation, may be NULL */
18
+ -1, /* size of per-interpreter state of the module,
19
+ or -1 if the module keeps state in global variables. */
20
+ method_def}; // NOLINT
21
+
22
+ #if PY_MAJOR_VERSION == 2
23
+ PyMODINIT_FUNC init_libbleu()
24
+ #else
25
+ PyMODINIT_FUNC PyInit_libbleu()
26
+ #endif
27
+ {
28
+ PyObject* m = PyModule_Create(&module_def);
29
+ if (!m) {
30
+ return NULL;
31
+ }
32
+ return m;
33
+ }
fairseq/clib/libnat/edit_dist.cpp ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**
2
+ * Copyright 2017-present, Facebook, Inc.
3
+ * All rights reserved.
4
+ *
5
+ * This source code is licensed under the license found in the
6
+ * LICENSE file in the root directory of this source tree.
7
+ */
8
+
9
+ #include <pybind11/detail/common.h>
10
+ #include <pybind11/pybind11.h>
11
+ #include <torch/torch.h> // @manual=//caffe2:torch_extension
12
+ #include <algorithm>
13
+ #include <cstdint>
14
+ #include <iosfwd>
15
+ #include <memory>
16
+ #include <new>
17
+ #include <string>
18
+ #include <utility>
19
+ #include <vector>
20
+
21
+ using namespace ::std;
22
+
23
+ vector<vector<uint32_t>> edit_distance2_with_dp(
24
+ vector<uint32_t>& x,
25
+ vector<uint32_t>& y) {
26
+ uint32_t lx = x.size();
27
+ uint32_t ly = y.size();
28
+ vector<vector<uint32_t>> d(lx + 1, vector<uint32_t>(ly + 1));
29
+ for (uint32_t i = 0; i < lx + 1; i++) {
30
+ d[i][0] = i;
31
+ }
32
+ for (uint32_t j = 0; j < ly + 1; j++) {
33
+ d[0][j] = j;
34
+ }
35
+ for (uint32_t i = 1; i < lx + 1; i++) {
36
+ for (uint32_t j = 1; j < ly + 1; j++) {
37
+ d[i][j] =
38
+ min(min(d[i - 1][j], d[i][j - 1]) + 1,
39
+ d[i - 1][j - 1] + 2 * (x.at(i - 1) == y.at(j - 1) ? 0 : 1));
40
+ }
41
+ }
42
+ return d;
43
+ }
44
+
45
+ vector<vector<uint32_t>> edit_distance2_backtracking(
46
+ vector<vector<uint32_t>>& d,
47
+ vector<uint32_t>& x,
48
+ vector<uint32_t>& y,
49
+ uint32_t terminal_symbol) {
50
+ vector<uint32_t> seq;
51
+ vector<vector<uint32_t>> edit_seqs(x.size() + 2, vector<uint32_t>());
52
+ /*
53
+ edit_seqs:
54
+ 0~x.size() cell is the insertion sequences
55
+ last cell is the delete sequence
56
+ */
57
+
58
+ if (x.size() == 0) {
59
+ edit_seqs.at(0) = y;
60
+ return edit_seqs;
61
+ }
62
+
63
+ uint32_t i = d.size() - 1;
64
+ uint32_t j = d.at(0).size() - 1;
65
+
66
+ while ((i >= 0) && (j >= 0)) {
67
+ if ((i == 0) && (j == 0)) {
68
+ break;
69
+ }
70
+
71
+ if ((j > 0) && (d.at(i).at(j - 1) < d.at(i).at(j))) {
72
+ seq.push_back(1); // insert
73
+ seq.push_back(y.at(j - 1));
74
+ j--;
75
+ } else if ((i > 0) && (d.at(i - 1).at(j) < d.at(i).at(j))) {
76
+ seq.push_back(2); // delete
77
+ seq.push_back(x.at(i - 1));
78
+ i--;
79
+ } else {
80
+ seq.push_back(3); // keep
81
+ seq.push_back(x.at(i - 1));
82
+ i--;
83
+ j--;
84
+ }
85
+ }
86
+
87
+ uint32_t prev_op, op, s, word;
88
+ prev_op = 0, s = 0;
89
+ for (uint32_t k = 0; k < seq.size() / 2; k++) {
90
+ op = seq.at(seq.size() - 2 * k - 2);
91
+ word = seq.at(seq.size() - 2 * k - 1);
92
+ if (prev_op != 1) {
93
+ s++;
94
+ }
95
+ if (op == 1) // insert
96
+ {
97
+ edit_seqs.at(s - 1).push_back(word);
98
+ } else if (op == 2) // delete
99
+ {
100
+ edit_seqs.at(x.size() + 1).push_back(1);
101
+ } else {
102
+ edit_seqs.at(x.size() + 1).push_back(0);
103
+ }
104
+
105
+ prev_op = op;
106
+ }
107
+
108
+ for (uint32_t k = 0; k < edit_seqs.size(); k++) {
109
+ if (edit_seqs[k].size() == 0) {
110
+ edit_seqs[k].push_back(terminal_symbol);
111
+ }
112
+ }
113
+ return edit_seqs;
114
+ }
115
+
116
+ vector<vector<uint32_t>> edit_distance2_backtracking_with_delete(
117
+ vector<vector<uint32_t>>& d,
118
+ vector<uint32_t>& x,
119
+ vector<uint32_t>& y,
120
+ uint32_t terminal_symbol,
121
+ uint32_t deletion_symbol) {
122
+ vector<uint32_t> seq;
123
+ vector<vector<uint32_t>> edit_seqs(x.size() + 1, vector<uint32_t>());
124
+ /*
125
+ edit_seqs:
126
+ 0~x.size() cell is the insertion sequences
127
+ last cell is the delete sequence
128
+ */
129
+
130
+ if (x.size() == 0) {
131
+ edit_seqs.at(0) = y;
132
+ return edit_seqs;
133
+ }
134
+
135
+ uint32_t i = d.size() - 1;
136
+ uint32_t j = d.at(0).size() - 1;
137
+
138
+ while ((i >= 0) && (j >= 0)) {
139
+ if ((i == 0) && (j == 0)) {
140
+ break;
141
+ }
142
+
143
+ if ((j > 0) && (d.at(i).at(j - 1) < d.at(i).at(j))) {
144
+ seq.push_back(1); // insert
145
+ seq.push_back(y.at(j - 1));
146
+ j--;
147
+ } else if ((i > 0) && (d.at(i - 1).at(j) < d.at(i).at(j))) {
148
+ seq.push_back(2); // delete
149
+ seq.push_back(x.at(i - 1));
150
+ i--;
151
+ } else {
152
+ seq.push_back(3); // keep
153
+ seq.push_back(x.at(i - 1));
154
+ i--;
155
+ j--;
156
+ }
157
+ }
158
+
159
+ uint32_t prev_op, op, s, word;
160
+ prev_op = 0, s = 0;
161
+ for (uint32_t k = 0; k < seq.size() / 2; k++) {
162
+ op = seq.at(seq.size() - 2 * k - 2);
163
+ word = seq.at(seq.size() - 2 * k - 1);
164
+ if (prev_op != 1) {
165
+ s++;
166
+ }
167
+ if (op == 1) // insert
168
+ {
169
+ edit_seqs.at(s - 1).push_back(word);
170
+ } else if (op == 2) // delete
171
+ {
172
+ edit_seqs.at(s - 1).push_back(deletion_symbol);
173
+ }
174
+
175
+ prev_op = op;
176
+ }
177
+
178
+ for (uint32_t k = 0; k < edit_seqs.size(); k++) {
179
+ if (edit_seqs.at(k).size() == 0) {
180
+ edit_seqs.at(k).push_back(terminal_symbol);
181
+ }
182
+ }
183
+ return edit_seqs;
184
+ }
185
+
186
+ vector<uint32_t> compute_ed2(
187
+ vector<vector<uint32_t>>& xs,
188
+ vector<vector<uint32_t>>& ys) {
189
+ vector<uint32_t> distances(xs.size());
190
+ for (uint32_t i = 0; i < xs.size(); i++) {
191
+ vector<vector<uint32_t>> d = edit_distance2_with_dp(xs.at(i), ys.at(i));
192
+ distances.at(i) = d.at(xs.at(i).size()).at(ys.at(i).size());
193
+ }
194
+ return distances;
195
+ }
196
+
197
+ vector<vector<vector<uint32_t>>> suggested_ed2_path(
198
+ vector<vector<uint32_t>>& xs,
199
+ vector<vector<uint32_t>>& ys,
200
+ uint32_t terminal_symbol) {
201
+ vector<vector<vector<uint32_t>>> seq(xs.size());
202
+ for (uint32_t i = 0; i < xs.size(); i++) {
203
+ vector<vector<uint32_t>> d = edit_distance2_with_dp(xs.at(i), ys.at(i));
204
+ seq.at(i) =
205
+ edit_distance2_backtracking(d, xs.at(i), ys.at(i), terminal_symbol);
206
+ }
207
+ return seq;
208
+ }
209
+
210
+ vector<vector<vector<uint32_t>>> suggested_ed2_path_with_delete(
211
+ vector<vector<uint32_t>>& xs,
212
+ vector<vector<uint32_t>>& ys,
213
+ uint32_t terminal_symbol,
214
+ uint32_t deletion_symbol) {
215
+ vector<vector<vector<uint32_t>>> seq(xs.size());
216
+ for (uint32_t i = 0; i < xs.size(); i++) {
217
+ vector<vector<uint32_t>> d = edit_distance2_with_dp(xs.at(i), ys.at(i));
218
+ seq.at(i) = edit_distance2_backtracking_with_delete(
219
+ d, xs.at(i), ys.at(i), terminal_symbol, deletion_symbol);
220
+ }
221
+ return seq;
222
+ }
223
+
224
+ PYBIND11_MODULE(libnat, m) {
225
+ m.def("compute_ed2", &compute_ed2, "compute_ed2");
226
+ m.def("suggested_ed2_path", &suggested_ed2_path, "suggested_ed2_path");
227
+ m.def(
228
+ "suggested_ed2_path_with_delete",
229
+ &suggested_ed2_path_with_delete,
230
+ "suggested_ed2_path_with_delete");
231
+ }
fairseq/clib/libnat_cuda/binding.cpp ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**
2
+ * Copyright 2017-present, Facebook, Inc.
3
+ * All rights reserved.
4
+ *
5
+ * This source code is licensed under the license found in the
6
+ * LICENSE file in the root directory of this source tree.
7
+ */
8
+
9
+ /*
10
+ This code is partially adpoted from
11
+ https://github.com/1ytic/pytorch-edit-distance
12
+ */
13
+
14
+ #include <torch/types.h>
15
+ #include "edit_dist.h"
16
+
17
+ #ifndef TORCH_CHECK
18
+ #define TORCH_CHECK AT_CHECK
19
+ #endif
20
+
21
+ #define CHECK_CUDA(x) \
22
+ TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
23
+ #define CHECK_CONTIGUOUS(x) \
24
+ TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
25
+ #define CHECK_INPUT(x) \
26
+ CHECK_CUDA(x); \
27
+ CHECK_CONTIGUOUS(x)
28
+
29
+ torch::Tensor LevenshteinDistance(
30
+ torch::Tensor source,
31
+ torch::Tensor target,
32
+ torch::Tensor source_length,
33
+ torch::Tensor target_length) {
34
+ CHECK_INPUT(source);
35
+ CHECK_INPUT(target);
36
+ CHECK_INPUT(source_length);
37
+ CHECK_INPUT(target_length);
38
+ return LevenshteinDistanceCuda(source, target, source_length, target_length);
39
+ }
40
+
41
+ torch::Tensor GenerateDeletionLabel(
42
+ torch::Tensor source,
43
+ torch::Tensor operations) {
44
+ CHECK_INPUT(source);
45
+ CHECK_INPUT(operations);
46
+ return GenerateDeletionLabelCuda(source, operations);
47
+ }
48
+
49
+ std::pair<torch::Tensor, torch::Tensor> GenerateInsertionLabel(
50
+ torch::Tensor target,
51
+ torch::Tensor operations) {
52
+ CHECK_INPUT(target);
53
+ CHECK_INPUT(operations);
54
+ return GenerateInsertionLabelCuda(target, operations);
55
+ }
56
+
57
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
58
+ m.def("levenshtein_distance", &LevenshteinDistance, "Levenshtein distance");
59
+ m.def(
60
+ "generate_deletion_labels",
61
+ &GenerateDeletionLabel,
62
+ "Generate Deletion Label");
63
+ m.def(
64
+ "generate_insertion_labels",
65
+ &GenerateInsertionLabel,
66
+ "Generate Insertion Label");
67
+ }
fairseq/clib/libnat_cuda/edit_dist.cu ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**
2
+ * Copyright 2017-present, Facebook, Inc.
3
+ * All rights reserved.
4
+ *
5
+ * This source code is licensed under the license found in the
6
+ * LICENSE file in the root directory of this source tree.
7
+ */
8
+
9
+ #include "edit_dist.h"
10
+
11
+ #include <c10/cuda/CUDAStream.h>
12
+ #include <cuda.h>
13
+ #include <cuda_runtime.h>
14
+ #include <device_launch_parameters.h>
15
+ #include <utility> // std::pair
16
+
17
+ template <typename scalar_t>
18
+ __global__ void generate_deletion_label_kernel(
19
+ const scalar_t* __restrict__ source,
20
+ const size_t source_size,
21
+ const size_t operation_size,
22
+ int* __restrict__ operations,
23
+ int* __restrict__ labels) {
24
+ const int index = blockIdx.x;
25
+ const int offset = index * operation_size;
26
+ const int offset_label = index * source_size;
27
+
28
+ for (int i = 0; i < source_size; i++) {
29
+ labels[offset_label + i] = 0;
30
+ }
31
+
32
+ int k = 0;
33
+ for (int i = 0; i < operation_size; i++) {
34
+ if (operations[offset + i] == 0) {
35
+ break;
36
+ } else if (operations[offset + i] == 1) {
37
+ continue;
38
+ } else {
39
+ labels[offset_label + k] = 3 - operations[offset + i];
40
+ k++;
41
+ }
42
+ }
43
+ }
44
+
45
+ template <typename scalar_t>
46
+ __global__ void generate_insertion_label_kernel(
47
+ const scalar_t* __restrict__ target,
48
+ const size_t target_size,
49
+ const size_t operation_size,
50
+ int* __restrict__ operations,
51
+ int* __restrict__ labels,
52
+ int* __restrict__ masks) {
53
+ const int index = blockIdx.x;
54
+ const int offset = index * operation_size;
55
+ const int offset_label = index * target_size;
56
+
57
+ int k = 0;
58
+ int u = 0;
59
+ int m = 0;
60
+
61
+ for (int i = 0; i < target_size; i++) {
62
+ labels[offset_label + i] = 0;
63
+ masks[offset_label + i] = 0;
64
+ }
65
+
66
+ for (int i = 0; i < operation_size - 1; i++) {
67
+ if (operations[offset + i] == 0) {
68
+ break;
69
+ } else if (operations[offset + i] == 2) {
70
+ continue;
71
+ } else if (operations[offset + i] == 1) {
72
+ masks[offset_label + m] = 1;
73
+ u++;
74
+ m++;
75
+ } else {
76
+ labels[offset_label + k] = u;
77
+ masks[offset_label + m] = 0;
78
+ k++;
79
+ m++;
80
+ u = 0;
81
+ }
82
+ }
83
+ }
84
+
85
+ template <typename scalar_t>
86
+ __global__ void levenshtein_distance_kernel(
87
+ const scalar_t* __restrict__ source,
88
+ const scalar_t* __restrict__ target,
89
+ const int* __restrict__ source_length,
90
+ const int* __restrict__ target_length,
91
+ const size_t source_size,
92
+ const size_t target_size,
93
+ int* __restrict__ operations,
94
+ int* __restrict__ errors_curr) {
95
+ const int index = blockIdx.x;
96
+ const int offset = index * (source_size + target_size);
97
+ const int d = index * (source_size + 1) * (target_size + 1);
98
+ const int t = target_size + 1;
99
+
100
+ auto err_idx = [d, t](int i, int j) { return d + i * t + j; };
101
+ auto opt_idx = [offset](int k) { return offset + k; };
102
+
103
+ const int hyp_len = source_length[index];
104
+ const int ref_len = target_length[index];
105
+ const scalar_t* hyp_begin = source + index * source_size;
106
+ const scalar_t* ref_begin = target + index * target_size;
107
+
108
+ // dynamic programming
109
+ for (int i = 0; i <= hyp_len; i++) {
110
+ errors_curr[err_idx(i, 0)] = i;
111
+ }
112
+ for (int j = 0; j <= ref_len; j++) {
113
+ errors_curr[err_idx(0, j)] = j;
114
+ }
115
+ for (int i = 1; i <= hyp_len; i++) {
116
+ for (int j = 1; j <= ref_len; j++) {
117
+ errors_curr[err_idx(i, j)] = min(
118
+ min(errors_curr[err_idx(i - 1, j)], errors_curr[err_idx(i, j - 1)]) +
119
+ 1,
120
+ errors_curr[err_idx(i - 1, j - 1)] +
121
+ 2 * (*(hyp_begin + i - 1) == *(ref_begin + j - 1) ? 0 : 1));
122
+ }
123
+ }
124
+
125
+ // back-tracing
126
+ int i = hyp_len;
127
+ int j = ref_len;
128
+ int o = hyp_len + ref_len;
129
+
130
+ for (int k = 0; k < source_size + target_size; k++) {
131
+ operations[opt_idx(k)] = 0;
132
+ }
133
+
134
+ while ((i >= 0) && (j >= 0)) {
135
+ if ((i == 0) && (j == 0)) {
136
+ break;
137
+ }
138
+
139
+ if ((j > 0) &&
140
+ (errors_curr[err_idx(i, j - 1)] < errors_curr[err_idx(i, j)])) {
141
+ o--;
142
+ operations[opt_idx(o)] = 1;
143
+ j--; // insertion
144
+ } else if (
145
+ (i > 0) &&
146
+ (errors_curr[err_idx(i - 1, j)] < errors_curr[err_idx(i, j)])) {
147
+ o--;
148
+ operations[opt_idx(o)] = 2;
149
+ i--; // deletion
150
+ } else {
151
+ o--;
152
+ operations[opt_idx(o)] = 3;
153
+ i--;
154
+ j--; // do nothing
155
+ }
156
+ }
157
+
158
+ // moving to the left
159
+ for (int k = 0; k < hyp_len + ref_len; k++) {
160
+ if (k + o < hyp_len + ref_len) {
161
+ operations[opt_idx(k)] = operations[opt_idx(k + o)];
162
+ } else {
163
+ operations[opt_idx(k)] = 0; // padding
164
+ }
165
+ }
166
+ }
167
+
168
+ template <typename scalar_t>
169
+ __global__ void faster_levenshtein_distance_kernel(
170
+ const scalar_t* __restrict__ source,
171
+ const scalar_t* __restrict__ target,
172
+ const int* __restrict__ source_length,
173
+ const int* __restrict__ target_length,
174
+ const size_t source_size,
175
+ const size_t target_size,
176
+ int* __restrict__ operations) {
177
+ extern __shared__ short errors[];
178
+ auto errors_curr = errors;
179
+
180
+ const int index = blockIdx.x;
181
+ const int offset = index * (source_size + target_size);
182
+ const int t = target_size + 1;
183
+
184
+ auto err_idx = [t](int i, int j) { return i * t + j; };
185
+ auto opt_idx = [offset](int k) { return offset + k; };
186
+
187
+ const int hyp_len = source_length[index];
188
+ const int ref_len = target_length[index];
189
+ const scalar_t* hyp_begin = source + index * source_size;
190
+ const scalar_t* ref_begin = target + index * target_size;
191
+
192
+ // dynamic programming
193
+ for (int i = 0; i <= hyp_len; i++) {
194
+ errors_curr[err_idx(i, 0)] = i;
195
+ }
196
+ for (int j = 0; j <= ref_len; j++) {
197
+ errors_curr[err_idx(0, j)] = j;
198
+ }
199
+ for (int i = 1; i <= hyp_len; i++) {
200
+ for (int j = 1; j <= ref_len; j++) {
201
+ errors_curr[err_idx(i, j)] = min(
202
+ min(errors_curr[err_idx(i - 1, j)], errors_curr[err_idx(i, j - 1)]) +
203
+ 1,
204
+ errors_curr[err_idx(i - 1, j - 1)] +
205
+ 2 * (*(hyp_begin + i - 1) == *(ref_begin + j - 1) ? 0 : 1));
206
+ }
207
+ }
208
+
209
+ // back-tracing
210
+ int i = hyp_len;
211
+ int j = ref_len;
212
+ int o = hyp_len + ref_len;
213
+
214
+ for (int k = 0; k < source_size + target_size; k++) {
215
+ operations[opt_idx(k)] = 0;
216
+ }
217
+
218
+ while ((i >= 0) && (j >= 0)) {
219
+ if ((i == 0) && (j == 0)) {
220
+ break;
221
+ }
222
+
223
+ if ((j > 0) &&
224
+ (errors_curr[err_idx(i, j - 1)] < errors_curr[err_idx(i, j)])) {
225
+ o--;
226
+ operations[opt_idx(o)] = 1;
227
+ j--; // insertion
228
+ } else if (
229
+ (i > 0) &&
230
+ (errors_curr[err_idx(i - 1, j)] < errors_curr[err_idx(i, j)])) {
231
+ o--;
232
+ operations[opt_idx(o)] = 2;
233
+ i--; // deletion
234
+ } else {
235
+ o--;
236
+ operations[opt_idx(o)] = 3;
237
+ i--;
238
+ j--; // do nothing
239
+ }
240
+ }
241
+
242
+ // moving to the left
243
+ for (int k = 0; k < hyp_len + ref_len; k++) {
244
+ if (k + o < hyp_len + ref_len) {
245
+ operations[opt_idx(k)] = operations[opt_idx(k + o)];
246
+ } else {
247
+ operations[opt_idx(k)] = 0; // padding
248
+ }
249
+ }
250
+ }
251
+
252
+ torch::Tensor GenerateDeletionLabelCuda(
253
+ torch::Tensor source,
254
+ torch::Tensor operations) {
255
+ const auto batch_size = source.size(0);
256
+ at::TensorOptions options(source.device());
257
+ options = options.dtype(at::ScalarType::Int);
258
+ auto labels = torch::empty({batch_size, source.size(1)}, options);
259
+ auto stream = at::cuda::getCurrentCUDAStream(source.device().index());
260
+
261
+ AT_DISPATCH_ALL_TYPES(source.scalar_type(), "generate_deletion_labels", ([&] {
262
+ generate_deletion_label_kernel<scalar_t>
263
+ <<<batch_size, 1, 0, stream>>>(
264
+ source.data_ptr<scalar_t>(),
265
+ source.size(1),
266
+ operations.size(1),
267
+ operations.data_ptr<int>(),
268
+ labels.data_ptr<int>());
269
+ }));
270
+
271
+ return labels;
272
+ }
273
+
274
+ std::pair<torch::Tensor, torch::Tensor> GenerateInsertionLabelCuda(
275
+ torch::Tensor target,
276
+ torch::Tensor operations) {
277
+ const auto batch_size = target.size(0);
278
+ at::TensorOptions options(target.device());
279
+ options = options.dtype(at::ScalarType::Int);
280
+ auto labels = torch::empty({batch_size, target.size(1)}, options);
281
+ auto masks = torch::empty({batch_size, target.size(1)}, options);
282
+ auto stream = at::cuda::getCurrentCUDAStream(target.device().index());
283
+
284
+ AT_DISPATCH_ALL_TYPES(
285
+ target.scalar_type(), "generate_insertion_labels", ([&] {
286
+ generate_insertion_label_kernel<scalar_t><<<batch_size, 1, 0, stream>>>(
287
+ target.data_ptr<scalar_t>(),
288
+ target.size(1),
289
+ operations.size(1),
290
+ operations.data_ptr<int>(),
291
+ labels.data_ptr<int>(),
292
+ masks.data_ptr<int>());
293
+ }));
294
+
295
+ return std::make_pair(labels, masks);
296
+ }
297
+
298
+ torch::Tensor LevenshteinDistanceCuda(
299
+ torch::Tensor source,
300
+ torch::Tensor target,
301
+ torch::Tensor source_length,
302
+ torch::Tensor target_length) {
303
+ const auto batch_size = source.size(0);
304
+ const auto shared_size =
305
+ (source.size(1) + 1) * (target.size(1) + 1) * sizeof(short);
306
+
307
+ at::TensorOptions options(source.device());
308
+ options = options.dtype(at::ScalarType::Int);
309
+ auto operations =
310
+ torch::empty({batch_size, source.size(1) + target.size(1)}, options);
311
+ auto stream = at::cuda::getCurrentCUDAStream(source.device().index());
312
+
313
+ if (shared_size > 40000) {
314
+ auto distances = torch::empty(
315
+ {batch_size, (source.size(1) + 1) * (target.size(1) + 1)}, options);
316
+ AT_DISPATCH_ALL_TYPES(source.scalar_type(), "levenshtein_distance", ([&] {
317
+ levenshtein_distance_kernel<scalar_t>
318
+ <<<batch_size, 1, 0, stream>>>(
319
+ source.data_ptr<scalar_t>(),
320
+ target.data_ptr<scalar_t>(),
321
+ source_length.data_ptr<int>(),
322
+ target_length.data_ptr<int>(),
323
+ source.size(1),
324
+ target.size(1),
325
+ operations.data_ptr<int>(),
326
+ distances.data_ptr<int>());
327
+ }));
328
+ } else {
329
+ AT_DISPATCH_ALL_TYPES(
330
+ source.scalar_type(), "faster_levenshtein_distance", ([&] {
331
+ faster_levenshtein_distance_kernel<scalar_t>
332
+ <<<batch_size, 1, shared_size, stream>>>(
333
+ source.data_ptr<scalar_t>(),
334
+ target.data_ptr<scalar_t>(),
335
+ source_length.data_ptr<int>(),
336
+ target_length.data_ptr<int>(),
337
+ source.size(1),
338
+ target.size(1),
339
+ operations.data_ptr<int>());
340
+ }));
341
+ }
342
+
343
+ return operations;
344
+ }
fairseq/clib/libnat_cuda/edit_dist.h ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**
2
+ * Copyright 2017-present, Facebook, Inc.
3
+ * All rights reserved.
4
+ *
5
+ * This source code is licensed under the license found in the
6
+ * LICENSE file in the root directory of this source tree.
7
+ */
8
+
9
+ #pragma once
10
+
11
+ #include <torch/extension.h>
12
+
13
+ torch::Tensor LevenshteinDistanceCuda(
14
+ torch::Tensor source,
15
+ torch::Tensor target,
16
+ torch::Tensor source_length,
17
+ torch::Tensor target_length);
18
+
19
+ torch::Tensor GenerateDeletionLabelCuda(
20
+ torch::Tensor source,
21
+ torch::Tensor operations);
22
+
23
+ std::pair<torch::Tensor, torch::Tensor> GenerateInsertionLabelCuda(
24
+ torch::Tensor source,
25
+ torch::Tensor operations);
fairseq/config/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
fairseq/config/config.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+
3
+ hydra:
4
+ run:
5
+ dir: .
6
+
7
+ defaults:
8
+ - _self_
9
+ - task: null
10
+ - model: null
11
+ - criterion: cross_entropy
12
+ - optimizer: null
13
+ - lr_scheduler: fixed
14
+ - bpe: null
15
+ - tokenizer: null
16
+ - scoring: null
17
+ - generation: null
18
+ - common_eval: null
19
+ - eval_lm: null
fairseq/config/fb_run_config/slurm.yaml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ hydra:
4
+ job:
5
+ config:
6
+ override_dirname:
7
+ kv_sep: ':'
8
+ item_sep: '__'
9
+ exclude_keys:
10
+ - fb_run_config
11
+ - distributed_training.distributed_port
12
+ sweep:
13
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
14
+ launcher:
15
+ cpus_per_task: 60
16
+ gpus_per_node: ???
17
+ tasks_per_node: 1
18
+ nodes: 1
19
+ partition: learnfair
20
+ mem_gb: 400
21
+ timeout_min: 4320
22
+ max_num_timeout: 10
23
+ name: ${env:PREFIX}_${hydra.job.config_name}
24
+ submitit_folder: ${hydra.sweep.dir}
25
+
26
+ distributed_training:
27
+ ddp_backend: c10d
28
+ distributed_world_size: ???
29
+ distributed_port: ???