Spaces:
Sleeping
Sleeping
Yixuan Li
commited on
Commit
·
85ba398
1
Parent(s):
e3e7837
add fairseq folder
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- fairseq/__init__.py +45 -0
- fairseq/__pycache__/__init__.cpython-310.pyc +0 -0
- fairseq/__pycache__/__init__.cpython-311.pyc +0 -0
- fairseq/__pycache__/checkpoint_utils.cpython-310.pyc +0 -0
- fairseq/__pycache__/file_chunker_utils.cpython-310.pyc +0 -0
- fairseq/__pycache__/file_io.cpython-310.pyc +0 -0
- fairseq/__pycache__/file_utils.cpython-310.pyc +0 -0
- fairseq/__pycache__/hub_utils.cpython-310.pyc +0 -0
- fairseq/__pycache__/incremental_decoding_utils.cpython-310.pyc +0 -0
- fairseq/__pycache__/iterative_refinement_generator.cpython-310.pyc +0 -0
- fairseq/__pycache__/ngram_repeat_block.cpython-310.pyc +0 -0
- fairseq/__pycache__/options.cpython-310.pyc +0 -0
- fairseq/__pycache__/pdb.cpython-310.pyc +0 -0
- fairseq/__pycache__/quantization_utils.cpython-310.pyc +0 -0
- fairseq/__pycache__/registry.cpython-310.pyc +0 -0
- fairseq/__pycache__/search.cpython-310.pyc +0 -0
- fairseq/__pycache__/sequence_generator.cpython-310.pyc +0 -0
- fairseq/__pycache__/speech_generator.cpython-310.pyc +0 -0
- fairseq/__pycache__/token_generation_constraints.cpython-310.pyc +0 -0
- fairseq/__pycache__/tokenizer.cpython-310.pyc +0 -0
- fairseq/__pycache__/utils.cpython-310.pyc +0 -0
- fairseq/__pycache__/version.cpython-310.pyc +0 -0
- fairseq/__pycache__/version.cpython-311.pyc +0 -0
- fairseq/benchmark/__init__.py +7 -0
- fairseq/benchmark/__pycache__/__init__.cpython-310.pyc +0 -0
- fairseq/benchmark/__pycache__/dummy_dataset.cpython-310.pyc +0 -0
- fairseq/benchmark/__pycache__/dummy_lm.cpython-310.pyc +0 -0
- fairseq/benchmark/__pycache__/dummy_masked_lm.cpython-310.pyc +0 -0
- fairseq/benchmark/__pycache__/dummy_model.cpython-310.pyc +0 -0
- fairseq/benchmark/__pycache__/dummy_mt.cpython-310.pyc +0 -0
- fairseq/benchmark/benchmark_multihead_attention.py +172 -0
- fairseq/benchmark/dummy_dataset.py +36 -0
- fairseq/benchmark/dummy_lm.py +83 -0
- fairseq/benchmark/dummy_masked_lm.py +94 -0
- fairseq/benchmark/dummy_model.py +96 -0
- fairseq/benchmark/dummy_mt.py +119 -0
- fairseq/binarizer.py +381 -0
- fairseq/checkpoint_utils.py +936 -0
- fairseq/clib/cuda/ngram_repeat_block_cuda.cpp +55 -0
- fairseq/clib/cuda/ngram_repeat_block_cuda_kernel.cu +82 -0
- fairseq/clib/libbase/balanced_assignment.cpp +109 -0
- fairseq/clib/libbleu/libbleu.cpp +157 -0
- fairseq/clib/libbleu/module.cpp +33 -0
- fairseq/clib/libnat/edit_dist.cpp +231 -0
- fairseq/clib/libnat_cuda/binding.cpp +67 -0
- fairseq/clib/libnat_cuda/edit_dist.cu +344 -0
- fairseq/clib/libnat_cuda/edit_dist.h +25 -0
- fairseq/config/__init__.py +4 -0
- fairseq/config/config.yaml +19 -0
- fairseq/config/fb_run_config/slurm.yaml +29 -0
fairseq/__init__.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
"""isort:skip_file"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import sys
|
| 9 |
+
|
| 10 |
+
try:
|
| 11 |
+
from .version import __version__ # noqa
|
| 12 |
+
except ImportError:
|
| 13 |
+
version_txt = os.path.join(os.path.dirname(__file__), "version.txt")
|
| 14 |
+
with open(version_txt) as f:
|
| 15 |
+
__version__ = f.read().strip()
|
| 16 |
+
|
| 17 |
+
__all__ = ["pdb"]
|
| 18 |
+
|
| 19 |
+
# backwards compatibility to support `from fairseq.X import Y`
|
| 20 |
+
from fairseq.distributed import utils as distributed_utils
|
| 21 |
+
from fairseq.logging import meters, metrics, progress_bar # noqa
|
| 22 |
+
|
| 23 |
+
sys.modules["fairseq.distributed_utils"] = distributed_utils
|
| 24 |
+
sys.modules["fairseq.meters"] = meters
|
| 25 |
+
sys.modules["fairseq.metrics"] = metrics
|
| 26 |
+
sys.modules["fairseq.progress_bar"] = progress_bar
|
| 27 |
+
|
| 28 |
+
# initialize hydra
|
| 29 |
+
from fairseq.dataclass.initialize import hydra_init
|
| 30 |
+
|
| 31 |
+
hydra_init()
|
| 32 |
+
|
| 33 |
+
import fairseq.criterions # noqa
|
| 34 |
+
import fairseq.distributed # noqa
|
| 35 |
+
import fairseq.models # noqa
|
| 36 |
+
import fairseq.modules # noqa
|
| 37 |
+
import fairseq.optim # noqa
|
| 38 |
+
import fairseq.optim.lr_scheduler # noqa
|
| 39 |
+
import fairseq.pdb # noqa
|
| 40 |
+
import fairseq.scoring # noqa
|
| 41 |
+
import fairseq.tasks # noqa
|
| 42 |
+
import fairseq.token_generation_constraints # noqa
|
| 43 |
+
|
| 44 |
+
import fairseq.benchmark # noqa
|
| 45 |
+
import fairseq.model_parallel # noqa
|
fairseq/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (1.2 kB). View file
|
|
|
fairseq/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (2.01 kB). View file
|
|
|
fairseq/__pycache__/checkpoint_utils.cpython-310.pyc
ADDED
|
Binary file (22 kB). View file
|
|
|
fairseq/__pycache__/file_chunker_utils.cpython-310.pyc
ADDED
|
Binary file (2.77 kB). View file
|
|
|
fairseq/__pycache__/file_io.cpython-310.pyc
ADDED
|
Binary file (5.11 kB). View file
|
|
|
fairseq/__pycache__/file_utils.cpython-310.pyc
ADDED
|
Binary file (9.1 kB). View file
|
|
|
fairseq/__pycache__/hub_utils.cpython-310.pyc
ADDED
|
Binary file (10.9 kB). View file
|
|
|
fairseq/__pycache__/incremental_decoding_utils.cpython-310.pyc
ADDED
|
Binary file (2.21 kB). View file
|
|
|
fairseq/__pycache__/iterative_refinement_generator.cpython-310.pyc
ADDED
|
Binary file (8.72 kB). View file
|
|
|
fairseq/__pycache__/ngram_repeat_block.cpython-310.pyc
ADDED
|
Binary file (3.78 kB). View file
|
|
|
fairseq/__pycache__/options.cpython-310.pyc
ADDED
|
Binary file (10.6 kB). View file
|
|
|
fairseq/__pycache__/pdb.cpython-310.pyc
ADDED
|
Binary file (1.32 kB). View file
|
|
|
fairseq/__pycache__/quantization_utils.cpython-310.pyc
ADDED
|
Binary file (3.57 kB). View file
|
|
|
fairseq/__pycache__/registry.cpython-310.pyc
ADDED
|
Binary file (2.52 kB). View file
|
|
|
fairseq/__pycache__/search.cpython-310.pyc
ADDED
|
Binary file (24.3 kB). View file
|
|
|
fairseq/__pycache__/sequence_generator.cpython-310.pyc
ADDED
|
Binary file (24 kB). View file
|
|
|
fairseq/__pycache__/speech_generator.cpython-310.pyc
ADDED
|
Binary file (10.7 kB). View file
|
|
|
fairseq/__pycache__/token_generation_constraints.cpython-310.pyc
ADDED
|
Binary file (16.2 kB). View file
|
|
|
fairseq/__pycache__/tokenizer.cpython-310.pyc
ADDED
|
Binary file (367 Bytes). View file
|
|
|
fairseq/__pycache__/utils.cpython-310.pyc
ADDED
|
Binary file (29.8 kB). View file
|
|
|
fairseq/__pycache__/version.cpython-310.pyc
ADDED
|
Binary file (166 Bytes). View file
|
|
|
fairseq/__pycache__/version.cpython-311.pyc
ADDED
|
Binary file (238 Bytes). View file
|
|
|
fairseq/benchmark/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# import models/tasks to register them
|
| 7 |
+
from . import dummy_dataset, dummy_lm, dummy_masked_lm, dummy_model, dummy_mt # noqa
|
fairseq/benchmark/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (279 Bytes). View file
|
|
|
fairseq/benchmark/__pycache__/dummy_dataset.cpython-310.pyc
ADDED
|
Binary file (1.74 kB). View file
|
|
|
fairseq/benchmark/__pycache__/dummy_lm.cpython-310.pyc
ADDED
|
Binary file (3.12 kB). View file
|
|
|
fairseq/benchmark/__pycache__/dummy_masked_lm.cpython-310.pyc
ADDED
|
Binary file (3.33 kB). View file
|
|
|
fairseq/benchmark/__pycache__/dummy_model.cpython-310.pyc
ADDED
|
Binary file (3.42 kB). View file
|
|
|
fairseq/benchmark/__pycache__/dummy_mt.cpython-310.pyc
ADDED
|
Binary file (4.5 kB). View file
|
|
|
fairseq/benchmark/benchmark_multihead_attention.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import itertools
|
| 7 |
+
import random
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from torch.utils import benchmark
|
| 11 |
+
|
| 12 |
+
from fairseq.modules.multihead_attention import MultiheadAttention
|
| 13 |
+
|
| 14 |
+
BATCH = [20, 41, 97]
|
| 15 |
+
SEQ = 64
|
| 16 |
+
EMB = 48
|
| 17 |
+
HEADS = 4
|
| 18 |
+
DROP = 0.1
|
| 19 |
+
DEVICE = torch.device("cuda")
|
| 20 |
+
ATTN_MASK_DTYPE = [torch.uint8, torch.bool, torch.float]
|
| 21 |
+
KEY_PADDING_MASK_DTYPE = [torch.uint8, torch.bool]
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _reset_seeds():
|
| 25 |
+
torch.manual_seed(0)
|
| 26 |
+
random.seed(0)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def _get_mask(to_dtype: torch.dtype, dim0: int, dim1: int):
|
| 30 |
+
if to_dtype == torch.float:
|
| 31 |
+
mask = torch.randint(0, 2, (dim0, dim1)).to(dtype=torch.bool)
|
| 32 |
+
return mask.to(dtype=to_dtype).masked_fill(mask, -float("inf"))
|
| 33 |
+
return torch.randint(0, 2, (dim0, dim1)).to(dtype=to_dtype)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def benchmark_multihead_attention(
|
| 37 |
+
label="",
|
| 38 |
+
attn_dtype=torch.uint8,
|
| 39 |
+
key_padding_dtype=torch.uint8,
|
| 40 |
+
add_bias_kv=False,
|
| 41 |
+
add_zero_attn=False,
|
| 42 |
+
static_kv=False,
|
| 43 |
+
batch_size=20,
|
| 44 |
+
embedding=EMB,
|
| 45 |
+
seq_len=SEQ,
|
| 46 |
+
num_heads=HEADS,
|
| 47 |
+
):
|
| 48 |
+
|
| 49 |
+
results = []
|
| 50 |
+
# device = torch.device("cuda")
|
| 51 |
+
|
| 52 |
+
xformers_att_config = '{"name": "scaled_dot_product"}'
|
| 53 |
+
|
| 54 |
+
attn_mask = _get_mask(to_dtype=attn_dtype, dim0=seq_len, dim1=seq_len)
|
| 55 |
+
key_padding_mask = _get_mask(
|
| 56 |
+
to_dtype=key_padding_dtype, dim0=batch_size, dim1=seq_len
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
q = torch.rand(seq_len, batch_size, embedding, requires_grad=True)
|
| 60 |
+
k = torch.rand(seq_len, batch_size, embedding, requires_grad=True)
|
| 61 |
+
v = torch.rand(seq_len, batch_size, embedding, requires_grad=True)
|
| 62 |
+
|
| 63 |
+
_reset_seeds()
|
| 64 |
+
|
| 65 |
+
original_mha = MultiheadAttention(
|
| 66 |
+
embedding,
|
| 67 |
+
num_heads,
|
| 68 |
+
dropout=0.0,
|
| 69 |
+
xformers_att_config=None,
|
| 70 |
+
add_bias_kv=add_bias_kv,
|
| 71 |
+
add_zero_attn=add_zero_attn,
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
xformers_mha = MultiheadAttention(
|
| 75 |
+
embedding,
|
| 76 |
+
num_heads,
|
| 77 |
+
dropout=0.0,
|
| 78 |
+
xformers_att_config=xformers_att_config,
|
| 79 |
+
add_bias_kv=add_bias_kv,
|
| 80 |
+
add_zero_attn=add_zero_attn,
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
def original_bench_fw(q, k, v, key_padding_mask, attn_mask, static_kv):
|
| 84 |
+
original_mha(
|
| 85 |
+
query=q,
|
| 86 |
+
key=k,
|
| 87 |
+
value=v,
|
| 88 |
+
key_padding_mask=key_padding_mask,
|
| 89 |
+
attn_mask=attn_mask,
|
| 90 |
+
static_kv=static_kv,
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
def xformers_bench_fw(q, k, v, key_padding_mask, attn_mask, static_kv):
|
| 94 |
+
xformers_mha(
|
| 95 |
+
query=q,
|
| 96 |
+
key=k,
|
| 97 |
+
value=v,
|
| 98 |
+
key_padding_mask=key_padding_mask,
|
| 99 |
+
attn_mask=attn_mask,
|
| 100 |
+
static_kv=static_kv,
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
def original_bench_fw_bw(q, k, v, key_padding_mask, attn_mask, static_kv):
|
| 104 |
+
output, _ = original_mha(
|
| 105 |
+
query=q,
|
| 106 |
+
key=k,
|
| 107 |
+
value=v,
|
| 108 |
+
key_padding_mask=key_padding_mask,
|
| 109 |
+
attn_mask=attn_mask,
|
| 110 |
+
static_kv=static_kv,
|
| 111 |
+
)
|
| 112 |
+
loss = torch.norm(output)
|
| 113 |
+
loss.backward()
|
| 114 |
+
|
| 115 |
+
def xformers_bench_fw_bw(q, k, v, key_padding_mask, attn_mask, static_kv):
|
| 116 |
+
output, _ = xformers_mha(
|
| 117 |
+
query=q,
|
| 118 |
+
key=k,
|
| 119 |
+
value=v,
|
| 120 |
+
key_padding_mask=key_padding_mask,
|
| 121 |
+
attn_mask=attn_mask,
|
| 122 |
+
static_kv=static_kv,
|
| 123 |
+
)
|
| 124 |
+
loss = torch.norm(output)
|
| 125 |
+
loss.backward()
|
| 126 |
+
|
| 127 |
+
fns = [
|
| 128 |
+
original_bench_fw,
|
| 129 |
+
xformers_bench_fw,
|
| 130 |
+
original_bench_fw_bw,
|
| 131 |
+
xformers_bench_fw_bw,
|
| 132 |
+
]
|
| 133 |
+
|
| 134 |
+
for fn in fns:
|
| 135 |
+
results.append(
|
| 136 |
+
benchmark.Timer(
|
| 137 |
+
stmt="fn(q, k, v, key_padding_mask, attn_mask, static_kv)",
|
| 138 |
+
globals={
|
| 139 |
+
"q": q,
|
| 140 |
+
"k": k,
|
| 141 |
+
"v": v,
|
| 142 |
+
"key_padding_mask": key_padding_mask,
|
| 143 |
+
"attn_mask": attn_mask,
|
| 144 |
+
"static_kv": static_kv,
|
| 145 |
+
"fn": fn,
|
| 146 |
+
},
|
| 147 |
+
label="multihead fw + bw",
|
| 148 |
+
sub_label=f"{fn.__name__}",
|
| 149 |
+
description=label,
|
| 150 |
+
).blocked_autorange(min_run_time=1)
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
compare = benchmark.Compare(results)
|
| 154 |
+
compare.print()
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def run_benchmarks():
|
| 158 |
+
for attn_dtype, key_padding_dtype, add_bias_kv, add_zero_attn in itertools.product(
|
| 159 |
+
ATTN_MASK_DTYPE, KEY_PADDING_MASK_DTYPE, [True, False], [True, False]
|
| 160 |
+
):
|
| 161 |
+
label = f"attn_dtype {attn_dtype}, key_padding_dtype {key_padding_dtype}, \
|
| 162 |
+
add_bias_kv {add_bias_kv}, add_zero_attn {add_zero_attn}"
|
| 163 |
+
benchmark_multihead_attention(
|
| 164 |
+
label=label,
|
| 165 |
+
attn_dtype=attn_dtype,
|
| 166 |
+
key_padding_dtype=key_padding_dtype,
|
| 167 |
+
add_bias_kv=add_bias_kv,
|
| 168 |
+
add_zero_attn=add_zero_attn,
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
run_benchmarks()
|
fairseq/benchmark/dummy_dataset.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from fairseq.data import FairseqDataset
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class DummyDataset(FairseqDataset):
|
| 6 |
+
def __init__(self, batch, num_items, item_size):
|
| 7 |
+
super().__init__()
|
| 8 |
+
self.batch = batch
|
| 9 |
+
self.num_items = num_items
|
| 10 |
+
self.item_size = item_size
|
| 11 |
+
|
| 12 |
+
def __getitem__(self, index):
|
| 13 |
+
return index
|
| 14 |
+
|
| 15 |
+
def __len__(self):
|
| 16 |
+
return self.num_items
|
| 17 |
+
|
| 18 |
+
def collater(self, samples):
|
| 19 |
+
return self.batch
|
| 20 |
+
|
| 21 |
+
@property
|
| 22 |
+
def sizes(self):
|
| 23 |
+
return np.array([self.item_size] * self.num_items)
|
| 24 |
+
|
| 25 |
+
def num_tokens(self, index):
|
| 26 |
+
return self.item_size
|
| 27 |
+
|
| 28 |
+
def size(self, index):
|
| 29 |
+
return self.item_size
|
| 30 |
+
|
| 31 |
+
def ordered_indices(self):
|
| 32 |
+
return np.arange(self.num_items)
|
| 33 |
+
|
| 34 |
+
@property
|
| 35 |
+
def supports_prefetch(self):
|
| 36 |
+
return False
|
fairseq/benchmark/dummy_lm.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import logging
|
| 7 |
+
from dataclasses import dataclass, field
|
| 8 |
+
from typing import Optional
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from .dummy_dataset import DummyDataset
|
| 12 |
+
from fairseq.data import Dictionary
|
| 13 |
+
from fairseq.dataclass import FairseqDataclass
|
| 14 |
+
from fairseq.tasks import FairseqTask, register_task
|
| 15 |
+
from omegaconf import II
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@dataclass
|
| 22 |
+
class DummyLMConfig(FairseqDataclass):
|
| 23 |
+
dict_size: int = 49996
|
| 24 |
+
dataset_size: int = 100000
|
| 25 |
+
tokens_per_sample: int = field(
|
| 26 |
+
default=512, metadata={"help": "max sequence length"}
|
| 27 |
+
)
|
| 28 |
+
add_bos_token: bool = False
|
| 29 |
+
batch_size: Optional[int] = II("dataset.batch_size")
|
| 30 |
+
max_tokens: Optional[int] = II("dataset.max_tokens")
|
| 31 |
+
max_target_positions: int = II("task.tokens_per_sample")
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@register_task("dummy_lm", dataclass=DummyLMConfig)
|
| 35 |
+
class DummyLMTask(FairseqTask):
|
| 36 |
+
def __init__(self, cfg: DummyLMConfig):
|
| 37 |
+
super().__init__(cfg)
|
| 38 |
+
|
| 39 |
+
# load dictionary
|
| 40 |
+
self.dictionary = Dictionary()
|
| 41 |
+
for i in range(cfg.dict_size):
|
| 42 |
+
self.dictionary.add_symbol("word{}".format(i))
|
| 43 |
+
self.dictionary.pad_to_multiple_(8) # often faster if divisible by 8
|
| 44 |
+
logger.info("dictionary: {} types".format(len(self.dictionary)))
|
| 45 |
+
|
| 46 |
+
seq = torch.arange(cfg.tokens_per_sample + 1) + self.dictionary.pad() + 1
|
| 47 |
+
|
| 48 |
+
self.dummy_src = seq[:-1]
|
| 49 |
+
self.dummy_tgt = seq[1:]
|
| 50 |
+
|
| 51 |
+
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
|
| 52 |
+
"""Load a given dataset split.
|
| 53 |
+
Args:
|
| 54 |
+
split (str): name of the split (e.g., train, valid, test)
|
| 55 |
+
"""
|
| 56 |
+
if self.cfg.batch_size is not None:
|
| 57 |
+
bsz = self.cfg.batch_size
|
| 58 |
+
else:
|
| 59 |
+
bsz = max(1, self.cfg.max_tokens // self.cfg.tokens_per_sample)
|
| 60 |
+
self.datasets[split] = DummyDataset(
|
| 61 |
+
{
|
| 62 |
+
"id": 1,
|
| 63 |
+
"net_input": {
|
| 64 |
+
"src_tokens": torch.stack([self.dummy_src for _ in range(bsz)]),
|
| 65 |
+
"src_lengths": torch.full(
|
| 66 |
+
(bsz,), self.cfg.tokens_per_sample, dtype=torch.long
|
| 67 |
+
),
|
| 68 |
+
},
|
| 69 |
+
"target": torch.stack([self.dummy_tgt for _ in range(bsz)]),
|
| 70 |
+
"nsentences": bsz,
|
| 71 |
+
"ntokens": bsz * self.cfg.tokens_per_sample,
|
| 72 |
+
},
|
| 73 |
+
num_items=self.cfg.dataset_size,
|
| 74 |
+
item_size=self.cfg.tokens_per_sample,
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
@property
|
| 78 |
+
def source_dictionary(self):
|
| 79 |
+
return self.dictionary
|
| 80 |
+
|
| 81 |
+
@property
|
| 82 |
+
def target_dictionary(self):
|
| 83 |
+
return self.dictionary
|
fairseq/benchmark/dummy_masked_lm.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import logging
|
| 7 |
+
from dataclasses import dataclass, field
|
| 8 |
+
from typing import Optional
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from omegaconf import II
|
| 12 |
+
|
| 13 |
+
from .dummy_dataset import DummyDataset
|
| 14 |
+
from fairseq.data import Dictionary
|
| 15 |
+
from fairseq.dataclass import FairseqDataclass
|
| 16 |
+
from fairseq.tasks import FairseqTask, register_task
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@dataclass
|
| 22 |
+
class DummyMaskedLMConfig(FairseqDataclass):
|
| 23 |
+
dict_size: int = 49996
|
| 24 |
+
dataset_size: int = 100000
|
| 25 |
+
tokens_per_sample: int = field(
|
| 26 |
+
default=512,
|
| 27 |
+
metadata={
|
| 28 |
+
"help": "max number of total tokens over all"
|
| 29 |
+
" segments per sample for BERT dataset"
|
| 30 |
+
},
|
| 31 |
+
)
|
| 32 |
+
batch_size: Optional[int] = II("dataset.batch_size")
|
| 33 |
+
max_tokens: Optional[int] = II("dataset.max_tokens")
|
| 34 |
+
max_target_positions: int = II("task.tokens_per_sample")
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@register_task("dummy_masked_lm", dataclass=DummyMaskedLMConfig)
|
| 38 |
+
class DummyMaskedLMTask(FairseqTask):
|
| 39 |
+
def __init__(self, cfg: DummyMaskedLMConfig):
|
| 40 |
+
super().__init__(cfg)
|
| 41 |
+
|
| 42 |
+
self.dictionary = Dictionary()
|
| 43 |
+
for i in range(cfg.dict_size):
|
| 44 |
+
self.dictionary.add_symbol("word{}".format(i))
|
| 45 |
+
logger.info("dictionary: {} types".format(len(self.dictionary)))
|
| 46 |
+
# add mask token
|
| 47 |
+
self.mask_idx = self.dictionary.add_symbol("<mask>")
|
| 48 |
+
self.dictionary.pad_to_multiple_(8) # often faster if divisible by 8
|
| 49 |
+
|
| 50 |
+
mask_idx = 0
|
| 51 |
+
pad_idx = 1
|
| 52 |
+
seq = torch.arange(cfg.tokens_per_sample) + pad_idx + 1
|
| 53 |
+
mask = torch.arange(2, cfg.tokens_per_sample, 7) # ~15%
|
| 54 |
+
src = seq.clone()
|
| 55 |
+
src[mask] = mask_idx
|
| 56 |
+
tgt = torch.full_like(seq, pad_idx)
|
| 57 |
+
tgt[mask] = seq[mask]
|
| 58 |
+
|
| 59 |
+
self.dummy_src = src
|
| 60 |
+
self.dummy_tgt = tgt
|
| 61 |
+
|
| 62 |
+
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
|
| 63 |
+
"""Load a given dataset split.
|
| 64 |
+
Args:
|
| 65 |
+
split (str): name of the split (e.g., train, valid, test)
|
| 66 |
+
"""
|
| 67 |
+
if self.cfg.batch_size is not None:
|
| 68 |
+
bsz = self.cfg.batch_size
|
| 69 |
+
else:
|
| 70 |
+
bsz = max(1, self.cfg.max_tokens // self.cfg.tokens_per_sample)
|
| 71 |
+
self.datasets[split] = DummyDataset(
|
| 72 |
+
{
|
| 73 |
+
"id": 1,
|
| 74 |
+
"net_input": {
|
| 75 |
+
"src_tokens": torch.stack([self.dummy_src for _ in range(bsz)]),
|
| 76 |
+
"src_lengths": torch.full(
|
| 77 |
+
(bsz,), self.cfg.tokens_per_sample, dtype=torch.long
|
| 78 |
+
),
|
| 79 |
+
},
|
| 80 |
+
"target": torch.stack([self.dummy_tgt for _ in range(bsz)]),
|
| 81 |
+
"nsentences": bsz,
|
| 82 |
+
"ntokens": bsz * self.cfg.tokens_per_sample,
|
| 83 |
+
},
|
| 84 |
+
num_items=self.cfg.dataset_size,
|
| 85 |
+
item_size=self.cfg.tokens_per_sample,
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
@property
|
| 89 |
+
def source_dictionary(self):
|
| 90 |
+
return self.dictionary
|
| 91 |
+
|
| 92 |
+
@property
|
| 93 |
+
def target_dictionary(self):
|
| 94 |
+
return self.dictionary
|
fairseq/benchmark/dummy_model.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from fairseq.data import Dictionary
|
| 9 |
+
from fairseq.models import (
|
| 10 |
+
FairseqDecoder,
|
| 11 |
+
FairseqLanguageModel,
|
| 12 |
+
register_model,
|
| 13 |
+
register_model_architecture,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@register_model("dummy_model")
|
| 18 |
+
class DummyModel(FairseqLanguageModel):
|
| 19 |
+
def __init__(self, args, encoder):
|
| 20 |
+
super().__init__(encoder)
|
| 21 |
+
self.args = args
|
| 22 |
+
|
| 23 |
+
@staticmethod
|
| 24 |
+
def add_args(parser):
|
| 25 |
+
parser.add_argument("--num-layers", type=int, default=24)
|
| 26 |
+
parser.add_argument("--embed-dim", type=int, default=1024)
|
| 27 |
+
|
| 28 |
+
@classmethod
|
| 29 |
+
def build_model(cls, args, task):
|
| 30 |
+
encoder = DummyEncoder(
|
| 31 |
+
num_embed=len(task.target_dictionary),
|
| 32 |
+
embed_dim=args.embed_dim,
|
| 33 |
+
num_layers=args.num_layers,
|
| 34 |
+
)
|
| 35 |
+
return cls(args, encoder)
|
| 36 |
+
|
| 37 |
+
def forward(self, src_tokens, masked_tokens=None, **kwargs):
|
| 38 |
+
return self.decoder(src_tokens, masked_tokens=masked_tokens)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class DummyEncoder(FairseqDecoder):
|
| 42 |
+
def __init__(self, num_embed=50000, embed_dim=1024, num_layers=24):
|
| 43 |
+
super().__init__(Dictionary())
|
| 44 |
+
self.embed = nn.Embedding(
|
| 45 |
+
num_embeddings=num_embed, embedding_dim=embed_dim, padding_idx=0
|
| 46 |
+
)
|
| 47 |
+
self.layers_a = nn.ModuleList(
|
| 48 |
+
[
|
| 49 |
+
nn.Sequential(
|
| 50 |
+
nn.LayerNorm(embed_dim),
|
| 51 |
+
nn.Linear(embed_dim, 3 * embed_dim), # q, k, v input projection
|
| 52 |
+
nn.Linear(3 * embed_dim, embed_dim), # skip self-attention
|
| 53 |
+
nn.Linear(embed_dim, embed_dim), # output projection
|
| 54 |
+
nn.Dropout(),
|
| 55 |
+
)
|
| 56 |
+
for i in range(num_layers)
|
| 57 |
+
]
|
| 58 |
+
)
|
| 59 |
+
self.layers_b = nn.ModuleList(
|
| 60 |
+
[
|
| 61 |
+
nn.Sequential(
|
| 62 |
+
nn.LayerNorm(embed_dim),
|
| 63 |
+
nn.Linear(embed_dim, 4 * embed_dim), # FFN
|
| 64 |
+
nn.ReLU(),
|
| 65 |
+
nn.Linear(4 * embed_dim, embed_dim), # FFN
|
| 66 |
+
nn.Dropout(0.1),
|
| 67 |
+
)
|
| 68 |
+
for i in range(num_layers)
|
| 69 |
+
]
|
| 70 |
+
)
|
| 71 |
+
self.out_proj = nn.Linear(embed_dim, num_embed)
|
| 72 |
+
|
| 73 |
+
def forward(self, tokens, masked_tokens=None):
|
| 74 |
+
x = self.embed(tokens)
|
| 75 |
+
for layer_a, layer_b in zip(self.layers_a, self.layers_b):
|
| 76 |
+
x = x + layer_a(x)
|
| 77 |
+
x = x + layer_b(x)
|
| 78 |
+
x = self.out_proj(x)
|
| 79 |
+
if masked_tokens is not None:
|
| 80 |
+
x = x[masked_tokens]
|
| 81 |
+
return (x,)
|
| 82 |
+
|
| 83 |
+
def max_positions(self):
|
| 84 |
+
return 1024
|
| 85 |
+
|
| 86 |
+
def get_normalized_probs(self, net_output, log_probs, sample=None):
|
| 87 |
+
logits = net_output[0].float()
|
| 88 |
+
if log_probs:
|
| 89 |
+
return F.log_softmax(logits, dim=-1)
|
| 90 |
+
else:
|
| 91 |
+
return F.softmax(logits, dim=-1)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
@register_model_architecture("dummy_model", "dummy_model")
|
| 95 |
+
def base_architecture(args):
|
| 96 |
+
pass
|
fairseq/benchmark/dummy_mt.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import logging
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from fairseq.data import Dictionary, FairseqDataset
|
| 12 |
+
from fairseq.tasks import LegacyFairseqTask, register_task
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@register_task("dummy_mt")
|
| 18 |
+
class DummyMTTask(LegacyFairseqTask):
|
| 19 |
+
@staticmethod
|
| 20 |
+
def add_args(parser):
|
| 21 |
+
"""Add task-specific arguments to the parser."""
|
| 22 |
+
parser.add_argument("--dict-size", default=49996, type=int)
|
| 23 |
+
parser.add_argument("--dataset-size", default=100000, type=int)
|
| 24 |
+
parser.add_argument("--src-len", default=30, type=int)
|
| 25 |
+
parser.add_argument("--tgt-len", default=30, type=int)
|
| 26 |
+
|
| 27 |
+
def __init__(self, args, dictionary):
|
| 28 |
+
super().__init__(args)
|
| 29 |
+
self.dictionary = dictionary
|
| 30 |
+
self.seed = args.seed
|
| 31 |
+
|
| 32 |
+
dictionary.pad_to_multiple_(8) # often faster if divisible by 8
|
| 33 |
+
|
| 34 |
+
self.dummy_src = torch.arange(args.src_len + 1) + dictionary.pad() + 1
|
| 35 |
+
self.dummy_tgt = torch.arange(args.tgt_len + 1) + dictionary.pad() + 1
|
| 36 |
+
|
| 37 |
+
@classmethod
|
| 38 |
+
def setup_task(cls, args, **kwargs):
|
| 39 |
+
"""Setup the task."""
|
| 40 |
+
dictionary = Dictionary()
|
| 41 |
+
for i in range(args.dict_size):
|
| 42 |
+
dictionary.add_symbol("word{}".format(i))
|
| 43 |
+
logger.info("dictionary: {} types".format(len(dictionary)))
|
| 44 |
+
|
| 45 |
+
args.max_source_positions = args.src_len + dictionary.pad() + 2
|
| 46 |
+
args.max_target_positions = args.tgt_len + dictionary.pad() + 2
|
| 47 |
+
|
| 48 |
+
return cls(args, dictionary)
|
| 49 |
+
|
| 50 |
+
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
|
| 51 |
+
"""Load a given dataset split.
|
| 52 |
+
Args:
|
| 53 |
+
split (str): name of the split (e.g., train, valid, test)
|
| 54 |
+
"""
|
| 55 |
+
item_size = max(self.args.src_len, self.args.tgt_len)
|
| 56 |
+
if self.args.batch_size is not None:
|
| 57 |
+
bsz = self.args.batch_size
|
| 58 |
+
else:
|
| 59 |
+
bsz = max(1, self.args.max_tokens // item_size)
|
| 60 |
+
tgt = torch.stack([self.dummy_tgt for _ in range(bsz)])
|
| 61 |
+
self.datasets[split] = DummyDataset(
|
| 62 |
+
{
|
| 63 |
+
"id": 1,
|
| 64 |
+
"net_input": {
|
| 65 |
+
"src_tokens": torch.stack([self.dummy_src for _ in range(bsz)]),
|
| 66 |
+
"src_lengths": torch.full(
|
| 67 |
+
(bsz,), self.args.src_len, dtype=torch.long
|
| 68 |
+
),
|
| 69 |
+
"prev_output_tokens": tgt.clone(),
|
| 70 |
+
},
|
| 71 |
+
"target": tgt,
|
| 72 |
+
"nsentences": bsz,
|
| 73 |
+
"ntokens": bsz * self.args.tgt_len,
|
| 74 |
+
},
|
| 75 |
+
num_items=self.args.dataset_size,
|
| 76 |
+
item_size=item_size,
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
@property
|
| 80 |
+
def source_dictionary(self):
|
| 81 |
+
return self.dictionary
|
| 82 |
+
|
| 83 |
+
@property
|
| 84 |
+
def target_dictionary(self):
|
| 85 |
+
return self.dictionary
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class DummyDataset(FairseqDataset):
|
| 89 |
+
def __init__(self, batch, num_items, item_size):
|
| 90 |
+
super().__init__()
|
| 91 |
+
self.batch = batch
|
| 92 |
+
self.num_items = num_items
|
| 93 |
+
self.item_size = item_size
|
| 94 |
+
|
| 95 |
+
def __getitem__(self, index):
|
| 96 |
+
return index
|
| 97 |
+
|
| 98 |
+
def __len__(self):
|
| 99 |
+
return self.num_items
|
| 100 |
+
|
| 101 |
+
def collater(self, samples):
|
| 102 |
+
return self.batch
|
| 103 |
+
|
| 104 |
+
@property
|
| 105 |
+
def sizes(self):
|
| 106 |
+
return np.array([self.item_size] * self.num_items)
|
| 107 |
+
|
| 108 |
+
def num_tokens(self, index):
|
| 109 |
+
return self.item_size
|
| 110 |
+
|
| 111 |
+
def size(self, index):
|
| 112 |
+
return self.item_size
|
| 113 |
+
|
| 114 |
+
def ordered_indices(self):
|
| 115 |
+
return np.arange(self.num_items)
|
| 116 |
+
|
| 117 |
+
@property
|
| 118 |
+
def supports_prefetch(self):
|
| 119 |
+
return False
|
fairseq/binarizer.py
ADDED
|
@@ -0,0 +1,381 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import logging
|
| 7 |
+
import os
|
| 8 |
+
import typing as tp
|
| 9 |
+
from abc import ABC, abstractmethod
|
| 10 |
+
from collections import Counter
|
| 11 |
+
from dataclasses import dataclass
|
| 12 |
+
from multiprocessing import Pool
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
|
| 16 |
+
from fairseq.data import Dictionary, indexed_dataset
|
| 17 |
+
from fairseq.file_chunker_utils import Chunker, find_offsets
|
| 18 |
+
from fairseq.file_io import PathManager
|
| 19 |
+
from fairseq.tokenizer import tokenize_line
|
| 20 |
+
|
| 21 |
+
logger = logging.getLogger("binarizer")
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@dataclass
|
| 25 |
+
class BinarizeSummary:
|
| 26 |
+
"""
|
| 27 |
+
Keep track of what's going on in the binarizer
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
num_seq: int = 0
|
| 31 |
+
replaced: tp.Optional[Counter] = None
|
| 32 |
+
num_tok: int = 0
|
| 33 |
+
|
| 34 |
+
@property
|
| 35 |
+
def num_replaced(self) -> int:
|
| 36 |
+
if self.replaced is None:
|
| 37 |
+
return 0
|
| 38 |
+
return sum(self.replaced.values())
|
| 39 |
+
|
| 40 |
+
@property
|
| 41 |
+
def replaced_percent(self) -> float:
|
| 42 |
+
return 100 * self.num_replaced / self.num_tok
|
| 43 |
+
|
| 44 |
+
def __str__(self) -> str:
|
| 45 |
+
base = f"{self.num_seq} sents, {self.num_tok} tokens"
|
| 46 |
+
if self.replaced is None:
|
| 47 |
+
return base
|
| 48 |
+
|
| 49 |
+
return f"{base}, {self.replaced_percent:.3}% replaced"
|
| 50 |
+
|
| 51 |
+
def merge(self, other: "BinarizeSummary"):
|
| 52 |
+
replaced = None
|
| 53 |
+
if self.replaced is not None:
|
| 54 |
+
replaced = self.replaced
|
| 55 |
+
if other.replaced is not None:
|
| 56 |
+
if replaced is None:
|
| 57 |
+
replaced = other.replaced
|
| 58 |
+
else:
|
| 59 |
+
replaced += other.replaced
|
| 60 |
+
self.replaced = replaced
|
| 61 |
+
self.num_seq += other.num_seq
|
| 62 |
+
self.num_tok += other.num_tok
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class Binarizer(ABC):
|
| 66 |
+
"""
|
| 67 |
+
a binarizer describes how to take a string and build a tensor out of it
|
| 68 |
+
"""
|
| 69 |
+
|
| 70 |
+
@abstractmethod
|
| 71 |
+
def binarize_line(
|
| 72 |
+
self,
|
| 73 |
+
line: str,
|
| 74 |
+
summary: BinarizeSummary,
|
| 75 |
+
) -> torch.IntTensor:
|
| 76 |
+
...
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def _worker_prefix(output_prefix: str, worker_id: int):
|
| 80 |
+
return f"{output_prefix}.pt{worker_id}"
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class FileBinarizer:
|
| 84 |
+
"""
|
| 85 |
+
An file binarizer can take a file, tokenize it, and binarize each line to a tensor
|
| 86 |
+
"""
|
| 87 |
+
|
| 88 |
+
@classmethod
|
| 89 |
+
def multiprocess_dataset(
|
| 90 |
+
cls,
|
| 91 |
+
input_file: str,
|
| 92 |
+
dataset_impl: str,
|
| 93 |
+
binarizer: Binarizer,
|
| 94 |
+
output_prefix: str,
|
| 95 |
+
vocab_size=None,
|
| 96 |
+
num_workers=1,
|
| 97 |
+
) -> BinarizeSummary:
|
| 98 |
+
final_summary = BinarizeSummary()
|
| 99 |
+
|
| 100 |
+
offsets = find_offsets(input_file, num_workers)
|
| 101 |
+
# find_offsets returns a list of position [pos1, pos2, pos3, pos4] but we would want pairs:
|
| 102 |
+
# [(pos1, pos2), (pos2, pos3), (pos3, pos4)] to process the chunks with start/end info
|
| 103 |
+
# we zip the list with itself shifted by one to get all the pairs.
|
| 104 |
+
(first_chunk, *more_chunks) = zip(offsets, offsets[1:])
|
| 105 |
+
pool = None
|
| 106 |
+
if num_workers > 1:
|
| 107 |
+
pool = Pool(processes=num_workers - 1)
|
| 108 |
+
worker_results = [
|
| 109 |
+
pool.apply_async(
|
| 110 |
+
cls._binarize_chunk_and_finalize,
|
| 111 |
+
args=(
|
| 112 |
+
binarizer,
|
| 113 |
+
input_file,
|
| 114 |
+
start_offset,
|
| 115 |
+
end_offset,
|
| 116 |
+
_worker_prefix(
|
| 117 |
+
output_prefix,
|
| 118 |
+
worker_id,
|
| 119 |
+
),
|
| 120 |
+
dataset_impl,
|
| 121 |
+
),
|
| 122 |
+
kwds={
|
| 123 |
+
"vocab_size": vocab_size,
|
| 124 |
+
}
|
| 125 |
+
if vocab_size is not None
|
| 126 |
+
else {},
|
| 127 |
+
)
|
| 128 |
+
for worker_id, (start_offset, end_offset) in enumerate(
|
| 129 |
+
more_chunks, start=1
|
| 130 |
+
)
|
| 131 |
+
]
|
| 132 |
+
|
| 133 |
+
pool.close()
|
| 134 |
+
pool.join()
|
| 135 |
+
for r in worker_results:
|
| 136 |
+
summ = r.get()
|
| 137 |
+
final_summary.merge(summ)
|
| 138 |
+
|
| 139 |
+
# do not close the bin file as we need to merge the worker results in
|
| 140 |
+
final_ds, summ = cls._binarize_file_chunk(
|
| 141 |
+
binarizer,
|
| 142 |
+
input_file,
|
| 143 |
+
offset_start=first_chunk[0],
|
| 144 |
+
offset_end=first_chunk[1],
|
| 145 |
+
output_prefix=output_prefix,
|
| 146 |
+
dataset_impl=dataset_impl,
|
| 147 |
+
vocab_size=vocab_size if vocab_size is not None else None,
|
| 148 |
+
)
|
| 149 |
+
final_summary.merge(summ)
|
| 150 |
+
|
| 151 |
+
if num_workers > 1:
|
| 152 |
+
for worker_id in range(1, num_workers):
|
| 153 |
+
# merge the worker outputs
|
| 154 |
+
worker_output_prefix = _worker_prefix(
|
| 155 |
+
output_prefix,
|
| 156 |
+
worker_id,
|
| 157 |
+
)
|
| 158 |
+
final_ds.merge_file_(worker_output_prefix)
|
| 159 |
+
try:
|
| 160 |
+
os.remove(indexed_dataset.data_file_path(worker_output_prefix))
|
| 161 |
+
os.remove(indexed_dataset.index_file_path(worker_output_prefix))
|
| 162 |
+
except Exception as e:
|
| 163 |
+
logger.error(
|
| 164 |
+
f"couldn't remove {worker_output_prefix}.*", exc_info=e
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
# now we can close the file
|
| 168 |
+
idx_file = indexed_dataset.index_file_path(output_prefix)
|
| 169 |
+
final_ds.finalize(idx_file)
|
| 170 |
+
return final_summary
|
| 171 |
+
|
| 172 |
+
@staticmethod
|
| 173 |
+
def _binarize_file_chunk(
|
| 174 |
+
binarizer: Binarizer,
|
| 175 |
+
filename: str,
|
| 176 |
+
offset_start: int,
|
| 177 |
+
offset_end: int,
|
| 178 |
+
output_prefix: str,
|
| 179 |
+
dataset_impl: str,
|
| 180 |
+
vocab_size=None,
|
| 181 |
+
) -> tp.Tuple[tp.Any, BinarizeSummary]: # (dataset builder, BinarizeSummary)
|
| 182 |
+
"""
|
| 183 |
+
creates a dataset builder and append binarized items to it. This function does not
|
| 184 |
+
finalize the builder, this is useful if you want to do other things with your bin file
|
| 185 |
+
like appending/merging other files
|
| 186 |
+
"""
|
| 187 |
+
bin_file = indexed_dataset.data_file_path(output_prefix)
|
| 188 |
+
ds = indexed_dataset.make_builder(
|
| 189 |
+
bin_file,
|
| 190 |
+
impl=dataset_impl,
|
| 191 |
+
vocab_size=vocab_size,
|
| 192 |
+
)
|
| 193 |
+
summary = BinarizeSummary()
|
| 194 |
+
|
| 195 |
+
with Chunker(
|
| 196 |
+
PathManager.get_local_path(filename), offset_start, offset_end
|
| 197 |
+
) as line_iterator:
|
| 198 |
+
for line in line_iterator:
|
| 199 |
+
ds.add_item(binarizer.binarize_line(line, summary))
|
| 200 |
+
|
| 201 |
+
return ds, summary
|
| 202 |
+
|
| 203 |
+
@classmethod
|
| 204 |
+
def _binarize_chunk_and_finalize(
|
| 205 |
+
cls,
|
| 206 |
+
binarizer: Binarizer,
|
| 207 |
+
filename: str,
|
| 208 |
+
offset_start: int,
|
| 209 |
+
offset_end: int,
|
| 210 |
+
output_prefix: str,
|
| 211 |
+
dataset_impl: str,
|
| 212 |
+
vocab_size=None,
|
| 213 |
+
):
|
| 214 |
+
"""
|
| 215 |
+
same as above, but also finalizes the builder
|
| 216 |
+
"""
|
| 217 |
+
ds, summ = cls._binarize_file_chunk(
|
| 218 |
+
binarizer,
|
| 219 |
+
filename,
|
| 220 |
+
offset_start,
|
| 221 |
+
offset_end,
|
| 222 |
+
output_prefix,
|
| 223 |
+
dataset_impl,
|
| 224 |
+
vocab_size=vocab_size,
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
idx_file = indexed_dataset.index_file_path(output_prefix)
|
| 228 |
+
ds.finalize(idx_file)
|
| 229 |
+
|
| 230 |
+
return summ
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
class VocabularyDatasetBinarizer(Binarizer):
|
| 234 |
+
"""
|
| 235 |
+
Takes a Dictionary/Vocabulary, assign ids to each
|
| 236 |
+
token using the dictionary encode_line function.
|
| 237 |
+
"""
|
| 238 |
+
|
| 239 |
+
def __init__(
|
| 240 |
+
self,
|
| 241 |
+
dict: Dictionary,
|
| 242 |
+
tokenize: tp.Callable[[str], tp.List[str]] = tokenize_line,
|
| 243 |
+
append_eos: bool = True,
|
| 244 |
+
reverse_order: bool = False,
|
| 245 |
+
already_numberized: bool = False,
|
| 246 |
+
) -> None:
|
| 247 |
+
self.dict = dict
|
| 248 |
+
self.tokenize = tokenize
|
| 249 |
+
self.append_eos = append_eos
|
| 250 |
+
self.reverse_order = reverse_order
|
| 251 |
+
self.already_numberized = already_numberized
|
| 252 |
+
super().__init__()
|
| 253 |
+
|
| 254 |
+
def binarize_line(
|
| 255 |
+
self,
|
| 256 |
+
line: str,
|
| 257 |
+
summary: BinarizeSummary,
|
| 258 |
+
):
|
| 259 |
+
if summary.replaced is None:
|
| 260 |
+
summary.replaced = Counter()
|
| 261 |
+
|
| 262 |
+
def replaced_consumer(word, idx):
|
| 263 |
+
if idx == self.dict.unk_index and word != self.dict.unk_word:
|
| 264 |
+
summary.replaced.update([word])
|
| 265 |
+
|
| 266 |
+
if self.already_numberized:
|
| 267 |
+
id_strings = line.strip().split()
|
| 268 |
+
id_list = [int(id_string) for id_string in id_strings]
|
| 269 |
+
if self.reverse_order:
|
| 270 |
+
id_list.reverse()
|
| 271 |
+
if self.append_eos:
|
| 272 |
+
id_list.append(self.dict.eos())
|
| 273 |
+
ids = torch.IntTensor(id_list)
|
| 274 |
+
else:
|
| 275 |
+
ids = self.dict.encode_line(
|
| 276 |
+
line=line,
|
| 277 |
+
line_tokenizer=self.tokenize,
|
| 278 |
+
add_if_not_exist=False,
|
| 279 |
+
consumer=replaced_consumer,
|
| 280 |
+
append_eos=self.append_eos,
|
| 281 |
+
reverse_order=self.reverse_order,
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
summary.num_seq += 1
|
| 285 |
+
summary.num_tok += len(ids)
|
| 286 |
+
return ids
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
class AlignmentDatasetBinarizer(Binarizer):
|
| 290 |
+
"""
|
| 291 |
+
binarize by parsing a set of alignments and packing
|
| 292 |
+
them in a tensor (see utils.parse_alignment)
|
| 293 |
+
"""
|
| 294 |
+
|
| 295 |
+
def __init__(
|
| 296 |
+
self,
|
| 297 |
+
alignment_parser: tp.Callable[[str], torch.IntTensor],
|
| 298 |
+
) -> None:
|
| 299 |
+
super().__init__()
|
| 300 |
+
self.alignment_parser = alignment_parser
|
| 301 |
+
|
| 302 |
+
def binarize_line(
|
| 303 |
+
self,
|
| 304 |
+
line: str,
|
| 305 |
+
summary: BinarizeSummary,
|
| 306 |
+
):
|
| 307 |
+
ids = self.alignment_parser(line)
|
| 308 |
+
summary.num_seq += 1
|
| 309 |
+
summary.num_tok += len(ids)
|
| 310 |
+
return ids
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
class LegacyBinarizer:
|
| 314 |
+
@classmethod
|
| 315 |
+
def binarize(
|
| 316 |
+
cls,
|
| 317 |
+
filename: str,
|
| 318 |
+
dico: Dictionary,
|
| 319 |
+
consumer: tp.Callable[[torch.IntTensor], None],
|
| 320 |
+
tokenize: tp.Callable[[str], tp.List[str]] = tokenize_line,
|
| 321 |
+
append_eos: bool = True,
|
| 322 |
+
reverse_order: bool = False,
|
| 323 |
+
offset: int = 0,
|
| 324 |
+
end: int = -1,
|
| 325 |
+
already_numberized: bool = False,
|
| 326 |
+
) -> tp.Dict[str, int]:
|
| 327 |
+
binarizer = VocabularyDatasetBinarizer(
|
| 328 |
+
dict=dico,
|
| 329 |
+
tokenize=tokenize,
|
| 330 |
+
append_eos=append_eos,
|
| 331 |
+
reverse_order=reverse_order,
|
| 332 |
+
already_numberized=already_numberized,
|
| 333 |
+
)
|
| 334 |
+
return cls._consume_file(
|
| 335 |
+
filename,
|
| 336 |
+
binarizer,
|
| 337 |
+
consumer,
|
| 338 |
+
offset_start=offset,
|
| 339 |
+
offset_end=end,
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
@classmethod
|
| 343 |
+
def binarize_alignments(
|
| 344 |
+
cls,
|
| 345 |
+
filename: str,
|
| 346 |
+
alignment_parser: tp.Callable[[str], torch.IntTensor],
|
| 347 |
+
consumer: tp.Callable[[torch.IntTensor], None],
|
| 348 |
+
offset: int = 0,
|
| 349 |
+
end: int = -1,
|
| 350 |
+
) -> tp.Dict[str, int]:
|
| 351 |
+
binarizer = AlignmentDatasetBinarizer(alignment_parser)
|
| 352 |
+
return cls._consume_file(
|
| 353 |
+
filename,
|
| 354 |
+
binarizer,
|
| 355 |
+
consumer,
|
| 356 |
+
offset_start=offset,
|
| 357 |
+
offset_end=end,
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
@staticmethod
|
| 361 |
+
def _consume_file(
|
| 362 |
+
filename: str,
|
| 363 |
+
binarizer: Binarizer,
|
| 364 |
+
consumer: tp.Callable[[torch.IntTensor], None],
|
| 365 |
+
offset_start: int,
|
| 366 |
+
offset_end: int,
|
| 367 |
+
) -> tp.Dict[str, int]:
|
| 368 |
+
summary = BinarizeSummary()
|
| 369 |
+
|
| 370 |
+
with Chunker(
|
| 371 |
+
PathManager.get_local_path(filename), offset_start, offset_end
|
| 372 |
+
) as line_iterator:
|
| 373 |
+
for line in line_iterator:
|
| 374 |
+
consumer(binarizer.binarize_line(line, summary))
|
| 375 |
+
|
| 376 |
+
return {
|
| 377 |
+
"nseq": summary.num_seq,
|
| 378 |
+
"nunk": summary.num_replaced,
|
| 379 |
+
"ntok": summary.num_tok,
|
| 380 |
+
"replaced": summary.replaced,
|
| 381 |
+
}
|
fairseq/checkpoint_utils.py
ADDED
|
@@ -0,0 +1,936 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import ast
|
| 7 |
+
import collections
|
| 8 |
+
import contextlib
|
| 9 |
+
import inspect
|
| 10 |
+
import logging
|
| 11 |
+
import os
|
| 12 |
+
import re
|
| 13 |
+
import time
|
| 14 |
+
import traceback
|
| 15 |
+
from collections import OrderedDict
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
from typing import Any, Dict, Optional, Union
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
import torch
|
| 21 |
+
from fairseq.data import data_utils
|
| 22 |
+
from fairseq.dataclass.configs import CheckpointConfig
|
| 23 |
+
from fairseq.dataclass.utils import (
|
| 24 |
+
convert_namespace_to_omegaconf,
|
| 25 |
+
overwrite_args_by_name,
|
| 26 |
+
)
|
| 27 |
+
from fairseq.distributed.fully_sharded_data_parallel import FSDP, has_FSDP
|
| 28 |
+
from fairseq.file_io import PathManager
|
| 29 |
+
from fairseq.models import FairseqDecoder, FairseqEncoder
|
| 30 |
+
from omegaconf import DictConfig, OmegaConf, open_dict
|
| 31 |
+
|
| 32 |
+
logger = logging.getLogger(__name__)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss):
|
| 36 |
+
from fairseq import meters
|
| 37 |
+
|
| 38 |
+
# only one worker should attempt to create the required dir
|
| 39 |
+
if trainer.data_parallel_rank == 0:
|
| 40 |
+
os.makedirs(cfg.save_dir, exist_ok=True)
|
| 41 |
+
|
| 42 |
+
prev_best = getattr(save_checkpoint, "best", val_loss)
|
| 43 |
+
if val_loss is not None:
|
| 44 |
+
best_function = max if cfg.maximize_best_checkpoint_metric else min
|
| 45 |
+
save_checkpoint.best = best_function(val_loss, prev_best)
|
| 46 |
+
|
| 47 |
+
if cfg.no_save:
|
| 48 |
+
return None
|
| 49 |
+
|
| 50 |
+
trainer.consolidate_optimizer() # TODO(SS): do we need this if no_save_optimizer_state
|
| 51 |
+
|
| 52 |
+
if not trainer.should_save_checkpoint_on_current_rank:
|
| 53 |
+
if trainer.always_call_state_dict_during_save_checkpoint:
|
| 54 |
+
trainer.state_dict()
|
| 55 |
+
return None
|
| 56 |
+
|
| 57 |
+
write_timer = meters.StopwatchMeter()
|
| 58 |
+
write_timer.start()
|
| 59 |
+
|
| 60 |
+
epoch = epoch_itr.epoch
|
| 61 |
+
end_of_epoch = epoch_itr.end_of_epoch()
|
| 62 |
+
updates = trainer.get_num_updates()
|
| 63 |
+
|
| 64 |
+
logger.info(f"Preparing to save checkpoint for epoch {epoch} @ {updates} updates")
|
| 65 |
+
|
| 66 |
+
def is_better(a, b):
|
| 67 |
+
return a >= b if cfg.maximize_best_checkpoint_metric else a <= b
|
| 68 |
+
|
| 69 |
+
suffix = trainer.checkpoint_suffix
|
| 70 |
+
checkpoint_conds = collections.OrderedDict()
|
| 71 |
+
checkpoint_conds["checkpoint{}{}.pt".format(epoch, suffix)] = (
|
| 72 |
+
end_of_epoch and not cfg.no_epoch_checkpoints and epoch % cfg.save_interval == 0
|
| 73 |
+
)
|
| 74 |
+
checkpoint_conds["checkpoint_{}_{}{}.pt".format(epoch, updates, suffix)] = (
|
| 75 |
+
not end_of_epoch
|
| 76 |
+
and cfg.save_interval_updates > 0
|
| 77 |
+
and updates % cfg.save_interval_updates == 0
|
| 78 |
+
)
|
| 79 |
+
checkpoint_conds["checkpoint_best{}.pt".format(suffix)] = val_loss is not None and (
|
| 80 |
+
not hasattr(save_checkpoint, "best")
|
| 81 |
+
or is_better(val_loss, save_checkpoint.best)
|
| 82 |
+
)
|
| 83 |
+
if val_loss is not None and cfg.keep_best_checkpoints > 0:
|
| 84 |
+
worst_best = getattr(save_checkpoint, "best", None)
|
| 85 |
+
chkpts = checkpoint_paths(
|
| 86 |
+
cfg.save_dir,
|
| 87 |
+
pattern=r"checkpoint\.best_{}_(\d+\.?\d*){}\.pt".format(
|
| 88 |
+
cfg.best_checkpoint_metric, suffix
|
| 89 |
+
),
|
| 90 |
+
)
|
| 91 |
+
if len(chkpts) > 0:
|
| 92 |
+
p = chkpts[-1] if cfg.maximize_best_checkpoint_metric else chkpts[0]
|
| 93 |
+
worst_best = float(p.rsplit("_")[-1].replace("{}.pt".format(suffix), ""))
|
| 94 |
+
# add random digits to resolve ties
|
| 95 |
+
with data_utils.numpy_seed(epoch, updates, val_loss):
|
| 96 |
+
rand_sfx = np.random.randint(0, cfg.keep_best_checkpoints)
|
| 97 |
+
|
| 98 |
+
checkpoint_conds[
|
| 99 |
+
"checkpoint.best_{}_{:.3f}{}{}.pt".format(
|
| 100 |
+
cfg.best_checkpoint_metric, val_loss, rand_sfx, suffix
|
| 101 |
+
)
|
| 102 |
+
] = worst_best is None or is_better(val_loss, worst_best)
|
| 103 |
+
checkpoint_conds[
|
| 104 |
+
"checkpoint_last{}.pt".format(suffix)
|
| 105 |
+
] = not cfg.no_last_checkpoints
|
| 106 |
+
|
| 107 |
+
extra_state = {
|
| 108 |
+
"train_iterator": epoch_itr.state_dict(),
|
| 109 |
+
"val_loss": val_loss,
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
# Going forward, different tasks could expose an API like this to dump all
|
| 113 |
+
# the checkpoint worthy attributes in a dictionary which then will be
|
| 114 |
+
# merged with the parent dictionary to create the "extra_state". This
|
| 115 |
+
# allows for an extensible yet simple design to checkpoint task level
|
| 116 |
+
# attributes
|
| 117 |
+
if hasattr(trainer.task, "get_checkpoint_dict"):
|
| 118 |
+
extra_state = {**extra_state, **trainer.task.get_checkpoint_dict()}
|
| 119 |
+
logger.info(f"State of {trainer.task.__class__.__name__} is ready to be persisted with the checkpoint")
|
| 120 |
+
|
| 121 |
+
if hasattr(save_checkpoint, "best"):
|
| 122 |
+
extra_state.update({"best": save_checkpoint.best})
|
| 123 |
+
|
| 124 |
+
checkpoints = [
|
| 125 |
+
os.path.join(cfg.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond
|
| 126 |
+
]
|
| 127 |
+
saved_cp = None
|
| 128 |
+
if len(checkpoints) > 0 and trainer.should_save_checkpoint_on_current_rank:
|
| 129 |
+
saved_cp = trainer.save_checkpoint(checkpoints[0], extra_state)
|
| 130 |
+
for cp in checkpoints[1:]:
|
| 131 |
+
if cfg.write_checkpoints_asynchronously:
|
| 132 |
+
# TODO[ioPath]: Need to implement a delayed asynchronous
|
| 133 |
+
# file copying/moving feature.
|
| 134 |
+
logger.warning(
|
| 135 |
+
f"ioPath is not copying {checkpoints[0]} to {cp} "
|
| 136 |
+
"since async write mode is on."
|
| 137 |
+
)
|
| 138 |
+
else:
|
| 139 |
+
assert PathManager.copy(
|
| 140 |
+
checkpoints[0], cp, overwrite=True
|
| 141 |
+
), f"Failed to copy {checkpoints[0]} to {cp}"
|
| 142 |
+
|
| 143 |
+
write_timer.stop()
|
| 144 |
+
logger.info(
|
| 145 |
+
"Saved checkpoint {} (epoch {} @ {} updates, score {}) (writing took {} seconds)".format(
|
| 146 |
+
checkpoints[0], epoch, updates, val_loss, write_timer.sum
|
| 147 |
+
)
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
if (
|
| 151 |
+
not end_of_epoch
|
| 152 |
+
and cfg.keep_interval_updates > 0
|
| 153 |
+
and trainer.should_save_checkpoint_on_current_rank
|
| 154 |
+
):
|
| 155 |
+
# remove old checkpoints; checkpoints are sorted in descending order
|
| 156 |
+
if cfg.keep_interval_updates_pattern == -1:
|
| 157 |
+
checkpoints = checkpoint_paths(
|
| 158 |
+
cfg.save_dir, pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix)
|
| 159 |
+
)
|
| 160 |
+
else:
|
| 161 |
+
checkpoints = checkpoint_paths(
|
| 162 |
+
cfg.save_dir,
|
| 163 |
+
pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix),
|
| 164 |
+
keep_match=True,
|
| 165 |
+
)
|
| 166 |
+
checkpoints = [
|
| 167 |
+
x[0]
|
| 168 |
+
for x in checkpoints
|
| 169 |
+
if x[1] % cfg.keep_interval_updates_pattern != 0
|
| 170 |
+
]
|
| 171 |
+
|
| 172 |
+
for old_chk in checkpoints[cfg.keep_interval_updates :]:
|
| 173 |
+
if os.path.lexists(old_chk):
|
| 174 |
+
os.remove(old_chk)
|
| 175 |
+
elif PathManager.exists(old_chk):
|
| 176 |
+
PathManager.rm(old_chk)
|
| 177 |
+
|
| 178 |
+
if cfg.keep_last_epochs > 0 and trainer.should_save_checkpoint_on_current_rank:
|
| 179 |
+
# remove old epoch checkpoints; checkpoints are sorted in descending order
|
| 180 |
+
checkpoints = checkpoint_paths(
|
| 181 |
+
cfg.save_dir, pattern=r"checkpoint(\d+){}\.pt".format(suffix)
|
| 182 |
+
)
|
| 183 |
+
for old_chk in checkpoints[cfg.keep_last_epochs :]:
|
| 184 |
+
if os.path.lexists(old_chk):
|
| 185 |
+
os.remove(old_chk)
|
| 186 |
+
elif PathManager.exists(old_chk):
|
| 187 |
+
PathManager.rm(old_chk)
|
| 188 |
+
|
| 189 |
+
if cfg.keep_best_checkpoints > 0 and trainer.should_save_checkpoint_on_current_rank:
|
| 190 |
+
# only keep the best N checkpoints according to validation metric
|
| 191 |
+
checkpoints = checkpoint_paths(
|
| 192 |
+
cfg.save_dir,
|
| 193 |
+
pattern=r"checkpoint\.best_{}_(\d+\.?\d*){}\.pt".format(
|
| 194 |
+
cfg.best_checkpoint_metric, suffix
|
| 195 |
+
),
|
| 196 |
+
)
|
| 197 |
+
if not cfg.maximize_best_checkpoint_metric:
|
| 198 |
+
checkpoints = checkpoints[::-1]
|
| 199 |
+
for old_chk in checkpoints[cfg.keep_best_checkpoints :]:
|
| 200 |
+
if os.path.lexists(old_chk):
|
| 201 |
+
os.remove(old_chk)
|
| 202 |
+
elif PathManager.exists(old_chk):
|
| 203 |
+
PathManager.rm(old_chk)
|
| 204 |
+
|
| 205 |
+
return saved_cp
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args):
|
| 209 |
+
"""
|
| 210 |
+
Load a checkpoint and restore the training iterator.
|
| 211 |
+
|
| 212 |
+
*passthrough_args* will be passed through to
|
| 213 |
+
``trainer.get_train_iterator``.
|
| 214 |
+
"""
|
| 215 |
+
|
| 216 |
+
reset_optimizer = cfg.reset_optimizer
|
| 217 |
+
reset_lr_scheduler = cfg.reset_lr_scheduler
|
| 218 |
+
optimizer_overrides = ast.literal_eval(cfg.optimizer_overrides)
|
| 219 |
+
reset_meters = cfg.reset_meters
|
| 220 |
+
reset_dataloader = cfg.reset_dataloader
|
| 221 |
+
|
| 222 |
+
if cfg.finetune_from_model is not None and (
|
| 223 |
+
reset_optimizer or reset_lr_scheduler or reset_meters or reset_dataloader
|
| 224 |
+
):
|
| 225 |
+
raise ValueError(
|
| 226 |
+
"--finetune-from-model can not be set together with either --reset-optimizer"
|
| 227 |
+
" or reset_lr_scheduler or reset_meters or reset_dataloader"
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
suffix = trainer.checkpoint_suffix
|
| 231 |
+
if (
|
| 232 |
+
cfg.restore_file == "checkpoint_last.pt"
|
| 233 |
+
): # default value of restore_file is 'checkpoint_last.pt'
|
| 234 |
+
checkpoint_path = os.path.join(
|
| 235 |
+
cfg.save_dir, "checkpoint_last{}.pt".format(suffix)
|
| 236 |
+
)
|
| 237 |
+
first_launch = not PathManager.exists(checkpoint_path)
|
| 238 |
+
if first_launch and getattr(cfg, "continue_once", None) is not None:
|
| 239 |
+
checkpoint_path = cfg.continue_once
|
| 240 |
+
elif cfg.finetune_from_model is not None and first_launch:
|
| 241 |
+
# if there is no last checkpoint to restore, start the finetune from pretrained model
|
| 242 |
+
# else just use usual logic to load checkpoint, e.g. restart from last checkpoint and etc.
|
| 243 |
+
if PathManager.exists(cfg.finetune_from_model):
|
| 244 |
+
checkpoint_path = cfg.finetune_from_model
|
| 245 |
+
reset_optimizer = True
|
| 246 |
+
reset_lr_scheduler = True
|
| 247 |
+
reset_meters = True
|
| 248 |
+
reset_dataloader = True
|
| 249 |
+
logger.info(
|
| 250 |
+
f"loading pretrained model from {checkpoint_path}: "
|
| 251 |
+
"optimizer, lr scheduler, meters, dataloader will be reset"
|
| 252 |
+
)
|
| 253 |
+
else:
|
| 254 |
+
raise ValueError(
|
| 255 |
+
f"--finetune-from-model {cfg.finetune_from_model} does not exist"
|
| 256 |
+
)
|
| 257 |
+
elif suffix is not None:
|
| 258 |
+
checkpoint_path = cfg.restore_file.replace(".pt", suffix + ".pt")
|
| 259 |
+
else:
|
| 260 |
+
checkpoint_path = cfg.restore_file
|
| 261 |
+
|
| 262 |
+
if cfg.restore_file != "checkpoint_last.pt" and cfg.finetune_from_model:
|
| 263 |
+
raise ValueError(
|
| 264 |
+
"--finetune-from-model and --restore-file (non-default value) "
|
| 265 |
+
"can not be specified together: " + str(cfg)
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
extra_state = trainer.load_checkpoint(
|
| 269 |
+
checkpoint_path,
|
| 270 |
+
reset_optimizer,
|
| 271 |
+
reset_lr_scheduler,
|
| 272 |
+
optimizer_overrides,
|
| 273 |
+
reset_meters=reset_meters,
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
if (
|
| 277 |
+
extra_state is not None
|
| 278 |
+
and "best" in extra_state
|
| 279 |
+
and not reset_optimizer
|
| 280 |
+
and not reset_meters
|
| 281 |
+
):
|
| 282 |
+
save_checkpoint.best = extra_state["best"]
|
| 283 |
+
|
| 284 |
+
if extra_state is not None and not reset_dataloader:
|
| 285 |
+
# restore iterator from checkpoint
|
| 286 |
+
itr_state = extra_state["train_iterator"]
|
| 287 |
+
epoch_itr = trainer.get_train_iterator(
|
| 288 |
+
epoch=itr_state["epoch"], load_dataset=True, **passthrough_args
|
| 289 |
+
)
|
| 290 |
+
epoch_itr.load_state_dict(itr_state)
|
| 291 |
+
|
| 292 |
+
# Preload the checkpoint for the task
|
| 293 |
+
task_cp_dict = extra_state.get(trainer.task.__class__.__name__, {})
|
| 294 |
+
if task_cp_dict and hasattr(trainer.task, "set_checkpoint_dict"):
|
| 295 |
+
trainer.task.set_checkpoint_dict(task_cp_dict)
|
| 296 |
+
else:
|
| 297 |
+
epoch_itr = trainer.get_train_iterator(
|
| 298 |
+
epoch=1, load_dataset=True, **passthrough_args
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
trainer.lr_step(epoch_itr.epoch)
|
| 302 |
+
|
| 303 |
+
return extra_state, epoch_itr
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
def load_checkpoint_to_cpu(path, arg_overrides=None, load_on_all_ranks=False):
|
| 307 |
+
"""Loads a checkpoint to CPU (with upgrading for backward compatibility).
|
| 308 |
+
|
| 309 |
+
If doing single-GPU training or if the checkpoint is only being loaded by at
|
| 310 |
+
most one process on each node (current default behavior is for only rank 0
|
| 311 |
+
to read the checkpoint from disk), load_on_all_ranks should be False to
|
| 312 |
+
avoid errors from torch.distributed not having been initialized or
|
| 313 |
+
torch.distributed.barrier() hanging.
|
| 314 |
+
|
| 315 |
+
If all processes on each node may be loading the checkpoint
|
| 316 |
+
simultaneously, load_on_all_ranks should be set to True to avoid I/O
|
| 317 |
+
conflicts.
|
| 318 |
+
|
| 319 |
+
There's currently no support for > 1 but < all processes loading the
|
| 320 |
+
checkpoint on each node.
|
| 321 |
+
"""
|
| 322 |
+
local_path = PathManager.get_local_path(path)
|
| 323 |
+
# The locally cached file returned by get_local_path() may be stale for
|
| 324 |
+
# remote files that are periodically updated/overwritten (ex:
|
| 325 |
+
# checkpoint_last.pt) - so we remove the local copy, sync across processes
|
| 326 |
+
# (if needed), and then download a fresh copy.
|
| 327 |
+
if local_path != path and PathManager.path_requires_pathmanager(path):
|
| 328 |
+
try:
|
| 329 |
+
os.remove(local_path)
|
| 330 |
+
except FileNotFoundError:
|
| 331 |
+
# With potentially multiple processes removing the same file, the
|
| 332 |
+
# file being missing is benign (missing_ok isn't available until
|
| 333 |
+
# Python 3.8).
|
| 334 |
+
pass
|
| 335 |
+
if load_on_all_ranks:
|
| 336 |
+
torch.distributed.barrier()
|
| 337 |
+
local_path = PathManager.get_local_path(path)
|
| 338 |
+
|
| 339 |
+
with open(local_path, "rb") as f:
|
| 340 |
+
state = torch.load(f, map_location=torch.device("cpu"), weights_only=False)
|
| 341 |
+
|
| 342 |
+
if "args" in state and state["args"] is not None and arg_overrides is not None:
|
| 343 |
+
args = state["args"]
|
| 344 |
+
for arg_name, arg_val in arg_overrides.items():
|
| 345 |
+
setattr(args, arg_name, arg_val)
|
| 346 |
+
|
| 347 |
+
if "cfg" in state and state["cfg"] is not None:
|
| 348 |
+
|
| 349 |
+
# hack to be able to set Namespace in dict config. this should be removed when we update to newer
|
| 350 |
+
# omegaconf version that supports object flags, or when we migrate all existing models
|
| 351 |
+
from omegaconf import __version__ as oc_version
|
| 352 |
+
from omegaconf import _utils
|
| 353 |
+
|
| 354 |
+
if oc_version < "2.2":
|
| 355 |
+
old_primitive = _utils.is_primitive_type
|
| 356 |
+
_utils.is_primitive_type = lambda _: True
|
| 357 |
+
|
| 358 |
+
state["cfg"] = OmegaConf.create(state["cfg"])
|
| 359 |
+
|
| 360 |
+
_utils.is_primitive_type = old_primitive
|
| 361 |
+
OmegaConf.set_struct(state["cfg"], True)
|
| 362 |
+
else:
|
| 363 |
+
state["cfg"] = OmegaConf.create(state["cfg"], flags={"allow_objects": True})
|
| 364 |
+
|
| 365 |
+
if arg_overrides is not None:
|
| 366 |
+
overwrite_args_by_name(state["cfg"], arg_overrides)
|
| 367 |
+
|
| 368 |
+
state = _upgrade_state_dict(state)
|
| 369 |
+
return state
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
def load_model_ensemble(
|
| 373 |
+
filenames,
|
| 374 |
+
arg_overrides: Optional[Dict[str, Any]] = None,
|
| 375 |
+
task=None,
|
| 376 |
+
strict=True,
|
| 377 |
+
suffix="",
|
| 378 |
+
num_shards=1,
|
| 379 |
+
state=None,
|
| 380 |
+
):
|
| 381 |
+
"""Loads an ensemble of models.
|
| 382 |
+
|
| 383 |
+
Args:
|
| 384 |
+
filenames (List[str]): checkpoint files to load
|
| 385 |
+
arg_overrides (Dict[str,Any], optional): override model args that
|
| 386 |
+
were used during model training
|
| 387 |
+
task (fairseq.tasks.FairseqTask, optional): task to use for loading
|
| 388 |
+
"""
|
| 389 |
+
assert not (
|
| 390 |
+
strict and num_shards > 1
|
| 391 |
+
), "Cannot load state dict with strict=True and checkpoint shards > 1"
|
| 392 |
+
ensemble, args, _task = load_model_ensemble_and_task(
|
| 393 |
+
filenames,
|
| 394 |
+
arg_overrides,
|
| 395 |
+
task,
|
| 396 |
+
strict,
|
| 397 |
+
suffix,
|
| 398 |
+
num_shards,
|
| 399 |
+
state,
|
| 400 |
+
)
|
| 401 |
+
return ensemble, args
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
def get_maybe_sharded_checkpoint_filename(
|
| 405 |
+
filename: str, suffix: str, shard_idx: int, num_shards: int
|
| 406 |
+
) -> str:
|
| 407 |
+
orig_filename = filename
|
| 408 |
+
filename = filename.replace(".pt", suffix + ".pt")
|
| 409 |
+
fsdp_filename = filename[:-3] + f"-shard{shard_idx}.pt"
|
| 410 |
+
model_parallel_filename = orig_filename[:-3] + f"_part{shard_idx}.pt"
|
| 411 |
+
if PathManager.exists(fsdp_filename):
|
| 412 |
+
return fsdp_filename
|
| 413 |
+
elif num_shards > 1:
|
| 414 |
+
return model_parallel_filename
|
| 415 |
+
else:
|
| 416 |
+
return filename
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
def load_model_ensemble_and_task(
|
| 420 |
+
filenames,
|
| 421 |
+
arg_overrides: Optional[Dict[str, Any]] = None,
|
| 422 |
+
task=None,
|
| 423 |
+
strict=True,
|
| 424 |
+
suffix="",
|
| 425 |
+
num_shards=1,
|
| 426 |
+
state=None,
|
| 427 |
+
):
|
| 428 |
+
assert state is None or len(filenames) == 1
|
| 429 |
+
|
| 430 |
+
from fairseq import tasks
|
| 431 |
+
|
| 432 |
+
assert not (
|
| 433 |
+
strict and num_shards > 1
|
| 434 |
+
), "Cannot load state dict with strict=True and checkpoint shards > 1"
|
| 435 |
+
ensemble = []
|
| 436 |
+
cfg = None
|
| 437 |
+
for filename in filenames:
|
| 438 |
+
orig_filename = filename
|
| 439 |
+
model_shard_state = {"shard_weights": [], "shard_metadata": []}
|
| 440 |
+
assert num_shards > 0
|
| 441 |
+
st = time.time()
|
| 442 |
+
for shard_idx in range(num_shards):
|
| 443 |
+
filename = get_maybe_sharded_checkpoint_filename(
|
| 444 |
+
orig_filename, suffix, shard_idx, num_shards
|
| 445 |
+
)
|
| 446 |
+
|
| 447 |
+
if not PathManager.exists(filename):
|
| 448 |
+
raise IOError("Model file not found: {}".format(filename))
|
| 449 |
+
if state is None:
|
| 450 |
+
state = load_checkpoint_to_cpu(filename, arg_overrides)
|
| 451 |
+
if "args" in state and state["args"] is not None:
|
| 452 |
+
cfg = convert_namespace_to_omegaconf(state["args"])
|
| 453 |
+
elif "cfg" in state and state["cfg"] is not None:
|
| 454 |
+
cfg = state["cfg"]
|
| 455 |
+
else:
|
| 456 |
+
raise RuntimeError(
|
| 457 |
+
f"Neither args nor cfg exist in state keys = {state.keys()}"
|
| 458 |
+
)
|
| 459 |
+
|
| 460 |
+
if task is None:
|
| 461 |
+
task = tasks.setup_task(cfg.task, from_checkpoint=True)
|
| 462 |
+
|
| 463 |
+
if "task_state" in state:
|
| 464 |
+
task.load_state_dict(state["task_state"])
|
| 465 |
+
|
| 466 |
+
argspec = inspect.getfullargspec(task.build_model)
|
| 467 |
+
|
| 468 |
+
if "fsdp_metadata" in state and num_shards > 1:
|
| 469 |
+
model_shard_state["shard_weights"].append(state["model"])
|
| 470 |
+
model_shard_state["shard_metadata"].append(state["fsdp_metadata"])
|
| 471 |
+
# check FSDP import before the code goes too far
|
| 472 |
+
if not has_FSDP:
|
| 473 |
+
raise ImportError(
|
| 474 |
+
"Cannot find FullyShardedDataParallel. "
|
| 475 |
+
"Please install fairscale with: pip install fairscale"
|
| 476 |
+
)
|
| 477 |
+
if shard_idx == num_shards - 1:
|
| 478 |
+
consolidated_model_state = FSDP.consolidate_shard_weights(
|
| 479 |
+
shard_weights=model_shard_state["shard_weights"],
|
| 480 |
+
shard_metadata=model_shard_state["shard_metadata"],
|
| 481 |
+
)
|
| 482 |
+
if "from_checkpoint" in argspec.args:
|
| 483 |
+
model = task.build_model(cfg.model, from_checkpoint=True)
|
| 484 |
+
else:
|
| 485 |
+
model = task.build_model(cfg.model)
|
| 486 |
+
if (
|
| 487 |
+
"optimizer_history" in state
|
| 488 |
+
and len(state["optimizer_history"]) > 0
|
| 489 |
+
and "num_updates" in state["optimizer_history"][-1]
|
| 490 |
+
):
|
| 491 |
+
model.set_num_updates(
|
| 492 |
+
state["optimizer_history"][-1]["num_updates"]
|
| 493 |
+
)
|
| 494 |
+
model.load_state_dict(
|
| 495 |
+
consolidated_model_state, strict=strict, model_cfg=cfg.model
|
| 496 |
+
)
|
| 497 |
+
else:
|
| 498 |
+
# model parallel checkpoint or unsharded checkpoint
|
| 499 |
+
# support old external tasks
|
| 500 |
+
|
| 501 |
+
if "from_checkpoint" in argspec.args:
|
| 502 |
+
model = task.build_model(cfg.model, from_checkpoint=True)
|
| 503 |
+
else:
|
| 504 |
+
model = task.build_model(cfg.model)
|
| 505 |
+
if (
|
| 506 |
+
"optimizer_history" in state
|
| 507 |
+
and len(state["optimizer_history"]) > 0
|
| 508 |
+
and "num_updates" in state["optimizer_history"][-1]
|
| 509 |
+
):
|
| 510 |
+
model.set_num_updates(state["optimizer_history"][-1]["num_updates"])
|
| 511 |
+
model.load_state_dict(
|
| 512 |
+
state["model"], strict=strict, model_cfg=cfg.model
|
| 513 |
+
)
|
| 514 |
+
|
| 515 |
+
# reset state so it gets loaded for the next model in ensemble
|
| 516 |
+
state = None
|
| 517 |
+
if shard_idx % 10 == 0 and shard_idx > 0:
|
| 518 |
+
elapsed = time.time() - st
|
| 519 |
+
logger.info(
|
| 520 |
+
f"Loaded {shard_idx} shards in {elapsed:.2f}s, {elapsed / (shard_idx+1):.2f}s/shard"
|
| 521 |
+
)
|
| 522 |
+
|
| 523 |
+
# build model for ensemble
|
| 524 |
+
ensemble.append(model)
|
| 525 |
+
return ensemble, cfg, task
|
| 526 |
+
|
| 527 |
+
|
| 528 |
+
def load_model_ensemble_and_task_from_hf_hub(
|
| 529 |
+
model_id,
|
| 530 |
+
cache_dir: Optional[str] = None,
|
| 531 |
+
arg_overrides: Optional[Dict[str, Any]] = None,
|
| 532 |
+
**kwargs: Any,
|
| 533 |
+
):
|
| 534 |
+
try:
|
| 535 |
+
from huggingface_hub import snapshot_download
|
| 536 |
+
except ImportError:
|
| 537 |
+
raise ImportError(
|
| 538 |
+
"You need to install huggingface_hub to use `load_from_hf_hub`. "
|
| 539 |
+
"See https://pypi.org/project/huggingface-hub/ for installation."
|
| 540 |
+
)
|
| 541 |
+
|
| 542 |
+
library_name = "fairseq"
|
| 543 |
+
cache_dir = cache_dir or (Path.home() / ".cache" / library_name).as_posix()
|
| 544 |
+
cache_dir = snapshot_download(
|
| 545 |
+
model_id, cache_dir=cache_dir, library_name=library_name, **kwargs
|
| 546 |
+
)
|
| 547 |
+
|
| 548 |
+
_arg_overrides = arg_overrides or {}
|
| 549 |
+
_arg_overrides["data"] = cache_dir
|
| 550 |
+
return load_model_ensemble_and_task(
|
| 551 |
+
[p.as_posix() for p in Path(cache_dir).glob("*.pt")],
|
| 552 |
+
arg_overrides=_arg_overrides,
|
| 553 |
+
)
|
| 554 |
+
|
| 555 |
+
|
| 556 |
+
def checkpoint_paths(path, pattern=r"checkpoint(\d+)\.pt", keep_match=False):
|
| 557 |
+
"""Retrieves all checkpoints found in `path` directory.
|
| 558 |
+
|
| 559 |
+
Checkpoints are identified by matching filename to the specified pattern. If
|
| 560 |
+
the pattern contains groups, the result will be sorted by the first group in
|
| 561 |
+
descending order.
|
| 562 |
+
"""
|
| 563 |
+
pt_regexp = re.compile(pattern)
|
| 564 |
+
files = PathManager.ls(path)
|
| 565 |
+
|
| 566 |
+
entries = []
|
| 567 |
+
for i, f in enumerate(files):
|
| 568 |
+
m = pt_regexp.fullmatch(f)
|
| 569 |
+
if m is not None:
|
| 570 |
+
idx = float(m.group(1)) if len(m.groups()) > 0 else i
|
| 571 |
+
entries.append((idx, m.group(0)))
|
| 572 |
+
if keep_match:
|
| 573 |
+
return [(os.path.join(path, x[1]), x[0]) for x in sorted(entries, reverse=True)]
|
| 574 |
+
else:
|
| 575 |
+
return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)]
|
| 576 |
+
|
| 577 |
+
|
| 578 |
+
def torch_persistent_save(obj, filename, async_write: bool = False):
|
| 579 |
+
if async_write:
|
| 580 |
+
with PathManager.opena(filename, "wb") as f:
|
| 581 |
+
_torch_persistent_save(obj, f)
|
| 582 |
+
else:
|
| 583 |
+
if PathManager.supports_rename(filename):
|
| 584 |
+
# do atomic save
|
| 585 |
+
with PathManager.open(filename + ".tmp", "wb") as f:
|
| 586 |
+
_torch_persistent_save(obj, f)
|
| 587 |
+
PathManager.rename(filename + ".tmp", filename)
|
| 588 |
+
else:
|
| 589 |
+
# fallback to non-atomic save
|
| 590 |
+
with PathManager.open(filename, "wb") as f:
|
| 591 |
+
_torch_persistent_save(obj, f)
|
| 592 |
+
|
| 593 |
+
|
| 594 |
+
def _torch_persistent_save(obj, f):
|
| 595 |
+
if isinstance(f, str):
|
| 596 |
+
with PathManager.open(f, "wb") as h:
|
| 597 |
+
torch_persistent_save(obj, h)
|
| 598 |
+
return
|
| 599 |
+
for i in range(3):
|
| 600 |
+
try:
|
| 601 |
+
return torch.save(obj, f)
|
| 602 |
+
except Exception:
|
| 603 |
+
if i == 2:
|
| 604 |
+
logger.error(traceback.format_exc())
|
| 605 |
+
raise
|
| 606 |
+
else:
|
| 607 |
+
time.sleep(2.5)
|
| 608 |
+
|
| 609 |
+
|
| 610 |
+
def _upgrade_state_dict(state):
|
| 611 |
+
"""Helper for upgrading old model checkpoints."""
|
| 612 |
+
|
| 613 |
+
# add optimizer_history
|
| 614 |
+
if "optimizer_history" not in state:
|
| 615 |
+
state["optimizer_history"] = [
|
| 616 |
+
{"criterion_name": "CrossEntropyCriterion", "best_loss": state["best_loss"]}
|
| 617 |
+
]
|
| 618 |
+
state["last_optimizer_state"] = state["optimizer"]
|
| 619 |
+
del state["optimizer"]
|
| 620 |
+
del state["best_loss"]
|
| 621 |
+
# move extra_state into sub-dictionary
|
| 622 |
+
if "epoch" in state and "extra_state" not in state:
|
| 623 |
+
state["extra_state"] = {
|
| 624 |
+
"epoch": state["epoch"],
|
| 625 |
+
"batch_offset": state["batch_offset"],
|
| 626 |
+
"val_loss": state["val_loss"],
|
| 627 |
+
}
|
| 628 |
+
del state["epoch"]
|
| 629 |
+
del state["batch_offset"]
|
| 630 |
+
del state["val_loss"]
|
| 631 |
+
# reduce optimizer history's memory usage (only keep the last state)
|
| 632 |
+
if "optimizer" in state["optimizer_history"][-1]:
|
| 633 |
+
state["last_optimizer_state"] = state["optimizer_history"][-1]["optimizer"]
|
| 634 |
+
for optim_hist in state["optimizer_history"]:
|
| 635 |
+
del optim_hist["optimizer"]
|
| 636 |
+
# record the optimizer class name
|
| 637 |
+
if "optimizer_name" not in state["optimizer_history"][-1]:
|
| 638 |
+
state["optimizer_history"][-1]["optimizer_name"] = "FairseqNAG"
|
| 639 |
+
# move best_loss into lr_scheduler_state
|
| 640 |
+
if "lr_scheduler_state" not in state["optimizer_history"][-1]:
|
| 641 |
+
state["optimizer_history"][-1]["lr_scheduler_state"] = {
|
| 642 |
+
"best": state["optimizer_history"][-1]["best_loss"]
|
| 643 |
+
}
|
| 644 |
+
del state["optimizer_history"][-1]["best_loss"]
|
| 645 |
+
# keep track of number of updates
|
| 646 |
+
if "num_updates" not in state["optimizer_history"][-1]:
|
| 647 |
+
state["optimizer_history"][-1]["num_updates"] = 0
|
| 648 |
+
# use stateful training data iterator
|
| 649 |
+
if "train_iterator" not in state["extra_state"]:
|
| 650 |
+
state["extra_state"]["train_iterator"] = {
|
| 651 |
+
"epoch": state["extra_state"].get("epoch", 0),
|
| 652 |
+
"iterations_in_epoch": state["extra_state"].get("batch_offset", 0),
|
| 653 |
+
}
|
| 654 |
+
|
| 655 |
+
# backward compatibility, cfg updates
|
| 656 |
+
if "args" in state and state["args"] is not None:
|
| 657 |
+
# old model checkpoints may not have separate source/target positions
|
| 658 |
+
if hasattr(state["args"], "max_positions") and not hasattr(
|
| 659 |
+
state["args"], "max_source_positions"
|
| 660 |
+
):
|
| 661 |
+
state["args"].max_source_positions = state["args"].max_positions
|
| 662 |
+
state["args"].max_target_positions = state["args"].max_positions
|
| 663 |
+
# default to translation task
|
| 664 |
+
if not hasattr(state["args"], "task"):
|
| 665 |
+
state["args"].task = "translation"
|
| 666 |
+
# --raw-text and --lazy-load are deprecated
|
| 667 |
+
if getattr(state["args"], "raw_text", False):
|
| 668 |
+
state["args"].dataset_impl = "raw"
|
| 669 |
+
elif getattr(state["args"], "lazy_load", False):
|
| 670 |
+
state["args"].dataset_impl = "lazy"
|
| 671 |
+
# epochs start at 1
|
| 672 |
+
if state["extra_state"]["train_iterator"] is not None:
|
| 673 |
+
state["extra_state"]["train_iterator"]["epoch"] = max(
|
| 674 |
+
state["extra_state"]["train_iterator"].get("epoch", 1), 1
|
| 675 |
+
)
|
| 676 |
+
# --remove-bpe ==> --postprocess
|
| 677 |
+
if hasattr(state["args"], "remove_bpe"):
|
| 678 |
+
state["args"].post_process = state["args"].remove_bpe
|
| 679 |
+
# --min-lr ==> --stop-min-lr
|
| 680 |
+
if hasattr(state["args"], "min_lr"):
|
| 681 |
+
state["args"].stop_min_lr = state["args"].min_lr
|
| 682 |
+
del state["args"].min_lr
|
| 683 |
+
# binary_cross_entropy / kd_binary_cross_entropy => wav2vec criterion
|
| 684 |
+
if hasattr(state["args"], "criterion") and state["args"].criterion in [
|
| 685 |
+
"binary_cross_entropy",
|
| 686 |
+
"kd_binary_cross_entropy",
|
| 687 |
+
]:
|
| 688 |
+
state["args"].criterion = "wav2vec"
|
| 689 |
+
# remove log_keys if it's None (criteria will supply a default value of [])
|
| 690 |
+
if hasattr(state["args"], "log_keys") and state["args"].log_keys is None:
|
| 691 |
+
delattr(state["args"], "log_keys")
|
| 692 |
+
# speech_pretraining => audio pretraining
|
| 693 |
+
if (
|
| 694 |
+
hasattr(state["args"], "task")
|
| 695 |
+
and state["args"].task == "speech_pretraining"
|
| 696 |
+
):
|
| 697 |
+
state["args"].task = "audio_pretraining"
|
| 698 |
+
# audio_cpc => wav2vec
|
| 699 |
+
if hasattr(state["args"], "arch") and state["args"].arch == "audio_cpc":
|
| 700 |
+
state["args"].arch = "wav2vec"
|
| 701 |
+
# convert legacy float learning rate to List[float]
|
| 702 |
+
if hasattr(state["args"], "lr") and isinstance(state["args"].lr, float):
|
| 703 |
+
state["args"].lr = [state["args"].lr]
|
| 704 |
+
# convert task data arg to a string instead of List[string]
|
| 705 |
+
if (
|
| 706 |
+
hasattr(state["args"], "data")
|
| 707 |
+
and isinstance(state["args"].data, list)
|
| 708 |
+
and len(state["args"].data) > 0
|
| 709 |
+
):
|
| 710 |
+
state["args"].data = state["args"].data[0]
|
| 711 |
+
|
| 712 |
+
state["cfg"] = convert_namespace_to_omegaconf(state["args"])
|
| 713 |
+
|
| 714 |
+
if "cfg" in state and state["cfg"] is not None:
|
| 715 |
+
cfg = state["cfg"]
|
| 716 |
+
with open_dict(cfg):
|
| 717 |
+
# any upgrades for Hydra-based configs
|
| 718 |
+
if (
|
| 719 |
+
"task" in cfg
|
| 720 |
+
and "eval_wer_config" in cfg.task
|
| 721 |
+
and isinstance(cfg.task.eval_wer_config.print_alignment, bool)
|
| 722 |
+
):
|
| 723 |
+
cfg.task.eval_wer_config.print_alignment = "hard"
|
| 724 |
+
if "generation" in cfg and isinstance(cfg.generation.print_alignment, bool):
|
| 725 |
+
cfg.generation.print_alignment = (
|
| 726 |
+
"hard" if cfg.generation.print_alignment else None
|
| 727 |
+
)
|
| 728 |
+
if (
|
| 729 |
+
"model" in cfg
|
| 730 |
+
and "w2v_args" in cfg.model
|
| 731 |
+
and cfg.model.w2v_args is not None
|
| 732 |
+
and (
|
| 733 |
+
hasattr(cfg.model.w2v_args, "task") or "task" in cfg.model.w2v_args
|
| 734 |
+
)
|
| 735 |
+
and hasattr(cfg.model.w2v_args.task, "eval_wer_config")
|
| 736 |
+
and cfg.model.w2v_args.task.eval_wer_config is not None
|
| 737 |
+
and isinstance(
|
| 738 |
+
cfg.model.w2v_args.task.eval_wer_config.print_alignment, bool
|
| 739 |
+
)
|
| 740 |
+
):
|
| 741 |
+
cfg.model.w2v_args.task.eval_wer_config.print_alignment = "hard"
|
| 742 |
+
|
| 743 |
+
return state
|
| 744 |
+
|
| 745 |
+
|
| 746 |
+
def prune_state_dict(state_dict, model_cfg: Optional[DictConfig]):
|
| 747 |
+
"""Prune the given state_dict if desired for LayerDrop
|
| 748 |
+
(https://arxiv.org/abs/1909.11556).
|
| 749 |
+
|
| 750 |
+
Training with LayerDrop allows models to be robust to pruning at inference
|
| 751 |
+
time. This function prunes state_dict to allow smaller models to be loaded
|
| 752 |
+
from a larger model and re-maps the existing state_dict for this to occur.
|
| 753 |
+
|
| 754 |
+
It's called by functions that load models from checkpoints and does not
|
| 755 |
+
need to be called directly.
|
| 756 |
+
"""
|
| 757 |
+
arch = None
|
| 758 |
+
if model_cfg is not None:
|
| 759 |
+
arch = (
|
| 760 |
+
model_cfg._name
|
| 761 |
+
if isinstance(model_cfg, DictConfig)
|
| 762 |
+
else getattr(model_cfg, "arch", None)
|
| 763 |
+
)
|
| 764 |
+
|
| 765 |
+
if not model_cfg or arch is None or arch == "ptt_transformer":
|
| 766 |
+
# args should not be none, but don't crash if it is.
|
| 767 |
+
return state_dict
|
| 768 |
+
|
| 769 |
+
encoder_layers_to_keep = getattr(model_cfg, "encoder_layers_to_keep", None)
|
| 770 |
+
decoder_layers_to_keep = getattr(model_cfg, "decoder_layers_to_keep", None)
|
| 771 |
+
|
| 772 |
+
if not encoder_layers_to_keep and not decoder_layers_to_keep:
|
| 773 |
+
return state_dict
|
| 774 |
+
|
| 775 |
+
# apply pruning
|
| 776 |
+
logger.info(
|
| 777 |
+
"Pruning model to specified layer configuration - this works best if the model was trained with LayerDrop"
|
| 778 |
+
)
|
| 779 |
+
|
| 780 |
+
def create_pruning_pass(layers_to_keep, layer_name):
|
| 781 |
+
keep_layers = sorted(
|
| 782 |
+
int(layer_string) for layer_string in layers_to_keep.split(",")
|
| 783 |
+
)
|
| 784 |
+
mapping_dict = {}
|
| 785 |
+
for i in range(len(keep_layers)):
|
| 786 |
+
mapping_dict[str(keep_layers[i])] = str(i)
|
| 787 |
+
|
| 788 |
+
regex = re.compile(r"^{layer}.*\.layers\.(\d+)".format(layer=layer_name))
|
| 789 |
+
return {"substitution_regex": regex, "mapping_dict": mapping_dict}
|
| 790 |
+
|
| 791 |
+
pruning_passes = []
|
| 792 |
+
if encoder_layers_to_keep:
|
| 793 |
+
pruning_passes.append(create_pruning_pass(encoder_layers_to_keep, "encoder"))
|
| 794 |
+
if decoder_layers_to_keep:
|
| 795 |
+
pruning_passes.append(create_pruning_pass(decoder_layers_to_keep, "decoder"))
|
| 796 |
+
|
| 797 |
+
new_state_dict = {}
|
| 798 |
+
for layer_name in state_dict.keys():
|
| 799 |
+
match = re.search(r"\.layers\.(\d+)\.", layer_name)
|
| 800 |
+
# if layer has no number in it, it is a supporting layer, such as an
|
| 801 |
+
# embedding
|
| 802 |
+
if not match:
|
| 803 |
+
new_state_dict[layer_name] = state_dict[layer_name]
|
| 804 |
+
continue
|
| 805 |
+
|
| 806 |
+
# otherwise, layer should be pruned.
|
| 807 |
+
original_layer_number = match.group(1)
|
| 808 |
+
# figure out which mapping dict to replace from
|
| 809 |
+
for pruning_pass in pruning_passes:
|
| 810 |
+
if original_layer_number in pruning_pass["mapping_dict"] and pruning_pass[
|
| 811 |
+
"substitution_regex"
|
| 812 |
+
].search(layer_name):
|
| 813 |
+
new_layer_number = pruning_pass["mapping_dict"][original_layer_number]
|
| 814 |
+
substitution_match = pruning_pass["substitution_regex"].search(
|
| 815 |
+
layer_name
|
| 816 |
+
)
|
| 817 |
+
new_state_key = (
|
| 818 |
+
layer_name[: substitution_match.start(1)]
|
| 819 |
+
+ new_layer_number
|
| 820 |
+
+ layer_name[substitution_match.end(1) :]
|
| 821 |
+
)
|
| 822 |
+
new_state_dict[new_state_key] = state_dict[layer_name]
|
| 823 |
+
|
| 824 |
+
# Since layers are now pruned, *_layers_to_keep are no longer needed.
|
| 825 |
+
# This is more of "It would make it work fix" rather than a proper fix.
|
| 826 |
+
if isinstance(model_cfg, DictConfig):
|
| 827 |
+
context = open_dict(model_cfg)
|
| 828 |
+
else:
|
| 829 |
+
context = contextlib.ExitStack()
|
| 830 |
+
with context:
|
| 831 |
+
if hasattr(model_cfg, "encoder_layers_to_keep"):
|
| 832 |
+
model_cfg.encoder_layers_to_keep = None
|
| 833 |
+
if hasattr(model_cfg, "decoder_layers_to_keep"):
|
| 834 |
+
model_cfg.decoder_layers_to_keep = None
|
| 835 |
+
|
| 836 |
+
return new_state_dict
|
| 837 |
+
|
| 838 |
+
|
| 839 |
+
def load_pretrained_component_from_model(
|
| 840 |
+
component: Union[FairseqEncoder, FairseqDecoder],
|
| 841 |
+
checkpoint: str,
|
| 842 |
+
strict: bool = True,
|
| 843 |
+
):
|
| 844 |
+
"""
|
| 845 |
+
Load a pretrained FairseqEncoder or FairseqDecoder from checkpoint into the
|
| 846 |
+
provided `component` object. If state_dict fails to load, there may be a
|
| 847 |
+
mismatch in the architecture of the corresponding `component` found in the
|
| 848 |
+
`checkpoint` file.
|
| 849 |
+
"""
|
| 850 |
+
if not PathManager.exists(checkpoint):
|
| 851 |
+
raise IOError("Model file not found: {}".format(checkpoint))
|
| 852 |
+
state = load_checkpoint_to_cpu(checkpoint)
|
| 853 |
+
if isinstance(component, FairseqEncoder):
|
| 854 |
+
component_type = "encoder"
|
| 855 |
+
elif isinstance(component, FairseqDecoder):
|
| 856 |
+
component_type = "decoder"
|
| 857 |
+
else:
|
| 858 |
+
raise ValueError(
|
| 859 |
+
"component to load must be either a FairseqEncoder or "
|
| 860 |
+
"FairseqDecoder. Loading other component types are not supported."
|
| 861 |
+
)
|
| 862 |
+
component_state_dict = OrderedDict()
|
| 863 |
+
for key in state["model"].keys():
|
| 864 |
+
if key.startswith(component_type):
|
| 865 |
+
# encoder.input_layers.0.0.weight --> input_layers.0.0.weight
|
| 866 |
+
component_subkey = key[len(component_type) + 1 :]
|
| 867 |
+
component_state_dict[component_subkey] = state["model"][key]
|
| 868 |
+
component.load_state_dict(component_state_dict, strict=strict)
|
| 869 |
+
return component
|
| 870 |
+
|
| 871 |
+
|
| 872 |
+
def verify_checkpoint_directory(save_dir: str) -> None:
|
| 873 |
+
if not os.path.exists(save_dir):
|
| 874 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 875 |
+
temp_file_path = os.path.join(save_dir, "dummy")
|
| 876 |
+
try:
|
| 877 |
+
with open(temp_file_path, "w"):
|
| 878 |
+
pass
|
| 879 |
+
except OSError as e:
|
| 880 |
+
logger.warning(
|
| 881 |
+
"Unable to access checkpoint save directory: {}".format(save_dir)
|
| 882 |
+
)
|
| 883 |
+
raise e
|
| 884 |
+
else:
|
| 885 |
+
os.remove(temp_file_path)
|
| 886 |
+
|
| 887 |
+
|
| 888 |
+
def save_ema_as_checkpoint(src_path, dst_path):
|
| 889 |
+
state = load_ema_from_checkpoint(src_path)
|
| 890 |
+
torch_persistent_save(state, dst_path)
|
| 891 |
+
|
| 892 |
+
|
| 893 |
+
def load_ema_from_checkpoint(fpath):
|
| 894 |
+
"""Loads exponential moving averaged (EMA) checkpoint from input and
|
| 895 |
+
returns a model with ema weights.
|
| 896 |
+
|
| 897 |
+
Args:
|
| 898 |
+
fpath: A string path of checkpoint to load from.
|
| 899 |
+
|
| 900 |
+
Returns:
|
| 901 |
+
A dict of string keys mapping to various values. The 'model' key
|
| 902 |
+
from the returned dict should correspond to an OrderedDict mapping
|
| 903 |
+
string parameter names to torch Tensors.
|
| 904 |
+
"""
|
| 905 |
+
params_dict = collections.OrderedDict()
|
| 906 |
+
new_state = None
|
| 907 |
+
|
| 908 |
+
with PathManager.open(fpath, "rb") as f:
|
| 909 |
+
new_state = torch.load(
|
| 910 |
+
f,
|
| 911 |
+
map_location=(
|
| 912 |
+
lambda s, _: torch.serialization.default_restore_location(s, "cpu")
|
| 913 |
+
),
|
| 914 |
+
)
|
| 915 |
+
|
| 916 |
+
# EMA model is stored in a separate "extra state"
|
| 917 |
+
model_params = new_state["extra_state"]["ema"]
|
| 918 |
+
|
| 919 |
+
for key in list(model_params.keys()):
|
| 920 |
+
p = model_params[key]
|
| 921 |
+
if isinstance(p, torch.HalfTensor):
|
| 922 |
+
p = p.float()
|
| 923 |
+
if key not in params_dict:
|
| 924 |
+
params_dict[key] = p.clone()
|
| 925 |
+
# NOTE: clone() is needed in case of p is a shared parameter
|
| 926 |
+
else:
|
| 927 |
+
raise ValueError("Key {} is repeated in EMA model params.".format(key))
|
| 928 |
+
|
| 929 |
+
if len(params_dict) == 0:
|
| 930 |
+
raise ValueError(
|
| 931 |
+
f"Input checkpoint path '{fpath}' does not contain "
|
| 932 |
+
"ema model weights, is this model trained with EMA?"
|
| 933 |
+
)
|
| 934 |
+
|
| 935 |
+
new_state["model"] = params_dict
|
| 936 |
+
return new_state
|
fairseq/clib/cuda/ngram_repeat_block_cuda.cpp
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
Copyright (c) Microsoft Corporation.
|
| 3 |
+
Licensed under the MIT License.
|
| 4 |
+
*/
|
| 5 |
+
|
| 6 |
+
#include <torch/extension.h>
|
| 7 |
+
#include <vector>
|
| 8 |
+
|
| 9 |
+
/*
|
| 10 |
+
CPP Binding for CUDA OP
|
| 11 |
+
*/
|
| 12 |
+
|
| 13 |
+
// CUDA forward declarations
|
| 14 |
+
torch::Tensor ngram_repeat_block_cuda_forward(
|
| 15 |
+
torch::Tensor tokens,
|
| 16 |
+
torch::Tensor lprobs,
|
| 17 |
+
int bsz,
|
| 18 |
+
int step,
|
| 19 |
+
int beam_size,
|
| 20 |
+
int no_repeat_ngram_size);
|
| 21 |
+
|
| 22 |
+
#define CHECK_CUDA(x) \
|
| 23 |
+
TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
|
| 24 |
+
#define CHECK_CONTIGUOUS(x) \
|
| 25 |
+
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
| 26 |
+
#define CHECK_INPUT(x) \
|
| 27 |
+
CHECK_CUDA(x); \
|
| 28 |
+
CHECK_CONTIGUOUS(x)
|
| 29 |
+
|
| 30 |
+
// Input check and call to CUDA OP
|
| 31 |
+
// Backward method not required
|
| 32 |
+
torch::Tensor ngram_repeat_block_forward(
|
| 33 |
+
torch::Tensor tokens,
|
| 34 |
+
torch::Tensor lprobs,
|
| 35 |
+
int bsz,
|
| 36 |
+
int step,
|
| 37 |
+
int beam_size,
|
| 38 |
+
int no_repeat_ngram_size) {
|
| 39 |
+
CHECK_INPUT(tokens);
|
| 40 |
+
CHECK_INPUT(lprobs);
|
| 41 |
+
assert(bsz > 0);
|
| 42 |
+
assert(step >= 0);
|
| 43 |
+
assert(beam_size > 0);
|
| 44 |
+
assert(no_repeat_ngram_size > 0);
|
| 45 |
+
|
| 46 |
+
return ngram_repeat_block_cuda_forward(
|
| 47 |
+
tokens, lprobs, bsz, step, beam_size, no_repeat_ngram_size);
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 51 |
+
m.def(
|
| 52 |
+
"forward",
|
| 53 |
+
&ngram_repeat_block_forward,
|
| 54 |
+
"No Repeat Ngram Block forward (CUDA)");
|
| 55 |
+
}
|
fairseq/clib/cuda/ngram_repeat_block_cuda_kernel.cu
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
Copyright (c) Microsoft Corporation.
|
| 3 |
+
Licensed under the MIT License.
|
| 4 |
+
*/
|
| 5 |
+
|
| 6 |
+
/*
|
| 7 |
+
Kernel implementation for blocking repeated n-grams.
|
| 8 |
+
*/
|
| 9 |
+
|
| 10 |
+
#include <cuda.h>
|
| 11 |
+
#include <cuda_runtime.h>
|
| 12 |
+
#include <math.h>
|
| 13 |
+
#include <torch/extension.h>
|
| 14 |
+
#include <vector>
|
| 15 |
+
|
| 16 |
+
// Ban repeated ngrams of length = 'no_repeat_ngram_size'
|
| 17 |
+
__global__ void banRepeatedTokens(
|
| 18 |
+
long* __restrict__ tokens,
|
| 19 |
+
float* __restrict__ lprobs,
|
| 20 |
+
int max_predict_len,
|
| 21 |
+
int vocab_size,
|
| 22 |
+
int no_repeat_ngram_size) {
|
| 23 |
+
auto row = blockIdx.x;
|
| 24 |
+
auto col = threadIdx.x;
|
| 25 |
+
auto start = row * (max_predict_len) + col;
|
| 26 |
+
// Each thread compares ngram starting from
|
| 27 |
+
// thread index with final ngram starting from
|
| 28 |
+
// step - no_repeat_ngram_size +2
|
| 29 |
+
auto check_start_pos = blockDim.x;
|
| 30 |
+
auto lprob_start = row * vocab_size;
|
| 31 |
+
bool is_banned = true;
|
| 32 |
+
extern __shared__ long tokens_shm[];
|
| 33 |
+
tokens_shm[col] = tokens[start];
|
| 34 |
+
if (col == blockDim.x - 1) {
|
| 35 |
+
for (int i = 1; i < no_repeat_ngram_size; i++) {
|
| 36 |
+
if (col + i < max_predict_len) {
|
| 37 |
+
tokens_shm[col + i] = tokens[start + i];
|
| 38 |
+
}
|
| 39 |
+
}
|
| 40 |
+
}
|
| 41 |
+
__syncthreads();
|
| 42 |
+
|
| 43 |
+
for (int k = 0; k < no_repeat_ngram_size - 1; k++) {
|
| 44 |
+
if (tokens_shm[col + k] != tokens_shm[check_start_pos + k]) {
|
| 45 |
+
is_banned = false;
|
| 46 |
+
}
|
| 47 |
+
}
|
| 48 |
+
if (is_banned == true) {
|
| 49 |
+
auto token_to_be_banned = tokens_shm[col + no_repeat_ngram_size - 1];
|
| 50 |
+
lprobs[lprob_start + token_to_be_banned] = -INFINITY;
|
| 51 |
+
}
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
// Allocate blocks and threads based on
|
| 55 |
+
// batch size and sequence length and launch
|
| 56 |
+
// kernel
|
| 57 |
+
torch::Tensor ngram_repeat_block_cuda_forward(
|
| 58 |
+
const torch::Tensor tokens,
|
| 59 |
+
torch::Tensor lprobs,
|
| 60 |
+
int bsz,
|
| 61 |
+
int step,
|
| 62 |
+
int beam_size,
|
| 63 |
+
int no_repeat_ngram_size) {
|
| 64 |
+
int threads = step - no_repeat_ngram_size + 2;
|
| 65 |
+
if (threads <= 0)
|
| 66 |
+
return lprobs;
|
| 67 |
+
int max_predict_len = tokens.size(1);
|
| 68 |
+
int vocab_size = lprobs.size(1);
|
| 69 |
+
auto token_ptr = tokens.data_ptr<long>();
|
| 70 |
+
auto lprob_ptr = lprobs.data_ptr<float>();
|
| 71 |
+
int blocks = bsz * beam_size;
|
| 72 |
+
int shared_mem_size = (step + 1) * sizeof(long);
|
| 73 |
+
|
| 74 |
+
// Launching N blocks where N is number of samples in a batch (beams*bsz)
|
| 75 |
+
// Launching T threads where T is number of previous ngrams in a sample
|
| 76 |
+
// Allocating shared mem per block for fastser access of input tokens since
|
| 77 |
+
// each token will be accessed N times to compare with current Ngram where
|
| 78 |
+
// N is Ngram size.
|
| 79 |
+
banRepeatedTokens<<<blocks, threads, shared_mem_size>>>(
|
| 80 |
+
token_ptr, lprob_ptr, max_predict_len, vocab_size, no_repeat_ngram_size);
|
| 81 |
+
return lprobs;
|
| 82 |
+
}
|
fairseq/clib/libbase/balanced_assignment.cpp
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/**
|
| 2 |
+
* Copyright 2017-present, Facebook, Inc.
|
| 3 |
+
* All rights reserved.
|
| 4 |
+
*
|
| 5 |
+
* This source code is licensed under the license found in the
|
| 6 |
+
* LICENSE file in the root directory of this source tree.
|
| 7 |
+
*/
|
| 8 |
+
|
| 9 |
+
/*
|
| 10 |
+
C++ code for solving the linear assignment problem.
|
| 11 |
+
Based on the Auction Algorithm from
|
| 12 |
+
https://dspace.mit.edu/bitstream/handle/1721.1/3265/P-2108-26912652.pdf and the
|
| 13 |
+
implementation from: https://github.com/bkj/auction-lap Adapted to be more
|
| 14 |
+
efficient when each worker is looking for k jobs instead of 1.
|
| 15 |
+
*/
|
| 16 |
+
#include <torch/extension.h>
|
| 17 |
+
#include <iostream>
|
| 18 |
+
using namespace torch::indexing;
|
| 19 |
+
torch::Tensor balanced_assignment(torch::Tensor job_and_worker_to_score) {
|
| 20 |
+
int max_iterations = 100;
|
| 21 |
+
torch::Tensor epsilon =
|
| 22 |
+
(job_and_worker_to_score.max() - job_and_worker_to_score.min()) / 50;
|
| 23 |
+
epsilon.clamp_min_(1e-04);
|
| 24 |
+
torch::Tensor worker_and_job_to_score =
|
| 25 |
+
job_and_worker_to_score.detach().transpose(0, 1).contiguous();
|
| 26 |
+
int num_workers = worker_and_job_to_score.size(0);
|
| 27 |
+
int num_jobs = worker_and_job_to_score.size(1);
|
| 28 |
+
auto device = worker_and_job_to_score.device();
|
| 29 |
+
int jobs_per_worker = num_jobs / num_workers;
|
| 30 |
+
torch::Tensor value = worker_and_job_to_score.clone();
|
| 31 |
+
int counter = 0;
|
| 32 |
+
torch::Tensor max_value = worker_and_job_to_score.max();
|
| 33 |
+
|
| 34 |
+
torch::Tensor bid_indices;
|
| 35 |
+
torch::Tensor cost = worker_and_job_to_score.new_zeros({1, num_jobs});
|
| 36 |
+
torch::Tensor bids =
|
| 37 |
+
worker_and_job_to_score.new_empty({num_workers, num_jobs});
|
| 38 |
+
torch::Tensor bid_increments =
|
| 39 |
+
worker_and_job_to_score.new_empty({num_workers, jobs_per_worker});
|
| 40 |
+
torch::Tensor top_values =
|
| 41 |
+
worker_and_job_to_score.new_empty({num_workers, jobs_per_worker + 1});
|
| 42 |
+
torch::Tensor high_bids = worker_and_job_to_score.new_empty({num_jobs});
|
| 43 |
+
|
| 44 |
+
torch::Tensor top_index = top_values.to(torch::kLong);
|
| 45 |
+
torch::Tensor high_bidders = top_index.new_empty({num_jobs});
|
| 46 |
+
torch::Tensor have_bids = high_bidders.to(torch::kBool);
|
| 47 |
+
torch::Tensor jobs_indices =
|
| 48 |
+
torch::arange({num_jobs}, torch::dtype(torch::kLong).device(device));
|
| 49 |
+
torch::Tensor true_tensor =
|
| 50 |
+
torch::ones({1}, torch::dtype(torch::kBool).device(device));
|
| 51 |
+
|
| 52 |
+
while (true) {
|
| 53 |
+
bids.zero_();
|
| 54 |
+
torch::topk_out(top_values, top_index, value, jobs_per_worker + 1, 1);
|
| 55 |
+
|
| 56 |
+
// Each worker bids the difference in value between that job and the k+1th
|
| 57 |
+
// job
|
| 58 |
+
torch::sub_out(
|
| 59 |
+
bid_increments,
|
| 60 |
+
top_values.index({Slice(None, None), Slice(0, jobs_per_worker)}),
|
| 61 |
+
top_values.index({Slice(None, None), jobs_per_worker}).unsqueeze(1));
|
| 62 |
+
|
| 63 |
+
bid_increments.add_(epsilon);
|
| 64 |
+
bids.scatter_(
|
| 65 |
+
1,
|
| 66 |
+
top_index.index({Slice(None, None), Slice(0, jobs_per_worker)}),
|
| 67 |
+
bid_increments);
|
| 68 |
+
|
| 69 |
+
if (counter < max_iterations && counter > 0) {
|
| 70 |
+
// Put in a minimal bid to retain items from the last round if no-one else
|
| 71 |
+
// bids for them this round
|
| 72 |
+
bids.view(-1).index_put_({bid_indices}, epsilon);
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
// Find the highest bidding worker per job
|
| 76 |
+
torch::max_out(high_bids, high_bidders, bids, 0);
|
| 77 |
+
torch::gt_out(have_bids, high_bids, 0);
|
| 78 |
+
|
| 79 |
+
if (have_bids.all().item<bool>()) {
|
| 80 |
+
// All jobs were bid for
|
| 81 |
+
break;
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
// Make popular items more expensive
|
| 85 |
+
cost.add_(high_bids);
|
| 86 |
+
torch::sub_out(value, worker_and_job_to_score, cost);
|
| 87 |
+
|
| 88 |
+
bid_indices = ((high_bidders * num_jobs) + jobs_indices).index({have_bids});
|
| 89 |
+
|
| 90 |
+
if (counter < max_iterations) {
|
| 91 |
+
// Make sure that this item will be in the winning worker's top-k next
|
| 92 |
+
// time.
|
| 93 |
+
value.view(-1).index_put_({bid_indices}, max_value);
|
| 94 |
+
} else {
|
| 95 |
+
// Suboptimal approximation that converges quickly from current solution
|
| 96 |
+
value.view(-1).index_put_(
|
| 97 |
+
{bid_indices}, worker_and_job_to_score.view(-1).index({bid_indices}));
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
counter += 1;
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
return top_index.index({Slice(None, None), Slice(0, jobs_per_worker)})
|
| 104 |
+
.reshape(-1);
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 108 |
+
m.def("balanced_assignment", &balanced_assignment, "Balanced Assignment");
|
| 109 |
+
}
|
fairseq/clib/libbleu/libbleu.cpp
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/**
|
| 2 |
+
* Copyright 2017-present, Facebook, Inc.
|
| 3 |
+
* All rights reserved.
|
| 4 |
+
*
|
| 5 |
+
* This source code is licensed under the license found in the
|
| 6 |
+
* LICENSE file in the root directory of this source tree.
|
| 7 |
+
*/
|
| 8 |
+
|
| 9 |
+
#include <array>
|
| 10 |
+
#include <cstdio>
|
| 11 |
+
#include <cstring>
|
| 12 |
+
#include <map>
|
| 13 |
+
|
| 14 |
+
// NOLINTNEXTLINE
|
| 15 |
+
typedef struct {
|
| 16 |
+
size_t reflen;
|
| 17 |
+
size_t predlen;
|
| 18 |
+
size_t match1;
|
| 19 |
+
size_t count1;
|
| 20 |
+
size_t match2;
|
| 21 |
+
size_t count2;
|
| 22 |
+
size_t match3;
|
| 23 |
+
size_t count3;
|
| 24 |
+
size_t match4;
|
| 25 |
+
size_t count4;
|
| 26 |
+
} bleu_stat;
|
| 27 |
+
|
| 28 |
+
// left trim (remove pad)
|
| 29 |
+
void bleu_ltrim(size_t* len, int** sent, int pad) {
|
| 30 |
+
size_t start = 0;
|
| 31 |
+
while (start < *len) {
|
| 32 |
+
if (*(*sent + start) != pad) {
|
| 33 |
+
break;
|
| 34 |
+
}
|
| 35 |
+
start++;
|
| 36 |
+
}
|
| 37 |
+
*sent += start;
|
| 38 |
+
*len -= start;
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
// right trim remove (eos)
|
| 42 |
+
void bleu_rtrim(size_t* len, int** sent, int pad, int eos) {
|
| 43 |
+
size_t end = *len - 1;
|
| 44 |
+
while (end > 0) {
|
| 45 |
+
if (*(*sent + end) != eos && *(*sent + end) != pad) {
|
| 46 |
+
break;
|
| 47 |
+
}
|
| 48 |
+
end--;
|
| 49 |
+
}
|
| 50 |
+
*len = end + 1;
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
// left and right trim
|
| 54 |
+
void bleu_trim(size_t* len, int** sent, int pad, int eos) {
|
| 55 |
+
bleu_ltrim(len, sent, pad);
|
| 56 |
+
bleu_rtrim(len, sent, pad, eos);
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
size_t bleu_hash(int len, int* data) {
|
| 60 |
+
size_t h = 14695981039346656037ul;
|
| 61 |
+
size_t prime = 0x100000001b3;
|
| 62 |
+
char* b = (char*)data;
|
| 63 |
+
size_t blen = sizeof(int) * len;
|
| 64 |
+
|
| 65 |
+
while (blen-- > 0) {
|
| 66 |
+
h ^= *b++;
|
| 67 |
+
h *= prime;
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
return h;
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
void bleu_addngram(
|
| 74 |
+
size_t* ntotal,
|
| 75 |
+
size_t* nmatch,
|
| 76 |
+
size_t n,
|
| 77 |
+
size_t reflen,
|
| 78 |
+
int* ref,
|
| 79 |
+
size_t predlen,
|
| 80 |
+
int* pred) {
|
| 81 |
+
if (predlen < n) {
|
| 82 |
+
return;
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
predlen = predlen - n + 1;
|
| 86 |
+
(*ntotal) += predlen;
|
| 87 |
+
|
| 88 |
+
if (reflen < n) {
|
| 89 |
+
return;
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
reflen = reflen - n + 1;
|
| 93 |
+
|
| 94 |
+
std::map<size_t, size_t> count;
|
| 95 |
+
while (predlen > 0) {
|
| 96 |
+
size_t w = bleu_hash(n, pred++);
|
| 97 |
+
count[w]++;
|
| 98 |
+
predlen--;
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
while (reflen > 0) {
|
| 102 |
+
size_t w = bleu_hash(n, ref++);
|
| 103 |
+
if (count[w] > 0) {
|
| 104 |
+
(*nmatch)++;
|
| 105 |
+
count[w] -= 1;
|
| 106 |
+
}
|
| 107 |
+
reflen--;
|
| 108 |
+
}
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
extern "C" {
|
| 112 |
+
|
| 113 |
+
#ifdef _WIN64
|
| 114 |
+
__declspec(dllexport)
|
| 115 |
+
#endif
|
| 116 |
+
void bleu_zero_init(bleu_stat* stat) {
|
| 117 |
+
std::memset(stat, 0, sizeof(bleu_stat));
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
#ifdef _WIN64
|
| 121 |
+
__declspec(dllexport)
|
| 122 |
+
#endif
|
| 123 |
+
void bleu_one_init(bleu_stat* stat) {
|
| 124 |
+
bleu_zero_init(stat);
|
| 125 |
+
stat->count1 = 0;
|
| 126 |
+
stat->count2 = 1;
|
| 127 |
+
stat->count3 = 1;
|
| 128 |
+
stat->count4 = 1;
|
| 129 |
+
stat->match1 = 0;
|
| 130 |
+
stat->match2 = 1;
|
| 131 |
+
stat->match3 = 1;
|
| 132 |
+
stat->match4 = 1;
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
#ifdef _WIN64
|
| 136 |
+
__declspec(dllexport)
|
| 137 |
+
#endif
|
| 138 |
+
void bleu_add(
|
| 139 |
+
bleu_stat* stat,
|
| 140 |
+
size_t reflen,
|
| 141 |
+
int* ref,
|
| 142 |
+
size_t predlen,
|
| 143 |
+
int* pred,
|
| 144 |
+
int pad,
|
| 145 |
+
int eos) {
|
| 146 |
+
|
| 147 |
+
bleu_trim(&reflen, &ref, pad, eos);
|
| 148 |
+
bleu_trim(&predlen, &pred, pad, eos);
|
| 149 |
+
stat->reflen += reflen;
|
| 150 |
+
stat->predlen += predlen;
|
| 151 |
+
|
| 152 |
+
bleu_addngram(&stat->count1, &stat->match1, 1, reflen, ref, predlen, pred);
|
| 153 |
+
bleu_addngram(&stat->count2, &stat->match2, 2, reflen, ref, predlen, pred);
|
| 154 |
+
bleu_addngram(&stat->count3, &stat->match3, 3, reflen, ref, predlen, pred);
|
| 155 |
+
bleu_addngram(&stat->count4, &stat->match4, 4, reflen, ref, predlen, pred);
|
| 156 |
+
}
|
| 157 |
+
}
|
fairseq/clib/libbleu/module.cpp
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/**
|
| 2 |
+
* Copyright 2017-present, Facebook, Inc.
|
| 3 |
+
* All rights reserved.
|
| 4 |
+
*
|
| 5 |
+
* This source code is licensed under the license found in the
|
| 6 |
+
* LICENSE file in the root directory of this source tree.
|
| 7 |
+
*/
|
| 8 |
+
|
| 9 |
+
#include <Python.h>
|
| 10 |
+
|
| 11 |
+
static PyMethodDef method_def[] = {{NULL, NULL, 0, NULL}}; // NOLINT
|
| 12 |
+
|
| 13 |
+
static struct PyModuleDef module_def = {
|
| 14 |
+
PyModuleDef_HEAD_INIT,
|
| 15 |
+
"libbleu", /* name of module */
|
| 16 |
+
// NOLINTNEXTLINE
|
| 17 |
+
NULL, /* module documentation, may be NULL */
|
| 18 |
+
-1, /* size of per-interpreter state of the module,
|
| 19 |
+
or -1 if the module keeps state in global variables. */
|
| 20 |
+
method_def}; // NOLINT
|
| 21 |
+
|
| 22 |
+
#if PY_MAJOR_VERSION == 2
|
| 23 |
+
PyMODINIT_FUNC init_libbleu()
|
| 24 |
+
#else
|
| 25 |
+
PyMODINIT_FUNC PyInit_libbleu()
|
| 26 |
+
#endif
|
| 27 |
+
{
|
| 28 |
+
PyObject* m = PyModule_Create(&module_def);
|
| 29 |
+
if (!m) {
|
| 30 |
+
return NULL;
|
| 31 |
+
}
|
| 32 |
+
return m;
|
| 33 |
+
}
|
fairseq/clib/libnat/edit_dist.cpp
ADDED
|
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/**
|
| 2 |
+
* Copyright 2017-present, Facebook, Inc.
|
| 3 |
+
* All rights reserved.
|
| 4 |
+
*
|
| 5 |
+
* This source code is licensed under the license found in the
|
| 6 |
+
* LICENSE file in the root directory of this source tree.
|
| 7 |
+
*/
|
| 8 |
+
|
| 9 |
+
#include <pybind11/detail/common.h>
|
| 10 |
+
#include <pybind11/pybind11.h>
|
| 11 |
+
#include <torch/torch.h> // @manual=//caffe2:torch_extension
|
| 12 |
+
#include <algorithm>
|
| 13 |
+
#include <cstdint>
|
| 14 |
+
#include <iosfwd>
|
| 15 |
+
#include <memory>
|
| 16 |
+
#include <new>
|
| 17 |
+
#include <string>
|
| 18 |
+
#include <utility>
|
| 19 |
+
#include <vector>
|
| 20 |
+
|
| 21 |
+
using namespace ::std;
|
| 22 |
+
|
| 23 |
+
vector<vector<uint32_t>> edit_distance2_with_dp(
|
| 24 |
+
vector<uint32_t>& x,
|
| 25 |
+
vector<uint32_t>& y) {
|
| 26 |
+
uint32_t lx = x.size();
|
| 27 |
+
uint32_t ly = y.size();
|
| 28 |
+
vector<vector<uint32_t>> d(lx + 1, vector<uint32_t>(ly + 1));
|
| 29 |
+
for (uint32_t i = 0; i < lx + 1; i++) {
|
| 30 |
+
d[i][0] = i;
|
| 31 |
+
}
|
| 32 |
+
for (uint32_t j = 0; j < ly + 1; j++) {
|
| 33 |
+
d[0][j] = j;
|
| 34 |
+
}
|
| 35 |
+
for (uint32_t i = 1; i < lx + 1; i++) {
|
| 36 |
+
for (uint32_t j = 1; j < ly + 1; j++) {
|
| 37 |
+
d[i][j] =
|
| 38 |
+
min(min(d[i - 1][j], d[i][j - 1]) + 1,
|
| 39 |
+
d[i - 1][j - 1] + 2 * (x.at(i - 1) == y.at(j - 1) ? 0 : 1));
|
| 40 |
+
}
|
| 41 |
+
}
|
| 42 |
+
return d;
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
vector<vector<uint32_t>> edit_distance2_backtracking(
|
| 46 |
+
vector<vector<uint32_t>>& d,
|
| 47 |
+
vector<uint32_t>& x,
|
| 48 |
+
vector<uint32_t>& y,
|
| 49 |
+
uint32_t terminal_symbol) {
|
| 50 |
+
vector<uint32_t> seq;
|
| 51 |
+
vector<vector<uint32_t>> edit_seqs(x.size() + 2, vector<uint32_t>());
|
| 52 |
+
/*
|
| 53 |
+
edit_seqs:
|
| 54 |
+
0~x.size() cell is the insertion sequences
|
| 55 |
+
last cell is the delete sequence
|
| 56 |
+
*/
|
| 57 |
+
|
| 58 |
+
if (x.size() == 0) {
|
| 59 |
+
edit_seqs.at(0) = y;
|
| 60 |
+
return edit_seqs;
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
uint32_t i = d.size() - 1;
|
| 64 |
+
uint32_t j = d.at(0).size() - 1;
|
| 65 |
+
|
| 66 |
+
while ((i >= 0) && (j >= 0)) {
|
| 67 |
+
if ((i == 0) && (j == 0)) {
|
| 68 |
+
break;
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
if ((j > 0) && (d.at(i).at(j - 1) < d.at(i).at(j))) {
|
| 72 |
+
seq.push_back(1); // insert
|
| 73 |
+
seq.push_back(y.at(j - 1));
|
| 74 |
+
j--;
|
| 75 |
+
} else if ((i > 0) && (d.at(i - 1).at(j) < d.at(i).at(j))) {
|
| 76 |
+
seq.push_back(2); // delete
|
| 77 |
+
seq.push_back(x.at(i - 1));
|
| 78 |
+
i--;
|
| 79 |
+
} else {
|
| 80 |
+
seq.push_back(3); // keep
|
| 81 |
+
seq.push_back(x.at(i - 1));
|
| 82 |
+
i--;
|
| 83 |
+
j--;
|
| 84 |
+
}
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
uint32_t prev_op, op, s, word;
|
| 88 |
+
prev_op = 0, s = 0;
|
| 89 |
+
for (uint32_t k = 0; k < seq.size() / 2; k++) {
|
| 90 |
+
op = seq.at(seq.size() - 2 * k - 2);
|
| 91 |
+
word = seq.at(seq.size() - 2 * k - 1);
|
| 92 |
+
if (prev_op != 1) {
|
| 93 |
+
s++;
|
| 94 |
+
}
|
| 95 |
+
if (op == 1) // insert
|
| 96 |
+
{
|
| 97 |
+
edit_seqs.at(s - 1).push_back(word);
|
| 98 |
+
} else if (op == 2) // delete
|
| 99 |
+
{
|
| 100 |
+
edit_seqs.at(x.size() + 1).push_back(1);
|
| 101 |
+
} else {
|
| 102 |
+
edit_seqs.at(x.size() + 1).push_back(0);
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
prev_op = op;
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
for (uint32_t k = 0; k < edit_seqs.size(); k++) {
|
| 109 |
+
if (edit_seqs[k].size() == 0) {
|
| 110 |
+
edit_seqs[k].push_back(terminal_symbol);
|
| 111 |
+
}
|
| 112 |
+
}
|
| 113 |
+
return edit_seqs;
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
vector<vector<uint32_t>> edit_distance2_backtracking_with_delete(
|
| 117 |
+
vector<vector<uint32_t>>& d,
|
| 118 |
+
vector<uint32_t>& x,
|
| 119 |
+
vector<uint32_t>& y,
|
| 120 |
+
uint32_t terminal_symbol,
|
| 121 |
+
uint32_t deletion_symbol) {
|
| 122 |
+
vector<uint32_t> seq;
|
| 123 |
+
vector<vector<uint32_t>> edit_seqs(x.size() + 1, vector<uint32_t>());
|
| 124 |
+
/*
|
| 125 |
+
edit_seqs:
|
| 126 |
+
0~x.size() cell is the insertion sequences
|
| 127 |
+
last cell is the delete sequence
|
| 128 |
+
*/
|
| 129 |
+
|
| 130 |
+
if (x.size() == 0) {
|
| 131 |
+
edit_seqs.at(0) = y;
|
| 132 |
+
return edit_seqs;
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
uint32_t i = d.size() - 1;
|
| 136 |
+
uint32_t j = d.at(0).size() - 1;
|
| 137 |
+
|
| 138 |
+
while ((i >= 0) && (j >= 0)) {
|
| 139 |
+
if ((i == 0) && (j == 0)) {
|
| 140 |
+
break;
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
if ((j > 0) && (d.at(i).at(j - 1) < d.at(i).at(j))) {
|
| 144 |
+
seq.push_back(1); // insert
|
| 145 |
+
seq.push_back(y.at(j - 1));
|
| 146 |
+
j--;
|
| 147 |
+
} else if ((i > 0) && (d.at(i - 1).at(j) < d.at(i).at(j))) {
|
| 148 |
+
seq.push_back(2); // delete
|
| 149 |
+
seq.push_back(x.at(i - 1));
|
| 150 |
+
i--;
|
| 151 |
+
} else {
|
| 152 |
+
seq.push_back(3); // keep
|
| 153 |
+
seq.push_back(x.at(i - 1));
|
| 154 |
+
i--;
|
| 155 |
+
j--;
|
| 156 |
+
}
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
uint32_t prev_op, op, s, word;
|
| 160 |
+
prev_op = 0, s = 0;
|
| 161 |
+
for (uint32_t k = 0; k < seq.size() / 2; k++) {
|
| 162 |
+
op = seq.at(seq.size() - 2 * k - 2);
|
| 163 |
+
word = seq.at(seq.size() - 2 * k - 1);
|
| 164 |
+
if (prev_op != 1) {
|
| 165 |
+
s++;
|
| 166 |
+
}
|
| 167 |
+
if (op == 1) // insert
|
| 168 |
+
{
|
| 169 |
+
edit_seqs.at(s - 1).push_back(word);
|
| 170 |
+
} else if (op == 2) // delete
|
| 171 |
+
{
|
| 172 |
+
edit_seqs.at(s - 1).push_back(deletion_symbol);
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
prev_op = op;
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
for (uint32_t k = 0; k < edit_seqs.size(); k++) {
|
| 179 |
+
if (edit_seqs.at(k).size() == 0) {
|
| 180 |
+
edit_seqs.at(k).push_back(terminal_symbol);
|
| 181 |
+
}
|
| 182 |
+
}
|
| 183 |
+
return edit_seqs;
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
vector<uint32_t> compute_ed2(
|
| 187 |
+
vector<vector<uint32_t>>& xs,
|
| 188 |
+
vector<vector<uint32_t>>& ys) {
|
| 189 |
+
vector<uint32_t> distances(xs.size());
|
| 190 |
+
for (uint32_t i = 0; i < xs.size(); i++) {
|
| 191 |
+
vector<vector<uint32_t>> d = edit_distance2_with_dp(xs.at(i), ys.at(i));
|
| 192 |
+
distances.at(i) = d.at(xs.at(i).size()).at(ys.at(i).size());
|
| 193 |
+
}
|
| 194 |
+
return distances;
|
| 195 |
+
}
|
| 196 |
+
|
| 197 |
+
vector<vector<vector<uint32_t>>> suggested_ed2_path(
|
| 198 |
+
vector<vector<uint32_t>>& xs,
|
| 199 |
+
vector<vector<uint32_t>>& ys,
|
| 200 |
+
uint32_t terminal_symbol) {
|
| 201 |
+
vector<vector<vector<uint32_t>>> seq(xs.size());
|
| 202 |
+
for (uint32_t i = 0; i < xs.size(); i++) {
|
| 203 |
+
vector<vector<uint32_t>> d = edit_distance2_with_dp(xs.at(i), ys.at(i));
|
| 204 |
+
seq.at(i) =
|
| 205 |
+
edit_distance2_backtracking(d, xs.at(i), ys.at(i), terminal_symbol);
|
| 206 |
+
}
|
| 207 |
+
return seq;
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
vector<vector<vector<uint32_t>>> suggested_ed2_path_with_delete(
|
| 211 |
+
vector<vector<uint32_t>>& xs,
|
| 212 |
+
vector<vector<uint32_t>>& ys,
|
| 213 |
+
uint32_t terminal_symbol,
|
| 214 |
+
uint32_t deletion_symbol) {
|
| 215 |
+
vector<vector<vector<uint32_t>>> seq(xs.size());
|
| 216 |
+
for (uint32_t i = 0; i < xs.size(); i++) {
|
| 217 |
+
vector<vector<uint32_t>> d = edit_distance2_with_dp(xs.at(i), ys.at(i));
|
| 218 |
+
seq.at(i) = edit_distance2_backtracking_with_delete(
|
| 219 |
+
d, xs.at(i), ys.at(i), terminal_symbol, deletion_symbol);
|
| 220 |
+
}
|
| 221 |
+
return seq;
|
| 222 |
+
}
|
| 223 |
+
|
| 224 |
+
PYBIND11_MODULE(libnat, m) {
|
| 225 |
+
m.def("compute_ed2", &compute_ed2, "compute_ed2");
|
| 226 |
+
m.def("suggested_ed2_path", &suggested_ed2_path, "suggested_ed2_path");
|
| 227 |
+
m.def(
|
| 228 |
+
"suggested_ed2_path_with_delete",
|
| 229 |
+
&suggested_ed2_path_with_delete,
|
| 230 |
+
"suggested_ed2_path_with_delete");
|
| 231 |
+
}
|
fairseq/clib/libnat_cuda/binding.cpp
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/**
|
| 2 |
+
* Copyright 2017-present, Facebook, Inc.
|
| 3 |
+
* All rights reserved.
|
| 4 |
+
*
|
| 5 |
+
* This source code is licensed under the license found in the
|
| 6 |
+
* LICENSE file in the root directory of this source tree.
|
| 7 |
+
*/
|
| 8 |
+
|
| 9 |
+
/*
|
| 10 |
+
This code is partially adpoted from
|
| 11 |
+
https://github.com/1ytic/pytorch-edit-distance
|
| 12 |
+
*/
|
| 13 |
+
|
| 14 |
+
#include <torch/types.h>
|
| 15 |
+
#include "edit_dist.h"
|
| 16 |
+
|
| 17 |
+
#ifndef TORCH_CHECK
|
| 18 |
+
#define TORCH_CHECK AT_CHECK
|
| 19 |
+
#endif
|
| 20 |
+
|
| 21 |
+
#define CHECK_CUDA(x) \
|
| 22 |
+
TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
|
| 23 |
+
#define CHECK_CONTIGUOUS(x) \
|
| 24 |
+
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
| 25 |
+
#define CHECK_INPUT(x) \
|
| 26 |
+
CHECK_CUDA(x); \
|
| 27 |
+
CHECK_CONTIGUOUS(x)
|
| 28 |
+
|
| 29 |
+
torch::Tensor LevenshteinDistance(
|
| 30 |
+
torch::Tensor source,
|
| 31 |
+
torch::Tensor target,
|
| 32 |
+
torch::Tensor source_length,
|
| 33 |
+
torch::Tensor target_length) {
|
| 34 |
+
CHECK_INPUT(source);
|
| 35 |
+
CHECK_INPUT(target);
|
| 36 |
+
CHECK_INPUT(source_length);
|
| 37 |
+
CHECK_INPUT(target_length);
|
| 38 |
+
return LevenshteinDistanceCuda(source, target, source_length, target_length);
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
torch::Tensor GenerateDeletionLabel(
|
| 42 |
+
torch::Tensor source,
|
| 43 |
+
torch::Tensor operations) {
|
| 44 |
+
CHECK_INPUT(source);
|
| 45 |
+
CHECK_INPUT(operations);
|
| 46 |
+
return GenerateDeletionLabelCuda(source, operations);
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
std::pair<torch::Tensor, torch::Tensor> GenerateInsertionLabel(
|
| 50 |
+
torch::Tensor target,
|
| 51 |
+
torch::Tensor operations) {
|
| 52 |
+
CHECK_INPUT(target);
|
| 53 |
+
CHECK_INPUT(operations);
|
| 54 |
+
return GenerateInsertionLabelCuda(target, operations);
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 58 |
+
m.def("levenshtein_distance", &LevenshteinDistance, "Levenshtein distance");
|
| 59 |
+
m.def(
|
| 60 |
+
"generate_deletion_labels",
|
| 61 |
+
&GenerateDeletionLabel,
|
| 62 |
+
"Generate Deletion Label");
|
| 63 |
+
m.def(
|
| 64 |
+
"generate_insertion_labels",
|
| 65 |
+
&GenerateInsertionLabel,
|
| 66 |
+
"Generate Insertion Label");
|
| 67 |
+
}
|
fairseq/clib/libnat_cuda/edit_dist.cu
ADDED
|
@@ -0,0 +1,344 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/**
|
| 2 |
+
* Copyright 2017-present, Facebook, Inc.
|
| 3 |
+
* All rights reserved.
|
| 4 |
+
*
|
| 5 |
+
* This source code is licensed under the license found in the
|
| 6 |
+
* LICENSE file in the root directory of this source tree.
|
| 7 |
+
*/
|
| 8 |
+
|
| 9 |
+
#include "edit_dist.h"
|
| 10 |
+
|
| 11 |
+
#include <c10/cuda/CUDAStream.h>
|
| 12 |
+
#include <cuda.h>
|
| 13 |
+
#include <cuda_runtime.h>
|
| 14 |
+
#include <device_launch_parameters.h>
|
| 15 |
+
#include <utility> // std::pair
|
| 16 |
+
|
| 17 |
+
template <typename scalar_t>
|
| 18 |
+
__global__ void generate_deletion_label_kernel(
|
| 19 |
+
const scalar_t* __restrict__ source,
|
| 20 |
+
const size_t source_size,
|
| 21 |
+
const size_t operation_size,
|
| 22 |
+
int* __restrict__ operations,
|
| 23 |
+
int* __restrict__ labels) {
|
| 24 |
+
const int index = blockIdx.x;
|
| 25 |
+
const int offset = index * operation_size;
|
| 26 |
+
const int offset_label = index * source_size;
|
| 27 |
+
|
| 28 |
+
for (int i = 0; i < source_size; i++) {
|
| 29 |
+
labels[offset_label + i] = 0;
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
int k = 0;
|
| 33 |
+
for (int i = 0; i < operation_size; i++) {
|
| 34 |
+
if (operations[offset + i] == 0) {
|
| 35 |
+
break;
|
| 36 |
+
} else if (operations[offset + i] == 1) {
|
| 37 |
+
continue;
|
| 38 |
+
} else {
|
| 39 |
+
labels[offset_label + k] = 3 - operations[offset + i];
|
| 40 |
+
k++;
|
| 41 |
+
}
|
| 42 |
+
}
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
template <typename scalar_t>
|
| 46 |
+
__global__ void generate_insertion_label_kernel(
|
| 47 |
+
const scalar_t* __restrict__ target,
|
| 48 |
+
const size_t target_size,
|
| 49 |
+
const size_t operation_size,
|
| 50 |
+
int* __restrict__ operations,
|
| 51 |
+
int* __restrict__ labels,
|
| 52 |
+
int* __restrict__ masks) {
|
| 53 |
+
const int index = blockIdx.x;
|
| 54 |
+
const int offset = index * operation_size;
|
| 55 |
+
const int offset_label = index * target_size;
|
| 56 |
+
|
| 57 |
+
int k = 0;
|
| 58 |
+
int u = 0;
|
| 59 |
+
int m = 0;
|
| 60 |
+
|
| 61 |
+
for (int i = 0; i < target_size; i++) {
|
| 62 |
+
labels[offset_label + i] = 0;
|
| 63 |
+
masks[offset_label + i] = 0;
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
for (int i = 0; i < operation_size - 1; i++) {
|
| 67 |
+
if (operations[offset + i] == 0) {
|
| 68 |
+
break;
|
| 69 |
+
} else if (operations[offset + i] == 2) {
|
| 70 |
+
continue;
|
| 71 |
+
} else if (operations[offset + i] == 1) {
|
| 72 |
+
masks[offset_label + m] = 1;
|
| 73 |
+
u++;
|
| 74 |
+
m++;
|
| 75 |
+
} else {
|
| 76 |
+
labels[offset_label + k] = u;
|
| 77 |
+
masks[offset_label + m] = 0;
|
| 78 |
+
k++;
|
| 79 |
+
m++;
|
| 80 |
+
u = 0;
|
| 81 |
+
}
|
| 82 |
+
}
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
template <typename scalar_t>
|
| 86 |
+
__global__ void levenshtein_distance_kernel(
|
| 87 |
+
const scalar_t* __restrict__ source,
|
| 88 |
+
const scalar_t* __restrict__ target,
|
| 89 |
+
const int* __restrict__ source_length,
|
| 90 |
+
const int* __restrict__ target_length,
|
| 91 |
+
const size_t source_size,
|
| 92 |
+
const size_t target_size,
|
| 93 |
+
int* __restrict__ operations,
|
| 94 |
+
int* __restrict__ errors_curr) {
|
| 95 |
+
const int index = blockIdx.x;
|
| 96 |
+
const int offset = index * (source_size + target_size);
|
| 97 |
+
const int d = index * (source_size + 1) * (target_size + 1);
|
| 98 |
+
const int t = target_size + 1;
|
| 99 |
+
|
| 100 |
+
auto err_idx = [d, t](int i, int j) { return d + i * t + j; };
|
| 101 |
+
auto opt_idx = [offset](int k) { return offset + k; };
|
| 102 |
+
|
| 103 |
+
const int hyp_len = source_length[index];
|
| 104 |
+
const int ref_len = target_length[index];
|
| 105 |
+
const scalar_t* hyp_begin = source + index * source_size;
|
| 106 |
+
const scalar_t* ref_begin = target + index * target_size;
|
| 107 |
+
|
| 108 |
+
// dynamic programming
|
| 109 |
+
for (int i = 0; i <= hyp_len; i++) {
|
| 110 |
+
errors_curr[err_idx(i, 0)] = i;
|
| 111 |
+
}
|
| 112 |
+
for (int j = 0; j <= ref_len; j++) {
|
| 113 |
+
errors_curr[err_idx(0, j)] = j;
|
| 114 |
+
}
|
| 115 |
+
for (int i = 1; i <= hyp_len; i++) {
|
| 116 |
+
for (int j = 1; j <= ref_len; j++) {
|
| 117 |
+
errors_curr[err_idx(i, j)] = min(
|
| 118 |
+
min(errors_curr[err_idx(i - 1, j)], errors_curr[err_idx(i, j - 1)]) +
|
| 119 |
+
1,
|
| 120 |
+
errors_curr[err_idx(i - 1, j - 1)] +
|
| 121 |
+
2 * (*(hyp_begin + i - 1) == *(ref_begin + j - 1) ? 0 : 1));
|
| 122 |
+
}
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
// back-tracing
|
| 126 |
+
int i = hyp_len;
|
| 127 |
+
int j = ref_len;
|
| 128 |
+
int o = hyp_len + ref_len;
|
| 129 |
+
|
| 130 |
+
for (int k = 0; k < source_size + target_size; k++) {
|
| 131 |
+
operations[opt_idx(k)] = 0;
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
while ((i >= 0) && (j >= 0)) {
|
| 135 |
+
if ((i == 0) && (j == 0)) {
|
| 136 |
+
break;
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
if ((j > 0) &&
|
| 140 |
+
(errors_curr[err_idx(i, j - 1)] < errors_curr[err_idx(i, j)])) {
|
| 141 |
+
o--;
|
| 142 |
+
operations[opt_idx(o)] = 1;
|
| 143 |
+
j--; // insertion
|
| 144 |
+
} else if (
|
| 145 |
+
(i > 0) &&
|
| 146 |
+
(errors_curr[err_idx(i - 1, j)] < errors_curr[err_idx(i, j)])) {
|
| 147 |
+
o--;
|
| 148 |
+
operations[opt_idx(o)] = 2;
|
| 149 |
+
i--; // deletion
|
| 150 |
+
} else {
|
| 151 |
+
o--;
|
| 152 |
+
operations[opt_idx(o)] = 3;
|
| 153 |
+
i--;
|
| 154 |
+
j--; // do nothing
|
| 155 |
+
}
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
// moving to the left
|
| 159 |
+
for (int k = 0; k < hyp_len + ref_len; k++) {
|
| 160 |
+
if (k + o < hyp_len + ref_len) {
|
| 161 |
+
operations[opt_idx(k)] = operations[opt_idx(k + o)];
|
| 162 |
+
} else {
|
| 163 |
+
operations[opt_idx(k)] = 0; // padding
|
| 164 |
+
}
|
| 165 |
+
}
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
template <typename scalar_t>
|
| 169 |
+
__global__ void faster_levenshtein_distance_kernel(
|
| 170 |
+
const scalar_t* __restrict__ source,
|
| 171 |
+
const scalar_t* __restrict__ target,
|
| 172 |
+
const int* __restrict__ source_length,
|
| 173 |
+
const int* __restrict__ target_length,
|
| 174 |
+
const size_t source_size,
|
| 175 |
+
const size_t target_size,
|
| 176 |
+
int* __restrict__ operations) {
|
| 177 |
+
extern __shared__ short errors[];
|
| 178 |
+
auto errors_curr = errors;
|
| 179 |
+
|
| 180 |
+
const int index = blockIdx.x;
|
| 181 |
+
const int offset = index * (source_size + target_size);
|
| 182 |
+
const int t = target_size + 1;
|
| 183 |
+
|
| 184 |
+
auto err_idx = [t](int i, int j) { return i * t + j; };
|
| 185 |
+
auto opt_idx = [offset](int k) { return offset + k; };
|
| 186 |
+
|
| 187 |
+
const int hyp_len = source_length[index];
|
| 188 |
+
const int ref_len = target_length[index];
|
| 189 |
+
const scalar_t* hyp_begin = source + index * source_size;
|
| 190 |
+
const scalar_t* ref_begin = target + index * target_size;
|
| 191 |
+
|
| 192 |
+
// dynamic programming
|
| 193 |
+
for (int i = 0; i <= hyp_len; i++) {
|
| 194 |
+
errors_curr[err_idx(i, 0)] = i;
|
| 195 |
+
}
|
| 196 |
+
for (int j = 0; j <= ref_len; j++) {
|
| 197 |
+
errors_curr[err_idx(0, j)] = j;
|
| 198 |
+
}
|
| 199 |
+
for (int i = 1; i <= hyp_len; i++) {
|
| 200 |
+
for (int j = 1; j <= ref_len; j++) {
|
| 201 |
+
errors_curr[err_idx(i, j)] = min(
|
| 202 |
+
min(errors_curr[err_idx(i - 1, j)], errors_curr[err_idx(i, j - 1)]) +
|
| 203 |
+
1,
|
| 204 |
+
errors_curr[err_idx(i - 1, j - 1)] +
|
| 205 |
+
2 * (*(hyp_begin + i - 1) == *(ref_begin + j - 1) ? 0 : 1));
|
| 206 |
+
}
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
// back-tracing
|
| 210 |
+
int i = hyp_len;
|
| 211 |
+
int j = ref_len;
|
| 212 |
+
int o = hyp_len + ref_len;
|
| 213 |
+
|
| 214 |
+
for (int k = 0; k < source_size + target_size; k++) {
|
| 215 |
+
operations[opt_idx(k)] = 0;
|
| 216 |
+
}
|
| 217 |
+
|
| 218 |
+
while ((i >= 0) && (j >= 0)) {
|
| 219 |
+
if ((i == 0) && (j == 0)) {
|
| 220 |
+
break;
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
if ((j > 0) &&
|
| 224 |
+
(errors_curr[err_idx(i, j - 1)] < errors_curr[err_idx(i, j)])) {
|
| 225 |
+
o--;
|
| 226 |
+
operations[opt_idx(o)] = 1;
|
| 227 |
+
j--; // insertion
|
| 228 |
+
} else if (
|
| 229 |
+
(i > 0) &&
|
| 230 |
+
(errors_curr[err_idx(i - 1, j)] < errors_curr[err_idx(i, j)])) {
|
| 231 |
+
o--;
|
| 232 |
+
operations[opt_idx(o)] = 2;
|
| 233 |
+
i--; // deletion
|
| 234 |
+
} else {
|
| 235 |
+
o--;
|
| 236 |
+
operations[opt_idx(o)] = 3;
|
| 237 |
+
i--;
|
| 238 |
+
j--; // do nothing
|
| 239 |
+
}
|
| 240 |
+
}
|
| 241 |
+
|
| 242 |
+
// moving to the left
|
| 243 |
+
for (int k = 0; k < hyp_len + ref_len; k++) {
|
| 244 |
+
if (k + o < hyp_len + ref_len) {
|
| 245 |
+
operations[opt_idx(k)] = operations[opt_idx(k + o)];
|
| 246 |
+
} else {
|
| 247 |
+
operations[opt_idx(k)] = 0; // padding
|
| 248 |
+
}
|
| 249 |
+
}
|
| 250 |
+
}
|
| 251 |
+
|
| 252 |
+
torch::Tensor GenerateDeletionLabelCuda(
|
| 253 |
+
torch::Tensor source,
|
| 254 |
+
torch::Tensor operations) {
|
| 255 |
+
const auto batch_size = source.size(0);
|
| 256 |
+
at::TensorOptions options(source.device());
|
| 257 |
+
options = options.dtype(at::ScalarType::Int);
|
| 258 |
+
auto labels = torch::empty({batch_size, source.size(1)}, options);
|
| 259 |
+
auto stream = at::cuda::getCurrentCUDAStream(source.device().index());
|
| 260 |
+
|
| 261 |
+
AT_DISPATCH_ALL_TYPES(source.scalar_type(), "generate_deletion_labels", ([&] {
|
| 262 |
+
generate_deletion_label_kernel<scalar_t>
|
| 263 |
+
<<<batch_size, 1, 0, stream>>>(
|
| 264 |
+
source.data_ptr<scalar_t>(),
|
| 265 |
+
source.size(1),
|
| 266 |
+
operations.size(1),
|
| 267 |
+
operations.data_ptr<int>(),
|
| 268 |
+
labels.data_ptr<int>());
|
| 269 |
+
}));
|
| 270 |
+
|
| 271 |
+
return labels;
|
| 272 |
+
}
|
| 273 |
+
|
| 274 |
+
std::pair<torch::Tensor, torch::Tensor> GenerateInsertionLabelCuda(
|
| 275 |
+
torch::Tensor target,
|
| 276 |
+
torch::Tensor operations) {
|
| 277 |
+
const auto batch_size = target.size(0);
|
| 278 |
+
at::TensorOptions options(target.device());
|
| 279 |
+
options = options.dtype(at::ScalarType::Int);
|
| 280 |
+
auto labels = torch::empty({batch_size, target.size(1)}, options);
|
| 281 |
+
auto masks = torch::empty({batch_size, target.size(1)}, options);
|
| 282 |
+
auto stream = at::cuda::getCurrentCUDAStream(target.device().index());
|
| 283 |
+
|
| 284 |
+
AT_DISPATCH_ALL_TYPES(
|
| 285 |
+
target.scalar_type(), "generate_insertion_labels", ([&] {
|
| 286 |
+
generate_insertion_label_kernel<scalar_t><<<batch_size, 1, 0, stream>>>(
|
| 287 |
+
target.data_ptr<scalar_t>(),
|
| 288 |
+
target.size(1),
|
| 289 |
+
operations.size(1),
|
| 290 |
+
operations.data_ptr<int>(),
|
| 291 |
+
labels.data_ptr<int>(),
|
| 292 |
+
masks.data_ptr<int>());
|
| 293 |
+
}));
|
| 294 |
+
|
| 295 |
+
return std::make_pair(labels, masks);
|
| 296 |
+
}
|
| 297 |
+
|
| 298 |
+
torch::Tensor LevenshteinDistanceCuda(
|
| 299 |
+
torch::Tensor source,
|
| 300 |
+
torch::Tensor target,
|
| 301 |
+
torch::Tensor source_length,
|
| 302 |
+
torch::Tensor target_length) {
|
| 303 |
+
const auto batch_size = source.size(0);
|
| 304 |
+
const auto shared_size =
|
| 305 |
+
(source.size(1) + 1) * (target.size(1) + 1) * sizeof(short);
|
| 306 |
+
|
| 307 |
+
at::TensorOptions options(source.device());
|
| 308 |
+
options = options.dtype(at::ScalarType::Int);
|
| 309 |
+
auto operations =
|
| 310 |
+
torch::empty({batch_size, source.size(1) + target.size(1)}, options);
|
| 311 |
+
auto stream = at::cuda::getCurrentCUDAStream(source.device().index());
|
| 312 |
+
|
| 313 |
+
if (shared_size > 40000) {
|
| 314 |
+
auto distances = torch::empty(
|
| 315 |
+
{batch_size, (source.size(1) + 1) * (target.size(1) + 1)}, options);
|
| 316 |
+
AT_DISPATCH_ALL_TYPES(source.scalar_type(), "levenshtein_distance", ([&] {
|
| 317 |
+
levenshtein_distance_kernel<scalar_t>
|
| 318 |
+
<<<batch_size, 1, 0, stream>>>(
|
| 319 |
+
source.data_ptr<scalar_t>(),
|
| 320 |
+
target.data_ptr<scalar_t>(),
|
| 321 |
+
source_length.data_ptr<int>(),
|
| 322 |
+
target_length.data_ptr<int>(),
|
| 323 |
+
source.size(1),
|
| 324 |
+
target.size(1),
|
| 325 |
+
operations.data_ptr<int>(),
|
| 326 |
+
distances.data_ptr<int>());
|
| 327 |
+
}));
|
| 328 |
+
} else {
|
| 329 |
+
AT_DISPATCH_ALL_TYPES(
|
| 330 |
+
source.scalar_type(), "faster_levenshtein_distance", ([&] {
|
| 331 |
+
faster_levenshtein_distance_kernel<scalar_t>
|
| 332 |
+
<<<batch_size, 1, shared_size, stream>>>(
|
| 333 |
+
source.data_ptr<scalar_t>(),
|
| 334 |
+
target.data_ptr<scalar_t>(),
|
| 335 |
+
source_length.data_ptr<int>(),
|
| 336 |
+
target_length.data_ptr<int>(),
|
| 337 |
+
source.size(1),
|
| 338 |
+
target.size(1),
|
| 339 |
+
operations.data_ptr<int>());
|
| 340 |
+
}));
|
| 341 |
+
}
|
| 342 |
+
|
| 343 |
+
return operations;
|
| 344 |
+
}
|
fairseq/clib/libnat_cuda/edit_dist.h
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/**
|
| 2 |
+
* Copyright 2017-present, Facebook, Inc.
|
| 3 |
+
* All rights reserved.
|
| 4 |
+
*
|
| 5 |
+
* This source code is licensed under the license found in the
|
| 6 |
+
* LICENSE file in the root directory of this source tree.
|
| 7 |
+
*/
|
| 8 |
+
|
| 9 |
+
#pragma once
|
| 10 |
+
|
| 11 |
+
#include <torch/extension.h>
|
| 12 |
+
|
| 13 |
+
torch::Tensor LevenshteinDistanceCuda(
|
| 14 |
+
torch::Tensor source,
|
| 15 |
+
torch::Tensor target,
|
| 16 |
+
torch::Tensor source_length,
|
| 17 |
+
torch::Tensor target_length);
|
| 18 |
+
|
| 19 |
+
torch::Tensor GenerateDeletionLabelCuda(
|
| 20 |
+
torch::Tensor source,
|
| 21 |
+
torch::Tensor operations);
|
| 22 |
+
|
| 23 |
+
std::pair<torch::Tensor, torch::Tensor> GenerateInsertionLabelCuda(
|
| 24 |
+
torch::Tensor source,
|
| 25 |
+
torch::Tensor operations);
|
fairseq/config/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
fairseq/config/config.yaml
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _group_
|
| 2 |
+
|
| 3 |
+
hydra:
|
| 4 |
+
run:
|
| 5 |
+
dir: .
|
| 6 |
+
|
| 7 |
+
defaults:
|
| 8 |
+
- _self_
|
| 9 |
+
- task: null
|
| 10 |
+
- model: null
|
| 11 |
+
- criterion: cross_entropy
|
| 12 |
+
- optimizer: null
|
| 13 |
+
- lr_scheduler: fixed
|
| 14 |
+
- bpe: null
|
| 15 |
+
- tokenizer: null
|
| 16 |
+
- scoring: null
|
| 17 |
+
- generation: null
|
| 18 |
+
- common_eval: null
|
| 19 |
+
- eval_lm: null
|
fairseq/config/fb_run_config/slurm.yaml
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
hydra:
|
| 4 |
+
job:
|
| 5 |
+
config:
|
| 6 |
+
override_dirname:
|
| 7 |
+
kv_sep: ':'
|
| 8 |
+
item_sep: '__'
|
| 9 |
+
exclude_keys:
|
| 10 |
+
- fb_run_config
|
| 11 |
+
- distributed_training.distributed_port
|
| 12 |
+
sweep:
|
| 13 |
+
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
|
| 14 |
+
launcher:
|
| 15 |
+
cpus_per_task: 60
|
| 16 |
+
gpus_per_node: ???
|
| 17 |
+
tasks_per_node: 1
|
| 18 |
+
nodes: 1
|
| 19 |
+
partition: learnfair
|
| 20 |
+
mem_gb: 400
|
| 21 |
+
timeout_min: 4320
|
| 22 |
+
max_num_timeout: 10
|
| 23 |
+
name: ${env:PREFIX}_${hydra.job.config_name}
|
| 24 |
+
submitit_folder: ${hydra.sweep.dir}
|
| 25 |
+
|
| 26 |
+
distributed_training:
|
| 27 |
+
ddp_backend: c10d
|
| 28 |
+
distributed_world_size: ???
|
| 29 |
+
distributed_port: ???
|