Upload 313 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- e_smiles.py +0 -0
- infer.sh +10 -0
- inference.py +5 -0
- onmt/__init__.py +24 -0
- onmt/__pycache__/__init__.cpython-311.pyc +0 -0
- onmt/__pycache__/__init__.cpython-37.pyc +0 -0
- onmt/__pycache__/__init__.cpython-38.pyc +0 -0
- onmt/__pycache__/constants.cpython-311.pyc +0 -0
- onmt/__pycache__/constants.cpython-38.pyc +0 -0
- onmt/__pycache__/inference_engine.cpython-38.pyc +0 -0
- onmt/__pycache__/model_builder.cpython-311.pyc +0 -0
- onmt/__pycache__/model_builder.cpython-38.pyc +0 -0
- onmt/__pycache__/opts.cpython-311.pyc +0 -0
- onmt/__pycache__/opts.cpython-38.pyc +0 -0
- onmt/__pycache__/train_single.cpython-38.pyc +0 -0
- onmt/__pycache__/trainer.cpython-38.pyc +0 -0
- onmt/bin/__init__.py +0 -0
- onmt/bin/__pycache__/__init__.cpython-311.pyc +0 -0
- onmt/bin/__pycache__/__init__.cpython-38.pyc +0 -0
- onmt/bin/__pycache__/average_models.cpython-38.pyc +0 -0
- onmt/bin/__pycache__/build_vocab.cpython-38.pyc +0 -0
- onmt/bin/__pycache__/release_model.cpython-38.pyc +0 -0
- onmt/bin/__pycache__/server.cpython-38.pyc +0 -0
- onmt/bin/__pycache__/train.cpython-38.pyc +0 -0
- onmt/bin/__pycache__/translate.cpython-311.pyc +0 -0
- onmt/bin/__pycache__/translate.cpython-38.pyc +0 -0
- onmt/bin/average_models.py +60 -0
- onmt/bin/build_vocab.py +287 -0
- onmt/bin/release_model.py +39 -0
- onmt/bin/server.py +167 -0
- onmt/bin/train.py +71 -0
- onmt/bin/translate.py +60 -0
- onmt/constants.py +41 -0
- onmt/decoders/__init__.py +63 -0
- onmt/decoders/__pycache__/__init__.cpython-311.pyc +0 -0
- onmt/decoders/__pycache__/__init__.cpython-38.pyc +0 -0
- onmt/decoders/__pycache__/cnn_decoder.cpython-311.pyc +0 -0
- onmt/decoders/__pycache__/cnn_decoder.cpython-38.pyc +0 -0
- onmt/decoders/__pycache__/decoder.cpython-311.pyc +0 -0
- onmt/decoders/__pycache__/decoder.cpython-38.pyc +0 -0
- onmt/decoders/__pycache__/ensemble.cpython-311.pyc +0 -0
- onmt/decoders/__pycache__/ensemble.cpython-38.pyc +0 -0
- onmt/decoders/__pycache__/transformer.cpython-311.pyc +0 -0
- onmt/decoders/__pycache__/transformer.cpython-38.pyc +0 -0
- onmt/decoders/cnn_decoder.py +141 -0
- onmt/decoders/decoder.py +405 -0
- onmt/decoders/ensemble.py +150 -0
- onmt/decoders/transformer.py +835 -0
- onmt/encoders/__init__.py +67 -0
- onmt/encoders/__pycache__/__init__.cpython-311.pyc +0 -0
e_smiles.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
infer.sh
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
python inference.py \
|
| 2 |
+
--model trained_models/retrosnyhesis_ReactSeq_prompt_model_on_50k_aug100.pt \
|
| 3 |
+
--src ./tmp_data/src.txt \
|
| 4 |
+
--output ./tmp_data/tgt.txt \
|
| 5 |
+
--beam_size 10 \
|
| 6 |
+
--n_best 10 \
|
| 7 |
+
--batch_size 16384 \
|
| 8 |
+
--batch_type tokens \
|
| 9 |
+
--max_length 500 \
|
| 10 |
+
--seed 0
|
inference.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
from onmt.bin.translate import main
|
| 3 |
+
|
| 4 |
+
if __name__ == "__main__":
|
| 5 |
+
main()
|
onmt/__init__.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" Main entry point of the ONMT library """
|
| 2 |
+
import onmt.inputters
|
| 3 |
+
import onmt.encoders
|
| 4 |
+
import onmt.decoders
|
| 5 |
+
import onmt.models
|
| 6 |
+
import onmt.utils
|
| 7 |
+
import onmt.modules
|
| 8 |
+
import sys
|
| 9 |
+
import onmt.utils.optimizers
|
| 10 |
+
|
| 11 |
+
onmt.utils.optimizers.Optim = onmt.utils.optimizers.Optimizer
|
| 12 |
+
sys.modules["onmt.Optim"] = onmt.utils.optimizers
|
| 13 |
+
|
| 14 |
+
# For Flake
|
| 15 |
+
__all__ = [
|
| 16 |
+
onmt.inputters,
|
| 17 |
+
onmt.encoders,
|
| 18 |
+
onmt.decoders,
|
| 19 |
+
onmt.models,
|
| 20 |
+
onmt.utils,
|
| 21 |
+
onmt.modules,
|
| 22 |
+
]
|
| 23 |
+
|
| 24 |
+
__version__ = "3.4.1"
|
onmt/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (892 Bytes). View file
|
|
|
onmt/__pycache__/__init__.cpython-37.pyc
ADDED
|
Binary file (605 Bytes). View file
|
|
|
onmt/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (603 Bytes). View file
|
|
|
onmt/__pycache__/constants.cpython-311.pyc
ADDED
|
Binary file (2.06 kB). View file
|
|
|
onmt/__pycache__/constants.cpython-38.pyc
ADDED
|
Binary file (1.61 kB). View file
|
|
|
onmt/__pycache__/inference_engine.cpython-38.pyc
ADDED
|
Binary file (3.22 kB). View file
|
|
|
onmt/__pycache__/model_builder.cpython-311.pyc
ADDED
|
Binary file (19.4 kB). View file
|
|
|
onmt/__pycache__/model_builder.cpython-38.pyc
ADDED
|
Binary file (10.6 kB). View file
|
|
|
onmt/__pycache__/opts.cpython-311.pyc
ADDED
|
Binary file (58 kB). View file
|
|
|
onmt/__pycache__/opts.cpython-38.pyc
ADDED
|
Binary file (38.4 kB). View file
|
|
|
onmt/__pycache__/train_single.cpython-38.pyc
ADDED
|
Binary file (6.41 kB). View file
|
|
|
onmt/__pycache__/trainer.cpython-38.pyc
ADDED
|
Binary file (14.5 kB). View file
|
|
|
onmt/bin/__init__.py
ADDED
|
File without changes
|
onmt/bin/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (171 Bytes). View file
|
|
|
onmt/bin/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (145 Bytes). View file
|
|
|
onmt/bin/__pycache__/average_models.cpython-38.pyc
ADDED
|
Binary file (1.48 kB). View file
|
|
|
onmt/bin/__pycache__/build_vocab.cpython-38.pyc
ADDED
|
Binary file (8.77 kB). View file
|
|
|
onmt/bin/__pycache__/release_model.cpython-38.pyc
ADDED
|
Binary file (1.17 kB). View file
|
|
|
onmt/bin/__pycache__/server.cpython-38.pyc
ADDED
|
Binary file (5.08 kB). View file
|
|
|
onmt/bin/__pycache__/train.cpython-38.pyc
ADDED
|
Binary file (1.84 kB). View file
|
|
|
onmt/bin/__pycache__/translate.cpython-311.pyc
ADDED
|
Binary file (2.89 kB). View file
|
|
|
onmt/bin/__pycache__/translate.cpython-38.pyc
ADDED
|
Binary file (1.77 kB). View file
|
|
|
onmt/bin/average_models.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
import argparse
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def average_models(model_files, fp32=False):
|
| 7 |
+
vocab = None
|
| 8 |
+
opt = None
|
| 9 |
+
avg_model = None
|
| 10 |
+
avg_generator = None
|
| 11 |
+
|
| 12 |
+
for i, model_file in enumerate(model_files):
|
| 13 |
+
m = torch.load(model_file, map_location="cpu")
|
| 14 |
+
model_weights = m["model"]
|
| 15 |
+
generator_weights = m["generator"]
|
| 16 |
+
|
| 17 |
+
if fp32:
|
| 18 |
+
for k, v in model_weights.items():
|
| 19 |
+
model_weights[k] = v.float()
|
| 20 |
+
for k, v in generator_weights.items():
|
| 21 |
+
generator_weights[k] = v.float()
|
| 22 |
+
|
| 23 |
+
if i == 0:
|
| 24 |
+
vocab, opt = m["vocab"], m["opt"]
|
| 25 |
+
avg_model = model_weights
|
| 26 |
+
avg_generator = generator_weights
|
| 27 |
+
else:
|
| 28 |
+
for k, v in avg_model.items():
|
| 29 |
+
avg_model[k].mul_(i).add_(model_weights[k]).div_(i + 1)
|
| 30 |
+
|
| 31 |
+
for k, v in avg_generator.items():
|
| 32 |
+
avg_generator[k].mul_(i).add_(generator_weights[k]).div_(i + 1)
|
| 33 |
+
|
| 34 |
+
final = {
|
| 35 |
+
"vocab": vocab,
|
| 36 |
+
"opt": opt,
|
| 37 |
+
"optim": None,
|
| 38 |
+
"generator": avg_generator,
|
| 39 |
+
"model": avg_model,
|
| 40 |
+
}
|
| 41 |
+
return final
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def main():
|
| 45 |
+
parser = argparse.ArgumentParser(description="")
|
| 46 |
+
parser.add_argument(
|
| 47 |
+
"-models", "-m", nargs="+", required=True, help="List of models"
|
| 48 |
+
)
|
| 49 |
+
parser.add_argument("-output", "-o", required=True, help="Output file")
|
| 50 |
+
parser.add_argument(
|
| 51 |
+
"-fp32", "-f", action="store_true", help="Cast params to float32"
|
| 52 |
+
)
|
| 53 |
+
opt = parser.parse_args()
|
| 54 |
+
|
| 55 |
+
final = average_models(opt.models, opt.fp32)
|
| 56 |
+
torch.save(final, opt.output)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
if __name__ == "__main__":
|
| 60 |
+
main()
|
onmt/bin/build_vocab.py
ADDED
|
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
"""Get vocabulary coutings from transformed corpora samples."""
|
| 3 |
+
import os
|
| 4 |
+
import copy
|
| 5 |
+
import multiprocessing as mp
|
| 6 |
+
import pyonmttok
|
| 7 |
+
from functools import partial
|
| 8 |
+
from onmt.utils.logging import init_logger, logger
|
| 9 |
+
from onmt.utils.misc import set_random_seed, check_path
|
| 10 |
+
from onmt.utils.parse import ArgumentParser
|
| 11 |
+
from onmt.opts import dynamic_prepare_opts
|
| 12 |
+
from onmt.inputters.text_corpus import build_corpora_iters, get_corpora
|
| 13 |
+
from onmt.inputters.text_utils import process, append_features_to_text
|
| 14 |
+
from onmt.transforms import make_transforms, get_transforms_cls
|
| 15 |
+
from onmt.constants import CorpusName, CorpusTask
|
| 16 |
+
from collections import Counter
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
MAXBUCKETSIZE = 256000
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def write_files_from_queues(sample_path, queues):
|
| 23 |
+
"""
|
| 24 |
+
Standalone process that reads data from
|
| 25 |
+
queues in order and write to sample files.
|
| 26 |
+
"""
|
| 27 |
+
os.makedirs(sample_path, exist_ok=True)
|
| 28 |
+
for c_name in queues.keys():
|
| 29 |
+
dest_base = os.path.join(sample_path, "{}.{}".format(c_name, CorpusName.SAMPLE))
|
| 30 |
+
with open(dest_base + ".src", "w", encoding="utf-8") as f_src, open(
|
| 31 |
+
dest_base + ".tgt", "w", encoding="utf-8"
|
| 32 |
+
) as f_tgt:
|
| 33 |
+
while True:
|
| 34 |
+
_next = False
|
| 35 |
+
for q in queues[c_name]:
|
| 36 |
+
item = q.get()
|
| 37 |
+
if item == "blank":
|
| 38 |
+
continue
|
| 39 |
+
if item == "break":
|
| 40 |
+
_next = True
|
| 41 |
+
break
|
| 42 |
+
_, src_line, tgt_line = item
|
| 43 |
+
f_src.write(src_line + "\n")
|
| 44 |
+
f_tgt.write(tgt_line + "\n")
|
| 45 |
+
if _next:
|
| 46 |
+
break
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def build_sub_vocab(corpora, transforms, opts, n_sample, stride, offset):
|
| 50 |
+
"""Build vocab on (strided) subpart of the data."""
|
| 51 |
+
sub_counter_src = Counter()
|
| 52 |
+
sub_counter_tgt = Counter()
|
| 53 |
+
sub_counter_src_feats = [Counter() for _ in range(opts.n_src_feats)]
|
| 54 |
+
datasets_iterables = build_corpora_iters(
|
| 55 |
+
corpora,
|
| 56 |
+
transforms,
|
| 57 |
+
opts.data,
|
| 58 |
+
skip_empty_level=opts.skip_empty_level,
|
| 59 |
+
stride=stride,
|
| 60 |
+
offset=offset,
|
| 61 |
+
)
|
| 62 |
+
for c_name, c_iter in datasets_iterables.items():
|
| 63 |
+
for i, item in enumerate(c_iter):
|
| 64 |
+
maybe_example = process(CorpusTask.TRAIN, [item])
|
| 65 |
+
if maybe_example is not None:
|
| 66 |
+
maybe_example = maybe_example[0]
|
| 67 |
+
else:
|
| 68 |
+
if opts.dump_samples:
|
| 69 |
+
build_sub_vocab.queues[c_name][offset].put("blank")
|
| 70 |
+
continue
|
| 71 |
+
src_line, tgt_line = (
|
| 72 |
+
maybe_example["src"]["src"],
|
| 73 |
+
maybe_example["tgt"]["tgt"],
|
| 74 |
+
)
|
| 75 |
+
sub_counter_src.update(src_line.split(" "))
|
| 76 |
+
sub_counter_tgt.update(tgt_line.split(" "))
|
| 77 |
+
|
| 78 |
+
if "feats" in maybe_example["src"]:
|
| 79 |
+
src_feats_lines = maybe_example["src"]["feats"]
|
| 80 |
+
for k in range(opts.n_src_feats):
|
| 81 |
+
sub_counter_src_feats[k].update(src_feats_lines[k].split(" "))
|
| 82 |
+
else:
|
| 83 |
+
src_feats_lines = []
|
| 84 |
+
|
| 85 |
+
if opts.dump_samples:
|
| 86 |
+
src_pretty_line = append_features_to_text(src_line, src_feats_lines)
|
| 87 |
+
build_sub_vocab.queues[c_name][offset].put(
|
| 88 |
+
(i, src_pretty_line, tgt_line)
|
| 89 |
+
)
|
| 90 |
+
if n_sample > 0 and ((i + 1) * stride + offset) >= n_sample:
|
| 91 |
+
if opts.dump_samples:
|
| 92 |
+
build_sub_vocab.queues[c_name][offset].put("break")
|
| 93 |
+
break
|
| 94 |
+
if opts.dump_samples:
|
| 95 |
+
build_sub_vocab.queues[c_name][offset].put("break")
|
| 96 |
+
return sub_counter_src, sub_counter_tgt, sub_counter_src_feats
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def init_pool(queues):
|
| 100 |
+
"""Add the queues as attribute of the pooled function."""
|
| 101 |
+
build_sub_vocab.queues = queues
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def build_vocab(opts, transforms, n_sample=3):
|
| 105 |
+
"""Build vocabulary from data."""
|
| 106 |
+
|
| 107 |
+
if n_sample == -1:
|
| 108 |
+
logger.info(f"n_sample={n_sample}: Build vocab on full datasets.")
|
| 109 |
+
elif n_sample > 0:
|
| 110 |
+
logger.info(f"Build vocab on {n_sample} transformed examples/corpus.")
|
| 111 |
+
else:
|
| 112 |
+
raise ValueError(f"n_sample should > 0 or == -1, get {n_sample}.")
|
| 113 |
+
|
| 114 |
+
if opts.dump_samples:
|
| 115 |
+
logger.info(
|
| 116 |
+
"The samples on which the vocab is built will be "
|
| 117 |
+
"dumped to disk. It may slow down the process."
|
| 118 |
+
)
|
| 119 |
+
corpora = get_corpora(opts, task=CorpusTask.TRAIN)
|
| 120 |
+
counter_src = Counter()
|
| 121 |
+
counter_tgt = Counter()
|
| 122 |
+
counter_src_feats = [Counter() for _ in range(opts.n_src_feats)]
|
| 123 |
+
|
| 124 |
+
queues = {
|
| 125 |
+
c_name: [
|
| 126 |
+
mp.Queue(opts.vocab_sample_queue_size) for i in range(opts.num_threads)
|
| 127 |
+
]
|
| 128 |
+
for c_name in corpora.keys()
|
| 129 |
+
}
|
| 130 |
+
sample_path = os.path.join(os.path.dirname(opts.save_data), CorpusName.SAMPLE)
|
| 131 |
+
if opts.dump_samples:
|
| 132 |
+
write_process = mp.Process(
|
| 133 |
+
target=write_files_from_queues, args=(sample_path, queues), daemon=True
|
| 134 |
+
)
|
| 135 |
+
write_process.start()
|
| 136 |
+
with mp.Pool(opts.num_threads, init_pool, [queues]) as p:
|
| 137 |
+
func = partial(
|
| 138 |
+
build_sub_vocab, corpora, transforms, opts, n_sample, opts.num_threads
|
| 139 |
+
)
|
| 140 |
+
for sub_counter_src, sub_counter_tgt, sub_counter_src_feats in p.imap(
|
| 141 |
+
func, range(0, opts.num_threads)
|
| 142 |
+
):
|
| 143 |
+
counter_src.update(sub_counter_src)
|
| 144 |
+
counter_tgt.update(sub_counter_tgt)
|
| 145 |
+
for i in range(opts.n_src_feats):
|
| 146 |
+
counter_src_feats[i].update(sub_counter_src_feats[i])
|
| 147 |
+
if opts.dump_samples:
|
| 148 |
+
write_process.join()
|
| 149 |
+
return counter_src, counter_tgt, counter_src_feats
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def ingest_tokens(opts, transforms, n_sample, learner, stride, offset):
|
| 153 |
+
def _mp_ingest(data):
|
| 154 |
+
func = partial(process, CorpusName.TRAIN)
|
| 155 |
+
chunk = len(data) // opts.num_threads
|
| 156 |
+
with mp.Pool(opts.num_threads) as pool:
|
| 157 |
+
buckets = pool.map(
|
| 158 |
+
func,
|
| 159 |
+
[data[i * chunk : (i + 1) * chunk] for i in range(0, opts.num_threads)],
|
| 160 |
+
)
|
| 161 |
+
for bucket in buckets:
|
| 162 |
+
for ex in bucket:
|
| 163 |
+
if ex is not None:
|
| 164 |
+
src_line, tgt_line = (ex["src"]["src"], ex["tgt"]["tgt"])
|
| 165 |
+
learner.ingest(src_line)
|
| 166 |
+
learner.ingest(tgt_line)
|
| 167 |
+
|
| 168 |
+
corpora = get_corpora(opts, task=CorpusTask.TRAIN)
|
| 169 |
+
datasets_iterables = build_corpora_iters(
|
| 170 |
+
corpora,
|
| 171 |
+
transforms,
|
| 172 |
+
opts.data,
|
| 173 |
+
skip_empty_level=opts.skip_empty_level,
|
| 174 |
+
stride=stride,
|
| 175 |
+
offset=offset,
|
| 176 |
+
)
|
| 177 |
+
to_ingest = []
|
| 178 |
+
for c_name, c_iter in datasets_iterables.items():
|
| 179 |
+
for i, item in enumerate(c_iter):
|
| 180 |
+
if n_sample >= 0 and i >= n_sample:
|
| 181 |
+
break
|
| 182 |
+
if len(to_ingest) >= MAXBUCKETSIZE:
|
| 183 |
+
_mp_ingest(to_ingest)
|
| 184 |
+
to_ingest = []
|
| 185 |
+
to_ingest.append(item)
|
| 186 |
+
_mp_ingest(to_ingest)
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def make_learner(tokenization_type, symbols):
|
| 190 |
+
if tokenization_type == "bpe":
|
| 191 |
+
# BPE training
|
| 192 |
+
learner = pyonmttok.BPELearner(tokenizer=None, symbols=symbols)
|
| 193 |
+
elif tokenization_type == "sentencepiece":
|
| 194 |
+
# SentencePiece training
|
| 195 |
+
learner = pyonmttok.SentencePieceLearner(
|
| 196 |
+
vocab_size=symbols, character_coverage=0.98
|
| 197 |
+
)
|
| 198 |
+
return learner
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def build_vocab_main(opts):
|
| 202 |
+
"""Apply transforms to samples of specified data and build vocab from it.
|
| 203 |
+
|
| 204 |
+
Transforms that need vocab will be disabled in this.
|
| 205 |
+
Built vocab is saved in plain text format as following and can be pass as
|
| 206 |
+
`-src_vocab` (and `-tgt_vocab`) when training:
|
| 207 |
+
```
|
| 208 |
+
<tok_0>\t<count_0>
|
| 209 |
+
<tok_1>\t<count_1>
|
| 210 |
+
```
|
| 211 |
+
"""
|
| 212 |
+
|
| 213 |
+
ArgumentParser.validate_prepare_opts(opts, build_vocab_only=True)
|
| 214 |
+
assert (
|
| 215 |
+
opts.n_sample == -1 or opts.n_sample > 1
|
| 216 |
+
), f"Illegal argument n_sample={opts.n_sample}."
|
| 217 |
+
|
| 218 |
+
logger = init_logger()
|
| 219 |
+
set_random_seed(opts.seed, False)
|
| 220 |
+
transforms_cls = get_transforms_cls(opts._all_transform)
|
| 221 |
+
|
| 222 |
+
if opts.learn_subwords:
|
| 223 |
+
logger.info(f"Ingesting {opts.src_subword_type} model from corpus")
|
| 224 |
+
learner = make_learner(opts.src_subword_type, opts.learn_subwords_size)
|
| 225 |
+
if opts.src_subword_model is not None:
|
| 226 |
+
tok_path = opts.src_subword_model
|
| 227 |
+
else:
|
| 228 |
+
data_dir = os.path.split(opts.save_data)[0]
|
| 229 |
+
if not os.path.exists(data_dir):
|
| 230 |
+
os.makedirs(data_dir)
|
| 231 |
+
tok_path = os.path.join(data_dir, f"{opts.src_subword_type}.model")
|
| 232 |
+
save_opts = copy.deepcopy(opts)
|
| 233 |
+
opts.src_subword_type = "none"
|
| 234 |
+
opts.tgt_subword_type = "none"
|
| 235 |
+
opts.src_onmttok_kwargs["joiner_annotate"] = False
|
| 236 |
+
opts.tgt_onmttok_kwargs["joiner_annotate"] = False
|
| 237 |
+
transforms = make_transforms(opts, transforms_cls, None)
|
| 238 |
+
ingest_tokens(opts, transforms, opts.n_sample, learner, 1, 0)
|
| 239 |
+
logger.info(f"Learning {tok_path} model, patience")
|
| 240 |
+
learner.learn(tok_path)
|
| 241 |
+
opts = save_opts
|
| 242 |
+
|
| 243 |
+
transforms = make_transforms(opts, transforms_cls, None)
|
| 244 |
+
|
| 245 |
+
logger.info(f"Counter vocab from {opts.n_sample} samples.")
|
| 246 |
+
src_counter, tgt_counter, src_feats_counter = build_vocab(
|
| 247 |
+
opts, transforms, n_sample=opts.n_sample
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
logger.info(f"Counters src: {len(src_counter)}")
|
| 251 |
+
logger.info(f"Counters tgt: {len(tgt_counter)}")
|
| 252 |
+
for i, feat_counter in enumerate(src_feats_counter):
|
| 253 |
+
logger.info(f"Counters src feat_{i}: {len(feat_counter)}")
|
| 254 |
+
|
| 255 |
+
def save_counter(counter, save_path):
|
| 256 |
+
check_path(save_path, exist_ok=opts.overwrite, log=logger.warning)
|
| 257 |
+
with open(save_path, "w", encoding="utf8") as fo:
|
| 258 |
+
for tok, count in counter.most_common():
|
| 259 |
+
fo.write(tok + "\t" + str(count) + "\n")
|
| 260 |
+
|
| 261 |
+
if opts.share_vocab:
|
| 262 |
+
src_counter += tgt_counter
|
| 263 |
+
tgt_counter = src_counter
|
| 264 |
+
logger.info(f"Counters after share:{len(src_counter)}")
|
| 265 |
+
save_counter(src_counter, opts.src_vocab)
|
| 266 |
+
else:
|
| 267 |
+
save_counter(src_counter, opts.src_vocab)
|
| 268 |
+
save_counter(tgt_counter, opts.tgt_vocab)
|
| 269 |
+
|
| 270 |
+
for i, c in enumerate(src_feats_counter):
|
| 271 |
+
save_counter(c, f"{opts.src_vocab}_feat{i}")
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def _get_parser():
|
| 275 |
+
parser = ArgumentParser(description="build_vocab.py")
|
| 276 |
+
dynamic_prepare_opts(parser, build_vocab_only=True)
|
| 277 |
+
return parser
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def main():
|
| 281 |
+
parser = _get_parser()
|
| 282 |
+
opts, unknown = parser.parse_known_args()
|
| 283 |
+
build_vocab_main(opts)
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
if __name__ == "__main__":
|
| 287 |
+
main()
|
onmt/bin/release_model.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
import argparse
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def main():
|
| 7 |
+
parser = argparse.ArgumentParser(
|
| 8 |
+
description="Release an OpenNMT-py model for inference"
|
| 9 |
+
)
|
| 10 |
+
parser.add_argument("--model", "-m", help="The model path", required=True)
|
| 11 |
+
parser.add_argument("--output", "-o", help="The output path", required=True)
|
| 12 |
+
parser.add_argument(
|
| 13 |
+
"--format",
|
| 14 |
+
choices=["pytorch", "ctranslate2"],
|
| 15 |
+
default="pytorch",
|
| 16 |
+
help="The format of the released model",
|
| 17 |
+
)
|
| 18 |
+
parser.add_argument(
|
| 19 |
+
"--quantization",
|
| 20 |
+
"-q",
|
| 21 |
+
choices=["int8", "int16", "float16", "int8_float16"],
|
| 22 |
+
default=None,
|
| 23 |
+
help="Quantization type for CT2 model.",
|
| 24 |
+
)
|
| 25 |
+
opt = parser.parse_args()
|
| 26 |
+
|
| 27 |
+
model = torch.load(opt.model, map_location=torch.device("cpu"))
|
| 28 |
+
if opt.format == "pytorch":
|
| 29 |
+
model["optim"] = None
|
| 30 |
+
torch.save(model, opt.output)
|
| 31 |
+
elif opt.format == "ctranslate2":
|
| 32 |
+
import ctranslate2
|
| 33 |
+
|
| 34 |
+
converter = ctranslate2.converters.OpenNMTPyConverter(opt.model)
|
| 35 |
+
converter.convert(opt.output, force=True, quantization=opt.quantization)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
if __name__ == "__main__":
|
| 39 |
+
main()
|
onmt/bin/server.py
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
import configargparse
|
| 3 |
+
|
| 4 |
+
from flask import Flask, jsonify, request
|
| 5 |
+
from waitress import serve
|
| 6 |
+
from onmt.translate import TranslationServer, ServerModelError
|
| 7 |
+
import logging
|
| 8 |
+
from logging.handlers import RotatingFileHandler
|
| 9 |
+
|
| 10 |
+
STATUS_OK = "ok"
|
| 11 |
+
STATUS_ERROR = "error"
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def start(config_file, url_root="./translator", host="0.0.0.0", port=5000, debug=False):
|
| 15 |
+
def prefix_route(route_function, prefix="", mask="{0}{1}"):
|
| 16 |
+
def newroute(route, *args, **kwargs):
|
| 17 |
+
return route_function(mask.format(prefix, route), *args, **kwargs)
|
| 18 |
+
|
| 19 |
+
return newroute
|
| 20 |
+
|
| 21 |
+
if debug:
|
| 22 |
+
logger = logging.getLogger("main")
|
| 23 |
+
log_format = logging.Formatter("[%(asctime)s %(levelname)s] %(message)s")
|
| 24 |
+
file_handler = RotatingFileHandler(
|
| 25 |
+
"debug_requests.log", maxBytes=1000000, backupCount=10
|
| 26 |
+
)
|
| 27 |
+
file_handler.setFormatter(log_format)
|
| 28 |
+
logger.addHandler(file_handler)
|
| 29 |
+
|
| 30 |
+
app = Flask(__name__)
|
| 31 |
+
app.route = prefix_route(app.route, url_root)
|
| 32 |
+
translation_server = TranslationServer()
|
| 33 |
+
translation_server.start(config_file)
|
| 34 |
+
|
| 35 |
+
@app.route("/models", methods=["GET"])
|
| 36 |
+
def get_models():
|
| 37 |
+
out = translation_server.list_models()
|
| 38 |
+
return jsonify(out)
|
| 39 |
+
|
| 40 |
+
@app.route("/health", methods=["GET"])
|
| 41 |
+
def health():
|
| 42 |
+
out = {}
|
| 43 |
+
out["status"] = STATUS_OK
|
| 44 |
+
return jsonify(out)
|
| 45 |
+
|
| 46 |
+
@app.route("/clone_model/<int:model_id>", methods=["POST"])
|
| 47 |
+
def clone_model(model_id):
|
| 48 |
+
out = {}
|
| 49 |
+
data = request.get_json(force=True)
|
| 50 |
+
timeout = -1
|
| 51 |
+
if "timeout" in data:
|
| 52 |
+
timeout = data["timeout"]
|
| 53 |
+
del data["timeout"]
|
| 54 |
+
|
| 55 |
+
opt = data.get("opt", None)
|
| 56 |
+
try:
|
| 57 |
+
model_id, load_time = translation_server.clone_model(model_id, opt, timeout)
|
| 58 |
+
except ServerModelError as e:
|
| 59 |
+
out["status"] = STATUS_ERROR
|
| 60 |
+
out["error"] = str(e)
|
| 61 |
+
else:
|
| 62 |
+
out["status"] = STATUS_OK
|
| 63 |
+
out["model_id"] = model_id
|
| 64 |
+
out["load_time"] = load_time
|
| 65 |
+
|
| 66 |
+
return jsonify(out)
|
| 67 |
+
|
| 68 |
+
@app.route("/unload_model/<int:model_id>", methods=["GET"])
|
| 69 |
+
def unload_model(model_id):
|
| 70 |
+
out = {"model_id": model_id}
|
| 71 |
+
|
| 72 |
+
try:
|
| 73 |
+
translation_server.unload_model(model_id)
|
| 74 |
+
out["status"] = STATUS_OK
|
| 75 |
+
except Exception as e:
|
| 76 |
+
out["status"] = STATUS_ERROR
|
| 77 |
+
out["error"] = str(e)
|
| 78 |
+
|
| 79 |
+
return jsonify(out)
|
| 80 |
+
|
| 81 |
+
@app.route("/translate", methods=["POST"])
|
| 82 |
+
def translate():
|
| 83 |
+
inputs = request.get_json(force=True)
|
| 84 |
+
if debug:
|
| 85 |
+
logger.info(inputs)
|
| 86 |
+
out = {}
|
| 87 |
+
try:
|
| 88 |
+
trans, scores, n_best, _, aligns, align_scores = translation_server.run(
|
| 89 |
+
inputs
|
| 90 |
+
)
|
| 91 |
+
assert len(trans) == len(inputs) * n_best
|
| 92 |
+
assert len(scores) == len(inputs) * n_best
|
| 93 |
+
assert len(aligns) == len(inputs) * n_best
|
| 94 |
+
|
| 95 |
+
out = [[] for _ in range(n_best)]
|
| 96 |
+
for i in range(len(trans)):
|
| 97 |
+
response = {
|
| 98 |
+
"src": inputs[i // n_best]["src"],
|
| 99 |
+
"tgt": trans[i],
|
| 100 |
+
"n_best": n_best,
|
| 101 |
+
"pred_score": scores[i],
|
| 102 |
+
}
|
| 103 |
+
if len(aligns[i]) > 0 and aligns[i][0] is not None:
|
| 104 |
+
response["align"] = aligns[i]
|
| 105 |
+
response["align_score"] = align_scores[i]
|
| 106 |
+
out[i % n_best].append(response)
|
| 107 |
+
except ServerModelError as e:
|
| 108 |
+
model_id = inputs[0].get("id")
|
| 109 |
+
if debug:
|
| 110 |
+
logger.warning(
|
| 111 |
+
"Unload model #{} " "because of an error".format(model_id)
|
| 112 |
+
)
|
| 113 |
+
translation_server.models[model_id].unload()
|
| 114 |
+
out["error"] = str(e)
|
| 115 |
+
out["status"] = STATUS_ERROR
|
| 116 |
+
if debug:
|
| 117 |
+
logger.info(out)
|
| 118 |
+
return jsonify(out)
|
| 119 |
+
|
| 120 |
+
@app.route("/to_cpu/<int:model_id>", methods=["GET"])
|
| 121 |
+
def to_cpu(model_id):
|
| 122 |
+
out = {"model_id": model_id}
|
| 123 |
+
translation_server.models[model_id].to_cpu()
|
| 124 |
+
|
| 125 |
+
out["status"] = STATUS_OK
|
| 126 |
+
return jsonify(out)
|
| 127 |
+
|
| 128 |
+
@app.route("/to_gpu/<int:model_id>", methods=["GET"])
|
| 129 |
+
def to_gpu(model_id):
|
| 130 |
+
out = {"model_id": model_id}
|
| 131 |
+
translation_server.models[model_id].to_gpu()
|
| 132 |
+
|
| 133 |
+
out["status"] = STATUS_OK
|
| 134 |
+
return jsonify(out)
|
| 135 |
+
|
| 136 |
+
serve(app, host=host, port=port)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def _get_parser():
|
| 140 |
+
parser = configargparse.ArgumentParser(
|
| 141 |
+
config_file_parser_class=configargparse.YAMLConfigFileParser,
|
| 142 |
+
description="OpenNMT-py REST Server",
|
| 143 |
+
)
|
| 144 |
+
parser.add_argument("--ip", type=str, default="0.0.0.0")
|
| 145 |
+
parser.add_argument("--port", type=int, default="5000")
|
| 146 |
+
parser.add_argument("--url_root", type=str, default="/translator")
|
| 147 |
+
parser.add_argument("--debug", "-d", action="store_true")
|
| 148 |
+
parser.add_argument(
|
| 149 |
+
"--config", "-c", type=str, default="./available_models/conf.json"
|
| 150 |
+
)
|
| 151 |
+
return parser
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def main():
|
| 155 |
+
parser = _get_parser()
|
| 156 |
+
args = parser.parse_args()
|
| 157 |
+
start(
|
| 158 |
+
args.config,
|
| 159 |
+
url_root=args.url_root,
|
| 160 |
+
host=args.ip,
|
| 161 |
+
port=args.port,
|
| 162 |
+
debug=args.debug,
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
if __name__ == "__main__":
|
| 167 |
+
main()
|
onmt/bin/train.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
"""Train models with dynamic data."""
|
| 3 |
+
import torch
|
| 4 |
+
from functools import partial
|
| 5 |
+
from onmt.utils.distributed import ErrorHandler, spawned_train
|
| 6 |
+
from onmt.utils.misc import set_random_seed
|
| 7 |
+
from onmt.utils.logging import init_logger, logger
|
| 8 |
+
from onmt.utils.parse import ArgumentParser
|
| 9 |
+
from onmt.opts import train_opts
|
| 10 |
+
from onmt.train_single import main as single_main
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# Set sharing strategy manually instead of default based on the OS.
|
| 14 |
+
# torch.multiprocessing.set_sharing_strategy('file_system')
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def train(opt):
|
| 18 |
+
init_logger(opt.log_file)
|
| 19 |
+
|
| 20 |
+
ArgumentParser.validate_train_opts(opt)
|
| 21 |
+
ArgumentParser.update_model_opts(opt)
|
| 22 |
+
ArgumentParser.validate_model_opts(opt)
|
| 23 |
+
|
| 24 |
+
set_random_seed(opt.seed, False)
|
| 25 |
+
|
| 26 |
+
train_process = partial(single_main)
|
| 27 |
+
|
| 28 |
+
nb_gpu = len(opt.gpu_ranks)
|
| 29 |
+
|
| 30 |
+
if opt.world_size > 1:
|
| 31 |
+
mp = torch.multiprocessing.get_context("spawn")
|
| 32 |
+
# Create a thread to listen for errors in the child processes.
|
| 33 |
+
error_queue = mp.SimpleQueue()
|
| 34 |
+
error_handler = ErrorHandler(error_queue)
|
| 35 |
+
# Train with multiprocessing.
|
| 36 |
+
procs = []
|
| 37 |
+
for device_id in range(nb_gpu):
|
| 38 |
+
procs.append(
|
| 39 |
+
mp.Process(
|
| 40 |
+
target=spawned_train,
|
| 41 |
+
args=(train_process, opt, device_id, error_queue),
|
| 42 |
+
daemon=False,
|
| 43 |
+
)
|
| 44 |
+
)
|
| 45 |
+
procs[device_id].start()
|
| 46 |
+
logger.info(" Starting process pid: %d " % procs[device_id].pid)
|
| 47 |
+
error_handler.add_child(procs[device_id].pid)
|
| 48 |
+
for p in procs:
|
| 49 |
+
p.join()
|
| 50 |
+
|
| 51 |
+
elif nb_gpu == 1: # case 1 GPU only
|
| 52 |
+
train_process(opt, device_id=0)
|
| 53 |
+
else: # case only CPU
|
| 54 |
+
train_process(opt, device_id=-1)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def _get_parser():
|
| 58 |
+
parser = ArgumentParser(description="train.py")
|
| 59 |
+
train_opts(parser)
|
| 60 |
+
return parser
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def main():
|
| 64 |
+
parser = _get_parser()
|
| 65 |
+
|
| 66 |
+
opt, unknown = parser.parse_known_args()
|
| 67 |
+
train(opt)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
if __name__ == "__main__":
|
| 71 |
+
main()
|
onmt/bin/translate.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
from onmt.utils.logging import init_logger
|
| 4 |
+
from onmt.translate.translator import build_translator
|
| 5 |
+
from onmt.inputters.dynamic_iterator import build_dynamic_dataset_iter
|
| 6 |
+
from onmt.inputters.inputter import IterOnDevice
|
| 7 |
+
from onmt.transforms import get_transforms_cls
|
| 8 |
+
from onmt.constants import CorpusTask
|
| 9 |
+
import onmt.opts as opts
|
| 10 |
+
from onmt.utils.parse import ArgumentParser
|
| 11 |
+
from onmt.utils.misc import use_gpu, set_random_seed
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def translate(opt):
|
| 15 |
+
ArgumentParser.validate_translate_opts(opt)
|
| 16 |
+
ArgumentParser._get_all_transform_translate(opt)
|
| 17 |
+
ArgumentParser._validate_transforms_opts(opt)
|
| 18 |
+
ArgumentParser.validate_translate_opts_dynamic(opt)
|
| 19 |
+
logger = init_logger(opt.log_file)
|
| 20 |
+
|
| 21 |
+
set_random_seed(opt.seed, use_gpu(opt))
|
| 22 |
+
|
| 23 |
+
translator = build_translator(opt, logger=logger, report_score=True)
|
| 24 |
+
|
| 25 |
+
transforms_cls = get_transforms_cls(opt._all_transform)
|
| 26 |
+
|
| 27 |
+
infer_iter = build_dynamic_dataset_iter(
|
| 28 |
+
opt,
|
| 29 |
+
transforms_cls,
|
| 30 |
+
translator.vocabs,
|
| 31 |
+
task=CorpusTask.INFER,
|
| 32 |
+
copy=translator.copy_attn,
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
infer_iter = IterOnDevice(infer_iter, opt.gpu)
|
| 36 |
+
|
| 37 |
+
_, _ = translator._translate(
|
| 38 |
+
infer_iter,
|
| 39 |
+
transform=infer_iter.transform,
|
| 40 |
+
attn_debug=opt.attn_debug,
|
| 41 |
+
align_debug=opt.align_debug,
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def _get_parser():
|
| 46 |
+
parser = ArgumentParser(description="translate.py")
|
| 47 |
+
|
| 48 |
+
opts.config_opts(parser)
|
| 49 |
+
opts.translate_opts(parser, dynamic=True)
|
| 50 |
+
return parser
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def main():
|
| 54 |
+
parser = _get_parser()
|
| 55 |
+
opt = parser.parse_args()
|
| 56 |
+
translate(opt)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
if __name__ == "__main__":
|
| 60 |
+
main()
|
onmt/constants.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Define constant values used across the project."""
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class DefaultTokens(object):
|
| 5 |
+
PAD = "<blank>"
|
| 6 |
+
BOS = "<s>"
|
| 7 |
+
EOS = "</s>"
|
| 8 |
+
UNK = "<unk>"
|
| 9 |
+
MASK = "<mask>"
|
| 10 |
+
VOCAB_PAD = "averyunlikelytoken"
|
| 11 |
+
SENT_FULL_STOPS = [".", "?", "!"]
|
| 12 |
+
PHRASE_TABLE_SEPARATOR = "|||"
|
| 13 |
+
ALIGNMENT_SEPARATOR = " ||| "
|
| 14 |
+
SEP = "⦅newline⦆"
|
| 15 |
+
MASK_BEFORE = "⦅_mask_before_⦆"
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class CorpusName(object):
|
| 19 |
+
VALID = "valid"
|
| 20 |
+
TRAIN = "train"
|
| 21 |
+
SAMPLE = "sample"
|
| 22 |
+
INFER = "infer"
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class CorpusTask(object):
|
| 26 |
+
TRAIN = "train"
|
| 27 |
+
VALID = "valid"
|
| 28 |
+
INFER = "infer"
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class SubwordMarker(object):
|
| 32 |
+
SPACER = "▁"
|
| 33 |
+
JOINER = "■"
|
| 34 |
+
BEGIN_UPPERCASE = "⦅mrk_begin_case_region_U⦆"
|
| 35 |
+
END_UPPERCASE = "⦅mrk_end_case_region_U⦆"
|
| 36 |
+
BEGIN_CASED = "⦅mrk_case_modifier_C⦆"
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class ModelTask(object):
|
| 40 |
+
LANGUAGE_MODEL = "lm"
|
| 41 |
+
SEQ2SEQ = "seq2seq"
|
onmt/decoders/__init__.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Module defining decoders."""
|
| 2 |
+
import os
|
| 3 |
+
import importlib
|
| 4 |
+
from onmt.decoders.decoder import DecoderBase, InputFeedRNNDecoder, StdRNNDecoder
|
| 5 |
+
from onmt.decoders.transformer import TransformerDecoder, TransformerLMDecoder
|
| 6 |
+
from onmt.decoders.cnn_decoder import CNNDecoder
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
str2dec = {
|
| 10 |
+
"rnn": StdRNNDecoder,
|
| 11 |
+
"ifrnn": InputFeedRNNDecoder,
|
| 12 |
+
"cnn": CNNDecoder,
|
| 13 |
+
"transformer": TransformerDecoder,
|
| 14 |
+
"transformer_lm": TransformerLMDecoder,
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
__all__ = [
|
| 18 |
+
"DecoderBase",
|
| 19 |
+
"TransformerDecoder",
|
| 20 |
+
"StdRNNDecoder",
|
| 21 |
+
"CNNDecoder",
|
| 22 |
+
"InputFeedRNNDecoder",
|
| 23 |
+
"str2dec",
|
| 24 |
+
"TransformerLMDecoder",
|
| 25 |
+
]
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def get_decoders_cls(decoders_names):
|
| 29 |
+
"""Return valid encoder class indicated in `decoders_names`."""
|
| 30 |
+
decoders_cls = {}
|
| 31 |
+
for name in decoders_names:
|
| 32 |
+
if name not in str2dec:
|
| 33 |
+
raise ValueError("%s decoder not supported!" % name)
|
| 34 |
+
decoders_cls[name] = str2dec[name]
|
| 35 |
+
return decoders_cls
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def register_decoder(name):
|
| 39 |
+
"""Encoder register that can be used to add new encoder class."""
|
| 40 |
+
|
| 41 |
+
def register_decoder_cls(cls):
|
| 42 |
+
if name in str2dec:
|
| 43 |
+
raise ValueError("Cannot register duplicate decoder ({})".format(name))
|
| 44 |
+
if not issubclass(cls, DecoderBase):
|
| 45 |
+
raise ValueError(f"decoder ({name}: {cls.__name_}) must extend DecoderBase")
|
| 46 |
+
str2dec[name] = cls
|
| 47 |
+
__all__.append(cls.__name__) # added to be complete
|
| 48 |
+
return cls
|
| 49 |
+
|
| 50 |
+
return register_decoder_cls
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
# Auto import python files in this directory
|
| 54 |
+
decoder_dir = os.path.dirname(__file__)
|
| 55 |
+
for file in os.listdir(decoder_dir):
|
| 56 |
+
path = os.path.join(decoder_dir, file)
|
| 57 |
+
if (
|
| 58 |
+
not file.startswith("_")
|
| 59 |
+
and not file.startswith(".")
|
| 60 |
+
and (file.endswith(".py") or os.path.isdir(path))
|
| 61 |
+
):
|
| 62 |
+
file_name = file[: file.find(".py")] if file.endswith(".py") else file
|
| 63 |
+
module = importlib.import_module("onmt.decoders." + file_name)
|
onmt/decoders/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (3 kB). View file
|
|
|
onmt/decoders/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (1.84 kB). View file
|
|
|
onmt/decoders/__pycache__/cnn_decoder.cpython-311.pyc
ADDED
|
Binary file (7.32 kB). View file
|
|
|
onmt/decoders/__pycache__/cnn_decoder.cpython-38.pyc
ADDED
|
Binary file (4.02 kB). View file
|
|
|
onmt/decoders/__pycache__/decoder.cpython-311.pyc
ADDED
|
Binary file (18.4 kB). View file
|
|
|
onmt/decoders/__pycache__/decoder.cpython-38.pyc
ADDED
|
Binary file (11.5 kB). View file
|
|
|
onmt/decoders/__pycache__/ensemble.cpython-311.pyc
ADDED
|
Binary file (11 kB). View file
|
|
|
onmt/decoders/__pycache__/ensemble.cpython-38.pyc
ADDED
|
Binary file (7.13 kB). View file
|
|
|
onmt/decoders/__pycache__/transformer.cpython-311.pyc
ADDED
|
Binary file (32.9 kB). View file
|
|
|
onmt/decoders/__pycache__/transformer.cpython-38.pyc
ADDED
|
Binary file (20.4 kB). View file
|
|
|
onmt/decoders/cnn_decoder.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Implementation of the CNN Decoder part of
|
| 2 |
+
"Convolutional Sequence to Sequence Learning"
|
| 3 |
+
"""
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
|
| 7 |
+
from onmt.modules import ConvMultiStepAttention, GlobalAttention
|
| 8 |
+
from onmt.utils.cnn_factory import shape_transform, GatedConv
|
| 9 |
+
from onmt.decoders.decoder import DecoderBase
|
| 10 |
+
|
| 11 |
+
SCALE_WEIGHT = 0.5**0.5
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class CNNDecoder(DecoderBase):
|
| 15 |
+
"""Decoder based on "Convolutional Sequence to Sequence Learning"
|
| 16 |
+
:cite:`DBLP:journals/corr/GehringAGYD17`.
|
| 17 |
+
|
| 18 |
+
Consists of residual convolutional layers, with ConvMultiStepAttention.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def __init__(
|
| 22 |
+
self,
|
| 23 |
+
num_layers,
|
| 24 |
+
hidden_size,
|
| 25 |
+
attn_type,
|
| 26 |
+
copy_attn,
|
| 27 |
+
cnn_kernel_width,
|
| 28 |
+
dropout,
|
| 29 |
+
embeddings,
|
| 30 |
+
copy_attn_type,
|
| 31 |
+
):
|
| 32 |
+
super(CNNDecoder, self).__init__()
|
| 33 |
+
|
| 34 |
+
self.cnn_kernel_width = cnn_kernel_width
|
| 35 |
+
self.embeddings = embeddings
|
| 36 |
+
|
| 37 |
+
# Decoder State
|
| 38 |
+
self.state = {}
|
| 39 |
+
|
| 40 |
+
input_size = self.embeddings.embedding_size
|
| 41 |
+
self.linear = nn.Linear(input_size, hidden_size)
|
| 42 |
+
self.conv_layers = nn.ModuleList(
|
| 43 |
+
[
|
| 44 |
+
GatedConv(hidden_size, cnn_kernel_width, dropout, True)
|
| 45 |
+
for i in range(num_layers)
|
| 46 |
+
]
|
| 47 |
+
)
|
| 48 |
+
self.attn_layers = nn.ModuleList(
|
| 49 |
+
[ConvMultiStepAttention(hidden_size) for i in range(num_layers)]
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
# CNNDecoder has its own attention mechanism.
|
| 53 |
+
# Set up a separate copy attention layer if needed.
|
| 54 |
+
assert not copy_attn, "Copy mechanism not yet tested in conv2conv"
|
| 55 |
+
if copy_attn:
|
| 56 |
+
self.copy_attn = GlobalAttention(hidden_size, attn_type=copy_attn_type)
|
| 57 |
+
else:
|
| 58 |
+
self.copy_attn = None
|
| 59 |
+
|
| 60 |
+
@classmethod
|
| 61 |
+
def from_opt(cls, opt, embeddings):
|
| 62 |
+
"""Alternate constructor."""
|
| 63 |
+
return cls(
|
| 64 |
+
opt.dec_layers,
|
| 65 |
+
opt.dec_hid_size,
|
| 66 |
+
opt.global_attention,
|
| 67 |
+
opt.copy_attn,
|
| 68 |
+
opt.cnn_kernel_width,
|
| 69 |
+
opt.dropout[0] if type(opt.dropout) is list else opt.dropout,
|
| 70 |
+
embeddings,
|
| 71 |
+
opt.copy_attn_type,
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
def init_state(self, _, enc_out, enc_hidden):
|
| 75 |
+
"""Init decoder state."""
|
| 76 |
+
self.state["src"] = (enc_out + enc_hidden) * SCALE_WEIGHT
|
| 77 |
+
self.state["previous_input"] = None
|
| 78 |
+
|
| 79 |
+
def map_state(self, fn):
|
| 80 |
+
self.state["src"] = fn(self.state["src"], 0)
|
| 81 |
+
if self.state["previous_input"] is not None:
|
| 82 |
+
self.state["previous_input"] = fn(self.state["previous_input"], 0)
|
| 83 |
+
|
| 84 |
+
def detach_state(self):
|
| 85 |
+
self.state["previous_input"] = self.state["previous_input"].detach()
|
| 86 |
+
|
| 87 |
+
def forward(self, tgt, enc_out, step=None, **kwargs):
|
| 88 |
+
"""See :obj:`onmt.modules.RNNDecoderBase.forward()`"""
|
| 89 |
+
|
| 90 |
+
if self.state["previous_input"] is not None:
|
| 91 |
+
tgt = torch.cat([self.state["previous_input"], tgt], 1)
|
| 92 |
+
|
| 93 |
+
dec_outs = []
|
| 94 |
+
attns = {"std": []}
|
| 95 |
+
if self.copy_attn is not None:
|
| 96 |
+
attns["copy"] = []
|
| 97 |
+
|
| 98 |
+
emb = self.embeddings(tgt)
|
| 99 |
+
assert emb.dim() == 3 # batch x len x embedding_dim
|
| 100 |
+
|
| 101 |
+
tgt_emb = emb
|
| 102 |
+
# The output of CNNEncoder.
|
| 103 |
+
enc_out_t = enc_out
|
| 104 |
+
# The combination of output of CNNEncoder and source embeddings.
|
| 105 |
+
enc_out_c = self.state["src"]
|
| 106 |
+
|
| 107 |
+
emb_reshape = tgt_emb.view(tgt_emb.size(0) * tgt_emb.size(1), -1)
|
| 108 |
+
linear_out = self.linear(emb_reshape)
|
| 109 |
+
x = linear_out.view(tgt_emb.size(0), tgt_emb.size(1), -1)
|
| 110 |
+
x = shape_transform(x)
|
| 111 |
+
|
| 112 |
+
pad = torch.zeros(x.size(0), x.size(1), self.cnn_kernel_width - 1, 1)
|
| 113 |
+
|
| 114 |
+
pad = pad.type_as(x)
|
| 115 |
+
base_target_emb = x
|
| 116 |
+
|
| 117 |
+
for conv, attention in zip(self.conv_layers, self.attn_layers):
|
| 118 |
+
new_target_input = torch.cat([pad, x], 2)
|
| 119 |
+
out = conv(new_target_input)
|
| 120 |
+
c, attn = attention(base_target_emb, out, enc_out_t, enc_out_c)
|
| 121 |
+
x = (x + (c + out) * SCALE_WEIGHT) * SCALE_WEIGHT
|
| 122 |
+
|
| 123 |
+
dec_outs = x.squeeze(3).transpose(1, 2)
|
| 124 |
+
|
| 125 |
+
# Process the result and update the attentions.
|
| 126 |
+
if self.state["previous_input"] is not None:
|
| 127 |
+
dec_outs = dec_outs[:, self.state["previous_input"].size(1) :, :]
|
| 128 |
+
attn = attn[:, self.state["previous_input"].size(1) :].squeeze()
|
| 129 |
+
attn = torch.stack([attn])
|
| 130 |
+
attns["std"] = attn
|
| 131 |
+
if self.copy_attn is not None:
|
| 132 |
+
attns["copy"] = attn
|
| 133 |
+
|
| 134 |
+
# Update the state.
|
| 135 |
+
self.state["previous_input"] = tgt
|
| 136 |
+
# TODO change the way attns is returned dict => list or tuple (onnx)
|
| 137 |
+
return dec_outs, attns
|
| 138 |
+
|
| 139 |
+
def update_dropout(self, dropout, attention_dropout=None):
|
| 140 |
+
for layer in self.conv_layers:
|
| 141 |
+
layer.dropout.p = dropout
|
onmt/decoders/decoder.py
ADDED
|
@@ -0,0 +1,405 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
from onmt.modules.stacked_rnn import StackedLSTM, StackedGRU
|
| 5 |
+
from onmt.modules import context_gate_factory, GlobalAttention
|
| 6 |
+
from onmt.utils.rnn_factory import rnn_factory
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class DecoderBase(nn.Module):
|
| 10 |
+
"""Abstract class for decoders.
|
| 11 |
+
|
| 12 |
+
Args:
|
| 13 |
+
attentional (bool): The decoder returns non-empty attention.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
def __init__(self, attentional=True):
|
| 17 |
+
super(DecoderBase, self).__init__()
|
| 18 |
+
self.attentional = attentional
|
| 19 |
+
|
| 20 |
+
@classmethod
|
| 21 |
+
def from_opt(cls, opt, embeddings):
|
| 22 |
+
"""Alternate constructor.
|
| 23 |
+
|
| 24 |
+
Subclasses should override this method.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
raise NotImplementedError
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class RNNDecoderBase(DecoderBase):
|
| 31 |
+
"""Base recurrent attention-based decoder class.
|
| 32 |
+
|
| 33 |
+
Specifies the interface used by different decoder types
|
| 34 |
+
and required by :class:`~onmt.models.NMTModel`.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
rnn_type (str):
|
| 38 |
+
style of recurrent unit to use, one of [RNN, LSTM, GRU, SRU]
|
| 39 |
+
bidirectional_encoder (bool) : use with a bidirectional encoder
|
| 40 |
+
num_layers (int) : number of stacked layers
|
| 41 |
+
hidden_size (int) : hidden size of each layer
|
| 42 |
+
attn_type (str) : see :class:`~onmt.modules.GlobalAttention`
|
| 43 |
+
attn_func (str) : see :class:`~onmt.modules.GlobalAttention`
|
| 44 |
+
coverage_attn (str): see :class:`~onmt.modules.GlobalAttention`
|
| 45 |
+
context_gate (str): see :class:`~onmt.modules.ContextGate`
|
| 46 |
+
copy_attn (bool): setup a separate copy attention mechanism
|
| 47 |
+
dropout (float) : dropout value for :class:`torch.nn.Dropout`
|
| 48 |
+
embeddings (onmt.modules.Embeddings): embedding module to use
|
| 49 |
+
reuse_copy_attn (bool): reuse the attention for copying
|
| 50 |
+
copy_attn_type (str): The copy attention style. See
|
| 51 |
+
:class:`~onmt.modules.GlobalAttention`.
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
def __init__(
|
| 55 |
+
self,
|
| 56 |
+
rnn_type,
|
| 57 |
+
bidirectional_encoder,
|
| 58 |
+
num_layers,
|
| 59 |
+
hidden_size,
|
| 60 |
+
attn_type="general",
|
| 61 |
+
attn_func="softmax",
|
| 62 |
+
coverage_attn=False,
|
| 63 |
+
context_gate=None,
|
| 64 |
+
copy_attn=False,
|
| 65 |
+
dropout=0.0,
|
| 66 |
+
embeddings=None,
|
| 67 |
+
reuse_copy_attn=False,
|
| 68 |
+
copy_attn_type="general",
|
| 69 |
+
):
|
| 70 |
+
super(RNNDecoderBase, self).__init__(
|
| 71 |
+
attentional=attn_type != "none" and attn_type is not None
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
self.bidirectional_encoder = bidirectional_encoder
|
| 75 |
+
self.num_layers = num_layers
|
| 76 |
+
self.hidden_size = hidden_size
|
| 77 |
+
self.embeddings = embeddings
|
| 78 |
+
self.dropout = nn.Dropout(dropout)
|
| 79 |
+
|
| 80 |
+
# Decoder state
|
| 81 |
+
self.state = {}
|
| 82 |
+
|
| 83 |
+
# Build the RNN.
|
| 84 |
+
self.rnn = self._build_rnn(
|
| 85 |
+
rnn_type,
|
| 86 |
+
input_size=self._input_size,
|
| 87 |
+
hidden_size=hidden_size,
|
| 88 |
+
num_layers=num_layers,
|
| 89 |
+
dropout=dropout,
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
# Set up the context gate.
|
| 93 |
+
self.context_gate = None
|
| 94 |
+
if context_gate is not None:
|
| 95 |
+
self.context_gate = context_gate_factory(
|
| 96 |
+
context_gate, self._input_size, hidden_size, hidden_size, hidden_size
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
# Set up the standard attention.
|
| 100 |
+
self._coverage = coverage_attn
|
| 101 |
+
if not self.attentional:
|
| 102 |
+
if self._coverage:
|
| 103 |
+
raise ValueError("Cannot use coverage term with no attention.")
|
| 104 |
+
self.attn = None
|
| 105 |
+
else:
|
| 106 |
+
self.attn = GlobalAttention(
|
| 107 |
+
hidden_size,
|
| 108 |
+
coverage=coverage_attn,
|
| 109 |
+
attn_type=attn_type,
|
| 110 |
+
attn_func=attn_func,
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
if copy_attn and not reuse_copy_attn:
|
| 114 |
+
if copy_attn_type == "none" or copy_attn_type is None:
|
| 115 |
+
raise ValueError("Cannot use copy_attn with copy_attn_type none")
|
| 116 |
+
self.copy_attn = GlobalAttention(
|
| 117 |
+
hidden_size, attn_type=copy_attn_type, attn_func=attn_func
|
| 118 |
+
)
|
| 119 |
+
else:
|
| 120 |
+
self.copy_attn = None
|
| 121 |
+
|
| 122 |
+
self._reuse_copy_attn = reuse_copy_attn and copy_attn
|
| 123 |
+
if self._reuse_copy_attn and not self.attentional:
|
| 124 |
+
raise ValueError("Cannot reuse copy attention with no attention.")
|
| 125 |
+
|
| 126 |
+
@classmethod
|
| 127 |
+
def from_opt(cls, opt, embeddings):
|
| 128 |
+
"""Alternate constructor."""
|
| 129 |
+
return cls(
|
| 130 |
+
opt.rnn_type,
|
| 131 |
+
opt.brnn,
|
| 132 |
+
opt.dec_layers,
|
| 133 |
+
opt.dec_hid_size,
|
| 134 |
+
opt.global_attention,
|
| 135 |
+
opt.global_attention_function,
|
| 136 |
+
opt.coverage_attn,
|
| 137 |
+
opt.context_gate,
|
| 138 |
+
opt.copy_attn,
|
| 139 |
+
opt.dropout[0] if type(opt.dropout) is list else opt.dropout,
|
| 140 |
+
embeddings,
|
| 141 |
+
opt.reuse_copy_attn,
|
| 142 |
+
opt.copy_attn_type,
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
def init_state(self, src, _, enc_final_hs):
|
| 146 |
+
"""Initialize decoder state with last state of the encoder."""
|
| 147 |
+
|
| 148 |
+
def _fix_enc_hidden(hidden):
|
| 149 |
+
# The encoder hidden is (layers*directions) x batch x dim.
|
| 150 |
+
# We need to convert it to layers x batch x (directions*dim).
|
| 151 |
+
if self.bidirectional_encoder:
|
| 152 |
+
hidden = torch.cat(
|
| 153 |
+
[hidden[0 : hidden.size(0) : 2], hidden[1 : hidden.size(0) : 2]], 2
|
| 154 |
+
)
|
| 155 |
+
return hidden
|
| 156 |
+
|
| 157 |
+
if isinstance(enc_final_hs, tuple): # LSTM
|
| 158 |
+
self.state["hidden"] = tuple(
|
| 159 |
+
_fix_enc_hidden(enc_hid) for enc_hid in enc_final_hs
|
| 160 |
+
)
|
| 161 |
+
else: # GRU
|
| 162 |
+
self.state["hidden"] = (_fix_enc_hidden(enc_final_hs),)
|
| 163 |
+
|
| 164 |
+
# Init the input feed.
|
| 165 |
+
batch_size = self.state["hidden"][0].size(1)
|
| 166 |
+
|
| 167 |
+
h_size = (batch_size, self.hidden_size)
|
| 168 |
+
self.state["input_feed"] = (
|
| 169 |
+
self.state["hidden"][0].data.new(*h_size).zero_().unsqueeze(0)
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
self.state["coverage"] = None
|
| 173 |
+
|
| 174 |
+
def map_state(self, fn):
|
| 175 |
+
self.state["hidden"] = tuple(fn(h, 1) for h in self.state["hidden"])
|
| 176 |
+
self.state["input_feed"] = fn(self.state["input_feed"], 1)
|
| 177 |
+
if self._coverage and self.state["coverage"] is not None:
|
| 178 |
+
self.state["coverage"] = fn(self.state["coverage"], 1)
|
| 179 |
+
|
| 180 |
+
def detach_state(self):
|
| 181 |
+
self.state["hidden"] = tuple(h.detach() for h in self.state["hidden"])
|
| 182 |
+
self.state["input_feed"] = self.state["input_feed"].detach()
|
| 183 |
+
if self._coverage and self.state["coverage"] is not None:
|
| 184 |
+
self.state["coverage"] = self.state["coverage"].detach()
|
| 185 |
+
|
| 186 |
+
def forward(self, tgt, enc_out, src_len=None, step=None, **kwargs):
|
| 187 |
+
"""
|
| 188 |
+
Args:
|
| 189 |
+
tgt (LongTensor): sequences of padded tokens
|
| 190 |
+
``(batch, tgt_len, nfeats)``.
|
| 191 |
+
enc_out (FloatTensor): vectors from the encoder
|
| 192 |
+
``(batch, src_len, hidden)``.
|
| 193 |
+
src_len (LongTensor): the padded source lengths
|
| 194 |
+
``(batch,)``.
|
| 195 |
+
|
| 196 |
+
Returns:
|
| 197 |
+
(FloatTensor, dict[str, FloatTensor]):
|
| 198 |
+
|
| 199 |
+
* dec_outs: output from the decoder (after attn)
|
| 200 |
+
``(batch, tgt_len, hidden)``.
|
| 201 |
+
* attns: distribution over src at each tgt
|
| 202 |
+
``(batch, tgt_len, src_len)``.
|
| 203 |
+
"""
|
| 204 |
+
dec_state, dec_outs, attns = self._run_forward_pass(
|
| 205 |
+
tgt, enc_out, src_len=src_len
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
# Update the state with the result.
|
| 209 |
+
if not isinstance(dec_state, tuple):
|
| 210 |
+
dec_state = (dec_state,)
|
| 211 |
+
self.state["hidden"] = dec_state
|
| 212 |
+
|
| 213 |
+
# Concatenates sequence of tensors along a new dimension.
|
| 214 |
+
# NOTE: v0.3 to 0.4: dec_outs / attns[*] may not be list
|
| 215 |
+
# (in particular in case of SRU) it was not raising error in 0.3
|
| 216 |
+
# since stack(Variable) was allowed.
|
| 217 |
+
# In 0.4, SRU returns a tensor that shouldn't be stacke
|
| 218 |
+
if type(dec_outs) == list:
|
| 219 |
+
dec_outs = torch.stack(dec_outs, dim=1)
|
| 220 |
+
for k in attns:
|
| 221 |
+
if type(attns[k]) == list:
|
| 222 |
+
attns[k] = torch.stack(attns[k])
|
| 223 |
+
|
| 224 |
+
self.state["input_feed"] = dec_outs[:, -1, :].unsqueeze(0)
|
| 225 |
+
self.state["coverage"] = None
|
| 226 |
+
if "coverage" in attns:
|
| 227 |
+
self.state["coverage"] = attns["coverage"][-1, :, :].unsqueeze(0)
|
| 228 |
+
|
| 229 |
+
return dec_outs, attns
|
| 230 |
+
|
| 231 |
+
def update_dropout(self, dropout, attention_dropout=None):
|
| 232 |
+
self.dropout.p = dropout
|
| 233 |
+
self.embeddings.update_dropout(dropout)
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
class StdRNNDecoder(RNNDecoderBase):
|
| 237 |
+
"""Standard fully batched RNN decoder with attention.
|
| 238 |
+
|
| 239 |
+
Faster implementation, uses CuDNN for implementation.
|
| 240 |
+
See :class:`~onmt.decoders.decoder.RNNDecoderBase` for options.
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
Based around the approach from
|
| 244 |
+
"Neural Machine Translation By Jointly Learning To Align and Translate"
|
| 245 |
+
:cite:`Bahdanau2015`
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
Implemented without input_feeding and currently with no `coverage_attn`
|
| 249 |
+
or `copy_attn` support.
|
| 250 |
+
"""
|
| 251 |
+
|
| 252 |
+
def _run_forward_pass(self, tgt, enc_out, src_len=None):
|
| 253 |
+
"""
|
| 254 |
+
Private helper for running the specific RNN forward pass.
|
| 255 |
+
Must be overriden by all subclasses.
|
| 256 |
+
|
| 257 |
+
Args:
|
| 258 |
+
tgt (LongTensor): a sequence of input tokens tensors
|
| 259 |
+
``(batch, tgt_len, nfeats)``.
|
| 260 |
+
enc_out (FloatTensor): output(tensor sequence) from the
|
| 261 |
+
encoder RNN of size ``(batch, src_len, hidden_size)``.
|
| 262 |
+
src_len (LongTensor): the source enc_out lengths.
|
| 263 |
+
|
| 264 |
+
Returns:
|
| 265 |
+
(Tensor, List[FloatTensor], Dict[str, List[FloatTensor]):
|
| 266 |
+
|
| 267 |
+
* dec_state: final hidden state from the decoder.
|
| 268 |
+
* dec_outs: an array of output of every time
|
| 269 |
+
step from the decoder.
|
| 270 |
+
* attns: a dictionary of different
|
| 271 |
+
type of attention Tensor array of every time
|
| 272 |
+
step from the decoder.
|
| 273 |
+
"""
|
| 274 |
+
|
| 275 |
+
assert self.copy_attn is None # TODO, no support yet.
|
| 276 |
+
assert not self._coverage # TODO, no support yet.
|
| 277 |
+
|
| 278 |
+
attns = {}
|
| 279 |
+
emb = self.embeddings(tgt)
|
| 280 |
+
|
| 281 |
+
if isinstance(self.rnn, nn.GRU):
|
| 282 |
+
rnn_out, dec_state = self.rnn(emb, self.state["hidden"][0])
|
| 283 |
+
else:
|
| 284 |
+
rnn_out, dec_state = self.rnn(emb, self.state["hidden"])
|
| 285 |
+
|
| 286 |
+
tgt_batch, tgt_len, _ = tgt.size()
|
| 287 |
+
|
| 288 |
+
# Calculate the attention.
|
| 289 |
+
if not self.attentional:
|
| 290 |
+
dec_outs = rnn_out
|
| 291 |
+
else:
|
| 292 |
+
dec_outs, p_attn = self.attn(rnn_out, enc_out, src_len=src_len)
|
| 293 |
+
attns["std"] = p_attn
|
| 294 |
+
|
| 295 |
+
# Calculate the context gate.
|
| 296 |
+
if self.context_gate is not None:
|
| 297 |
+
dec_outs = self.context_gate(
|
| 298 |
+
emb.view(-1, emb.size(2)),
|
| 299 |
+
rnn_out.view(-1, rnn_out.size(2)),
|
| 300 |
+
dec_outs.view(-1, dec_outs.size(2)),
|
| 301 |
+
)
|
| 302 |
+
dec_outs = dec_outs.view(tgt_batch, tgt_len, self.hidden_size)
|
| 303 |
+
|
| 304 |
+
dec_outs = self.dropout(dec_outs)
|
| 305 |
+
|
| 306 |
+
return dec_state, dec_outs, attns
|
| 307 |
+
|
| 308 |
+
def _build_rnn(self, rnn_type, **kwargs):
|
| 309 |
+
rnn, _ = rnn_factory(rnn_type, **kwargs)
|
| 310 |
+
return rnn
|
| 311 |
+
|
| 312 |
+
@property
|
| 313 |
+
def _input_size(self):
|
| 314 |
+
return self.embeddings.embedding_size
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
class InputFeedRNNDecoder(RNNDecoderBase):
|
| 318 |
+
"""Input feeding based decoder.
|
| 319 |
+
|
| 320 |
+
See :class:`~onmt.decoders.decoder.RNNDecoderBase` for options.
|
| 321 |
+
|
| 322 |
+
Based around the input feeding approach from
|
| 323 |
+
"Effective Approaches to Attention-based Neural Machine Translation"
|
| 324 |
+
:cite:`Luong2015`
|
| 325 |
+
|
| 326 |
+
"""
|
| 327 |
+
|
| 328 |
+
def _run_forward_pass(self, tgt, enc_out, src_len=None):
|
| 329 |
+
"""
|
| 330 |
+
See StdRNNDecoder._run_forward_pass() for description
|
| 331 |
+
of arguments and return values.
|
| 332 |
+
"""
|
| 333 |
+
# Additional args check.
|
| 334 |
+
input_feed = self.state["input_feed"].squeeze(0)
|
| 335 |
+
|
| 336 |
+
dec_outs = []
|
| 337 |
+
attns = {}
|
| 338 |
+
if self.attn is not None:
|
| 339 |
+
attns["std"] = []
|
| 340 |
+
if self.copy_attn is not None or self._reuse_copy_attn:
|
| 341 |
+
attns["copy"] = []
|
| 342 |
+
if self._coverage:
|
| 343 |
+
attns["coverage"] = []
|
| 344 |
+
|
| 345 |
+
emb = self.embeddings(tgt)
|
| 346 |
+
assert emb.dim() == 3 # batch x len x embedding_dim
|
| 347 |
+
|
| 348 |
+
dec_state = self.state["hidden"]
|
| 349 |
+
|
| 350 |
+
coverage = (
|
| 351 |
+
self.state["coverage"].squeeze(0)
|
| 352 |
+
if self.state["coverage"] is not None
|
| 353 |
+
else None
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
# Input feed concatenates hidden state with
|
| 357 |
+
# input at every time step.
|
| 358 |
+
for emb_t in emb.split(1, dim=1):
|
| 359 |
+
dec_in = torch.cat([emb_t.squeeze(1), input_feed], 1)
|
| 360 |
+
rnn_out, dec_state = self.rnn(dec_in, dec_state)
|
| 361 |
+
if self.attentional:
|
| 362 |
+
dec_out, p_attn = self.attn(rnn_out, enc_out, src_len=src_len)
|
| 363 |
+
attns["std"].append(p_attn)
|
| 364 |
+
else:
|
| 365 |
+
dec_out = rnn_out
|
| 366 |
+
if self.context_gate is not None:
|
| 367 |
+
# TODO: context gate should be employed
|
| 368 |
+
# instead of second RNN transform.
|
| 369 |
+
dec_out = self.context_gate(dec_in, rnn_out, dec_out)
|
| 370 |
+
dec_out = self.dropout(dec_out)
|
| 371 |
+
input_feed = dec_out
|
| 372 |
+
|
| 373 |
+
dec_outs += [dec_out]
|
| 374 |
+
|
| 375 |
+
# Update the coverage attention.
|
| 376 |
+
# attns["coverage"] is actually c^(t+1) of See et al(2017)
|
| 377 |
+
# 1-index shifted
|
| 378 |
+
if self._coverage:
|
| 379 |
+
coverage = p_attn if coverage is None else p_attn + coverage
|
| 380 |
+
attns["coverage"] += [coverage]
|
| 381 |
+
|
| 382 |
+
if self.copy_attn is not None:
|
| 383 |
+
_, copy_attn = self.copy_attn(dec_out, enc_out)
|
| 384 |
+
attns["copy"] += [copy_attn]
|
| 385 |
+
elif self._reuse_copy_attn:
|
| 386 |
+
attns["copy"] = attns["std"]
|
| 387 |
+
|
| 388 |
+
return dec_state, dec_outs, attns
|
| 389 |
+
|
| 390 |
+
def _build_rnn(self, rnn_type, input_size, hidden_size, num_layers, dropout):
|
| 391 |
+
assert rnn_type != "SRU", (
|
| 392 |
+
"SRU doesn't support input feed! " "Please set -input_feed 0!"
|
| 393 |
+
)
|
| 394 |
+
stacked_cell = StackedLSTM if rnn_type == "LSTM" else StackedGRU
|
| 395 |
+
return stacked_cell(num_layers, input_size, hidden_size, dropout)
|
| 396 |
+
|
| 397 |
+
@property
|
| 398 |
+
def _input_size(self):
|
| 399 |
+
"""Using input feed by concatenating input with attention vectors."""
|
| 400 |
+
return self.embeddings.embedding_size + self.hidden_size
|
| 401 |
+
|
| 402 |
+
def update_dropout(self, dropout, attention_dropout=None):
|
| 403 |
+
self.dropout.p = dropout
|
| 404 |
+
self.rnn.dropout.p = dropout
|
| 405 |
+
self.embeddings.update_dropout(dropout)
|
onmt/decoders/ensemble.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Ensemble decoding.
|
| 2 |
+
|
| 3 |
+
Decodes using multiple models simultaneously,
|
| 4 |
+
combining their prediction distributions by averaging.
|
| 5 |
+
All models in the ensemble must share a target vocabulary.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
|
| 11 |
+
from onmt.encoders.encoder import EncoderBase
|
| 12 |
+
from onmt.decoders.decoder import DecoderBase
|
| 13 |
+
from onmt.models import NMTModel
|
| 14 |
+
import onmt.model_builder
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class EnsembleDecoderOutput(object):
|
| 18 |
+
"""Wrapper around multiple decoder final hidden states."""
|
| 19 |
+
|
| 20 |
+
def __init__(self, model_dec_outs):
|
| 21 |
+
self.model_dec_outs = tuple(model_dec_outs)
|
| 22 |
+
|
| 23 |
+
def squeeze(self, dim=None):
|
| 24 |
+
"""Delegate squeeze to avoid modifying
|
| 25 |
+
:func:`onmt.translate.translator.Translator.translate_batch()`
|
| 26 |
+
"""
|
| 27 |
+
return EnsembleDecoderOutput([x.squeeze(dim) for x in self.model_dec_outs])
|
| 28 |
+
|
| 29 |
+
def __getitem__(self, index):
|
| 30 |
+
return self.model_dec_outs[index]
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class EnsembleEncoder(EncoderBase):
|
| 34 |
+
"""Dummy Encoder that delegates to individual real Encoders."""
|
| 35 |
+
|
| 36 |
+
def __init__(self, model_encoders):
|
| 37 |
+
super(EnsembleEncoder, self).__init__()
|
| 38 |
+
self.model_encoders = nn.ModuleList(model_encoders)
|
| 39 |
+
|
| 40 |
+
def forward(self, src, src_len=None):
|
| 41 |
+
enc_out, enc_final_hs, _ = zip(
|
| 42 |
+
*[model_encoder(src, src_len) for model_encoder in self.model_encoders]
|
| 43 |
+
)
|
| 44 |
+
return enc_out, enc_final_hs, src_len
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class EnsembleDecoder(DecoderBase):
|
| 48 |
+
"""Dummy Decoder that delegates to individual real Decoders."""
|
| 49 |
+
|
| 50 |
+
def __init__(self, model_decoders):
|
| 51 |
+
model_decoders = nn.ModuleList(model_decoders)
|
| 52 |
+
attentional = any([dec.attentional for dec in model_decoders])
|
| 53 |
+
super(EnsembleDecoder, self).__init__(attentional)
|
| 54 |
+
self.model_decoders = model_decoders
|
| 55 |
+
|
| 56 |
+
def forward(self, tgt, enc_out, src_len=None, step=None, **kwargs):
|
| 57 |
+
"""See :func:`onmt.decoders.decoder.DecoderBase.forward()`."""
|
| 58 |
+
# src_len is a single tensor shared between all models.
|
| 59 |
+
# This assumption will not hold if Translator is modified
|
| 60 |
+
# to calculate src_len as something other than the length
|
| 61 |
+
# of the input.
|
| 62 |
+
dec_outs, attns = zip(
|
| 63 |
+
*[
|
| 64 |
+
model_decoder(tgt, enc_out[i], src_len=src_len, step=step, **kwargs)
|
| 65 |
+
for i, model_decoder in enumerate(self.model_decoders)
|
| 66 |
+
]
|
| 67 |
+
)
|
| 68 |
+
mean_attns = self.combine_attns(attns)
|
| 69 |
+
return EnsembleDecoderOutput(dec_outs), mean_attns
|
| 70 |
+
|
| 71 |
+
def combine_attns(self, attns):
|
| 72 |
+
result = {}
|
| 73 |
+
for key in attns[0].keys():
|
| 74 |
+
result[key] = torch.stack(
|
| 75 |
+
[attn[key] for attn in attns if attn[key] is not None]
|
| 76 |
+
).mean(0)
|
| 77 |
+
return result
|
| 78 |
+
|
| 79 |
+
def init_state(self, src, enc_out, enc_hidden):
|
| 80 |
+
"""See :obj:`RNNDecoderBase.init_state()`"""
|
| 81 |
+
for i, model_decoder in enumerate(self.model_decoders):
|
| 82 |
+
model_decoder.init_state(src, enc_out[i], enc_hidden[i])
|
| 83 |
+
|
| 84 |
+
def map_state(self, fn):
|
| 85 |
+
for model_decoder in self.model_decoders:
|
| 86 |
+
model_decoder.map_state(fn)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class EnsembleGenerator(nn.Module):
|
| 90 |
+
"""
|
| 91 |
+
Dummy Generator that delegates to individual real Generators,
|
| 92 |
+
and then averages the resulting target distributions.
|
| 93 |
+
"""
|
| 94 |
+
|
| 95 |
+
def __init__(self, model_generators, raw_probs=False):
|
| 96 |
+
super(EnsembleGenerator, self).__init__()
|
| 97 |
+
self.model_generators = nn.ModuleList(model_generators)
|
| 98 |
+
self._raw_probs = raw_probs
|
| 99 |
+
|
| 100 |
+
def forward(self, hidden, attn=None, src_map=None):
|
| 101 |
+
"""
|
| 102 |
+
Compute a distribution over the target dictionary
|
| 103 |
+
by averaging distributions from models in the ensemble.
|
| 104 |
+
All models in the ensemble must share a target vocabulary.
|
| 105 |
+
"""
|
| 106 |
+
distributions = torch.stack(
|
| 107 |
+
[
|
| 108 |
+
mg(h) if attn is None else mg(h, attn, src_map)
|
| 109 |
+
for h, mg in zip(hidden, self.model_generators)
|
| 110 |
+
]
|
| 111 |
+
)
|
| 112 |
+
if self._raw_probs:
|
| 113 |
+
return torch.log(torch.exp(distributions).mean(0))
|
| 114 |
+
else:
|
| 115 |
+
return distributions.mean(0)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class EnsembleModel(NMTModel):
|
| 119 |
+
"""Dummy NMTModel wrapping individual real NMTModels."""
|
| 120 |
+
|
| 121 |
+
def __init__(self, models, raw_probs=False):
|
| 122 |
+
encoder = EnsembleEncoder(model.encoder for model in models)
|
| 123 |
+
decoder = EnsembleDecoder(model.decoder for model in models)
|
| 124 |
+
super(EnsembleModel, self).__init__(encoder, decoder)
|
| 125 |
+
self.generator = EnsembleGenerator(
|
| 126 |
+
[model.generator for model in models], raw_probs
|
| 127 |
+
)
|
| 128 |
+
self.models = nn.ModuleList(models)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def load_test_model(opt, device_id=0):
|
| 132 |
+
"""Read in multiple models for ensemble."""
|
| 133 |
+
shared_vocabs = None
|
| 134 |
+
shared_model_opt = None
|
| 135 |
+
models = []
|
| 136 |
+
for model_path in opt.models:
|
| 137 |
+
vocabs, model, model_opt = onmt.model_builder.load_test_model(
|
| 138 |
+
opt, device_id, model_path=model_path
|
| 139 |
+
)
|
| 140 |
+
if shared_vocabs is None:
|
| 141 |
+
shared_vocabs = vocabs
|
| 142 |
+
else:
|
| 143 |
+
assert (
|
| 144 |
+
shared_vocabs["src"].tokens_to_ids == vocabs["src"].tokens_to_ids
|
| 145 |
+
), "Ensemble models must use the same vocabs "
|
| 146 |
+
models.append(model)
|
| 147 |
+
if shared_model_opt is None:
|
| 148 |
+
shared_model_opt = model_opt
|
| 149 |
+
ensemble_model = EnsembleModel(models, opt.avg_raw_probs)
|
| 150 |
+
return shared_vocabs, ensemble_model, shared_model_opt
|
onmt/decoders/transformer.py
ADDED
|
@@ -0,0 +1,835 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Implementation of "Attention is All You Need" and of
|
| 3 |
+
subsequent transformer based architectures
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from onmt.decoders.decoder import DecoderBase
|
| 9 |
+
from onmt.modules import MultiHeadedAttention, AverageAttention
|
| 10 |
+
from onmt.modules.position_ffn import PositionwiseFeedForward
|
| 11 |
+
from onmt.modules.position_ffn import ActivationFunction
|
| 12 |
+
from onmt.utils.misc import sequence_mask
|
| 13 |
+
from onmt.modules.rmsnorm import RMSNorm
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class TransformerDecoderLayerBase(nn.Module):
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
d_model,
|
| 20 |
+
heads,
|
| 21 |
+
d_ff,
|
| 22 |
+
dropout,
|
| 23 |
+
attention_dropout,
|
| 24 |
+
self_attn_type="scaled-dot",
|
| 25 |
+
max_relative_positions=0,
|
| 26 |
+
relative_positions_buckets=0,
|
| 27 |
+
aan_useffn=False,
|
| 28 |
+
full_context_alignment=False,
|
| 29 |
+
alignment_heads=0,
|
| 30 |
+
pos_ffn_activation_fn=ActivationFunction.relu,
|
| 31 |
+
add_qkvbias=False,
|
| 32 |
+
num_kv=0,
|
| 33 |
+
add_ffnbias=True,
|
| 34 |
+
parallel_residual=False,
|
| 35 |
+
shared_layer_norm=False,
|
| 36 |
+
layer_norm="standard",
|
| 37 |
+
norm_eps=1e-6,
|
| 38 |
+
use_ckpting=[],
|
| 39 |
+
parallel_gpu=1,
|
| 40 |
+
):
|
| 41 |
+
"""
|
| 42 |
+
Args:
|
| 43 |
+
d_model (int): the dimension of keys/values/queries in
|
| 44 |
+
:class:`MultiHeadedAttention`, also the input size of
|
| 45 |
+
the first-layer of the :class:`PositionwiseFeedForward`.
|
| 46 |
+
heads (int): the number of heads for MultiHeadedAttention.
|
| 47 |
+
d_ff (int): the second-layer of the
|
| 48 |
+
:class:`PositionwiseFeedForward`.
|
| 49 |
+
dropout (float): dropout in residual, self-attn(dot) and
|
| 50 |
+
feed-forward
|
| 51 |
+
attention_dropout (float): dropout in context_attn (and
|
| 52 |
+
self-attn(avg))
|
| 53 |
+
self_attn_type (string): type of self-attention scaled-dot,
|
| 54 |
+
average
|
| 55 |
+
max_relative_positions (int):
|
| 56 |
+
Max distance between inputs in relative positions
|
| 57 |
+
representations
|
| 58 |
+
aan_useffn (bool): Turn on the FFN layer in the AAN decoder
|
| 59 |
+
full_context_alignment (bool):
|
| 60 |
+
whether enable an extra full context decoder forward for
|
| 61 |
+
alignment
|
| 62 |
+
alignment_heads (int):
|
| 63 |
+
N. of cross attention heads to use for alignment guiding
|
| 64 |
+
pos_ffn_activation_fn (ActivationFunction):
|
| 65 |
+
activation function choice for PositionwiseFeedForward layer
|
| 66 |
+
add_qkvbias (bool): whether to add bias to the Key/Value nn.Linear
|
| 67 |
+
layer_norm (string): type of layer normalization standard/rms
|
| 68 |
+
norm_eps (float): layer norm epsilon
|
| 69 |
+
|
| 70 |
+
"""
|
| 71 |
+
super(TransformerDecoderLayerBase, self).__init__()
|
| 72 |
+
|
| 73 |
+
self.self_attn_type = self_attn_type
|
| 74 |
+
if self_attn_type == "scaled-dot":
|
| 75 |
+
self.self_attn = MultiHeadedAttention(
|
| 76 |
+
heads,
|
| 77 |
+
d_model,
|
| 78 |
+
dropout=attention_dropout,
|
| 79 |
+
max_relative_positions=max_relative_positions,
|
| 80 |
+
relative_positions_buckets=relative_positions_buckets,
|
| 81 |
+
attn_type="self",
|
| 82 |
+
add_qkvbias=add_qkvbias,
|
| 83 |
+
num_kv=num_kv,
|
| 84 |
+
use_ckpting=use_ckpting,
|
| 85 |
+
parallel_gpu=parallel_gpu,
|
| 86 |
+
)
|
| 87 |
+
elif self_attn_type == "average":
|
| 88 |
+
self.self_attn = AverageAttention(
|
| 89 |
+
d_model, dropout=attention_dropout, aan_useffn=aan_useffn
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
self.feed_forward = PositionwiseFeedForward(
|
| 93 |
+
d_model,
|
| 94 |
+
d_ff,
|
| 95 |
+
dropout,
|
| 96 |
+
pos_ffn_activation_fn,
|
| 97 |
+
add_ffnbias,
|
| 98 |
+
parallel_residual,
|
| 99 |
+
layer_norm,
|
| 100 |
+
norm_eps,
|
| 101 |
+
use_ckpting=use_ckpting,
|
| 102 |
+
parallel_gpu=parallel_gpu,
|
| 103 |
+
)
|
| 104 |
+
self.parallel_residual = parallel_residual
|
| 105 |
+
self.shared_layer_norm = shared_layer_norm
|
| 106 |
+
if layer_norm == "standard":
|
| 107 |
+
self.layer_norm_1 = nn.LayerNorm(d_model, eps=norm_eps)
|
| 108 |
+
if parallel_residual and not shared_layer_norm:
|
| 109 |
+
self.layer_norm_res = nn.LayerNorm(d_model, eps=norm_eps)
|
| 110 |
+
elif layer_norm == "rms":
|
| 111 |
+
self.layer_norm_1 = RMSNorm(d_model, eps=norm_eps)
|
| 112 |
+
if parallel_residual and not shared_layer_norm:
|
| 113 |
+
self.layer_norm_res = RMSNorm(d_model, eps=norm_eps)
|
| 114 |
+
else:
|
| 115 |
+
raise ValueError(f"{layer_norm} layer norm type is not supported")
|
| 116 |
+
|
| 117 |
+
self.dropout = nn.Dropout(dropout)
|
| 118 |
+
self.full_context_alignment = full_context_alignment
|
| 119 |
+
self.alignment_heads = alignment_heads
|
| 120 |
+
|
| 121 |
+
def forward(self, *args, **kwargs):
|
| 122 |
+
"""Extend `_forward` for (possibly) multiple decoder pass:
|
| 123 |
+
Always a default (future masked) decoder forward pass,
|
| 124 |
+
Possibly a second future aware decoder pass for joint learn
|
| 125 |
+
full context alignement, :cite:`garg2019jointly`.
|
| 126 |
+
|
| 127 |
+
Args:
|
| 128 |
+
* All arguments of _forward, of which
|
| 129 |
+
with_align (bool): needed to compute attn_align
|
| 130 |
+
return_attn (bool): to force MHA to return attns
|
| 131 |
+
|
| 132 |
+
Returns:
|
| 133 |
+
(FloatTensor, FloatTensor, FloatTensor or None):
|
| 134 |
+
|
| 135 |
+
* layer_out ``(batch_size, T, model_dim)``
|
| 136 |
+
* top_attn ``(batch_size, T, src_len)``
|
| 137 |
+
* attn_align ``(batch_size, T, src_len)`` or None
|
| 138 |
+
"""
|
| 139 |
+
with_align = kwargs.pop("with_align", False)
|
| 140 |
+
layer_out, attns = self._forward(*args, **kwargs)
|
| 141 |
+
top_attn = None if attns is None else attns[:, 0, :, :].contiguous()
|
| 142 |
+
attn_align = None
|
| 143 |
+
if with_align:
|
| 144 |
+
if self.full_context_alignment:
|
| 145 |
+
# return _, (B, Q_len, K_len)
|
| 146 |
+
_, attns = self._forward(*args, **kwargs, future=True)
|
| 147 |
+
|
| 148 |
+
if self.alignment_heads > 0:
|
| 149 |
+
attns = attns[:, : self.alignment_heads, :, :].contiguous()
|
| 150 |
+
# layer average attention across heads, get ``(B, Q, K)``
|
| 151 |
+
# Case 1: no full_context, no align heads -> layer avg baseline
|
| 152 |
+
# Case 2: no full_context, 1 align heads -> guided align
|
| 153 |
+
# Case 3: full_context, 1 align heads -> full cte guided align
|
| 154 |
+
attn_align = attns.mean(dim=1)
|
| 155 |
+
return layer_out, top_attn, attn_align
|
| 156 |
+
|
| 157 |
+
def update_dropout(self, dropout, attention_dropout):
|
| 158 |
+
self.self_attn.update_dropout(attention_dropout)
|
| 159 |
+
self.feed_forward.update_dropout(dropout)
|
| 160 |
+
self.dropout.p = dropout
|
| 161 |
+
|
| 162 |
+
def _forward(self, *args, **kwargs):
|
| 163 |
+
raise NotImplementedError
|
| 164 |
+
|
| 165 |
+
def _compute_dec_mask(self, tgt_pad_mask, future):
|
| 166 |
+
tgt_len = tgt_pad_mask.size(-1)
|
| 167 |
+
if not future: # apply future_mask, result mask in (B, T, T)
|
| 168 |
+
future_mask = torch.ones(
|
| 169 |
+
[tgt_len, tgt_len],
|
| 170 |
+
device=tgt_pad_mask.device,
|
| 171 |
+
dtype=torch.uint8,
|
| 172 |
+
)
|
| 173 |
+
future_mask = future_mask.triu_(1).view(1, tgt_len, tgt_len)
|
| 174 |
+
# BoolTensor was introduced in pytorch 1.2
|
| 175 |
+
try:
|
| 176 |
+
future_mask = future_mask.bool()
|
| 177 |
+
except AttributeError:
|
| 178 |
+
pass
|
| 179 |
+
dec_mask = torch.gt(tgt_pad_mask + future_mask, 0)
|
| 180 |
+
else: # only mask padding, result mask in (B, 1, T)
|
| 181 |
+
dec_mask = tgt_pad_mask
|
| 182 |
+
return dec_mask
|
| 183 |
+
|
| 184 |
+
def _forward_self_attn(self, norm_layer_in, dec_mask, step, return_attn=False):
|
| 185 |
+
if self.self_attn_type == "scaled-dot":
|
| 186 |
+
return self.self_attn(
|
| 187 |
+
norm_layer_in,
|
| 188 |
+
norm_layer_in,
|
| 189 |
+
norm_layer_in,
|
| 190 |
+
mask=dec_mask,
|
| 191 |
+
step=step,
|
| 192 |
+
return_attn=return_attn,
|
| 193 |
+
)
|
| 194 |
+
elif self.self_attn_type == "average":
|
| 195 |
+
return self.self_attn(norm_layer_in, mask=dec_mask, step=step)
|
| 196 |
+
else:
|
| 197 |
+
raise ValueError(f"self attention {type(self.self_attn)} not supported")
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
class TransformerDecoderLayer(TransformerDecoderLayerBase):
|
| 201 |
+
"""Transformer Decoder layer block in Pre-Norm style.
|
| 202 |
+
Pre-Norm style is an improvement w.r.t. Original paper's Post-Norm style,
|
| 203 |
+
providing better converge speed and performance. This is also the actual
|
| 204 |
+
implementation in tensor2tensor and also avalable in fairseq.
|
| 205 |
+
See https://tunz.kr/post/4 and :cite:`DeeperTransformer`.
|
| 206 |
+
|
| 207 |
+
"""
|
| 208 |
+
|
| 209 |
+
def __init__(
|
| 210 |
+
self,
|
| 211 |
+
d_model,
|
| 212 |
+
heads,
|
| 213 |
+
d_ff,
|
| 214 |
+
dropout,
|
| 215 |
+
attention_dropout,
|
| 216 |
+
self_attn_type="scaled-dot",
|
| 217 |
+
max_relative_positions=0,
|
| 218 |
+
relative_positions_buckets=0,
|
| 219 |
+
aan_useffn=False,
|
| 220 |
+
full_context_alignment=False,
|
| 221 |
+
alignment_heads=0,
|
| 222 |
+
pos_ffn_activation_fn=ActivationFunction.relu,
|
| 223 |
+
add_qkvbias=False,
|
| 224 |
+
num_kv=0,
|
| 225 |
+
add_ffnbias=True,
|
| 226 |
+
parallel_residual=False,
|
| 227 |
+
shared_layer_norm=False,
|
| 228 |
+
layer_norm="standard",
|
| 229 |
+
norm_eps=1e-6,
|
| 230 |
+
use_ckpting=[],
|
| 231 |
+
parallel_gpu=1,
|
| 232 |
+
):
|
| 233 |
+
"""
|
| 234 |
+
Args:
|
| 235 |
+
See TransformerDecoderLayerBase
|
| 236 |
+
"""
|
| 237 |
+
super(TransformerDecoderLayer, self).__init__(
|
| 238 |
+
d_model,
|
| 239 |
+
heads,
|
| 240 |
+
d_ff,
|
| 241 |
+
dropout,
|
| 242 |
+
attention_dropout,
|
| 243 |
+
self_attn_type,
|
| 244 |
+
max_relative_positions,
|
| 245 |
+
relative_positions_buckets,
|
| 246 |
+
aan_useffn,
|
| 247 |
+
full_context_alignment,
|
| 248 |
+
alignment_heads,
|
| 249 |
+
pos_ffn_activation_fn=pos_ffn_activation_fn,
|
| 250 |
+
add_qkvbias=add_qkvbias,
|
| 251 |
+
num_kv=num_kv,
|
| 252 |
+
add_ffnbias=add_ffnbias,
|
| 253 |
+
parallel_residual=parallel_residual,
|
| 254 |
+
shared_layer_norm=shared_layer_norm,
|
| 255 |
+
layer_norm=layer_norm,
|
| 256 |
+
norm_eps=norm_eps,
|
| 257 |
+
use_ckpting=use_ckpting,
|
| 258 |
+
parallel_gpu=parallel_gpu,
|
| 259 |
+
)
|
| 260 |
+
self.context_attn = MultiHeadedAttention(
|
| 261 |
+
heads,
|
| 262 |
+
d_model,
|
| 263 |
+
dropout=attention_dropout,
|
| 264 |
+
attn_type="context",
|
| 265 |
+
add_qkvbias=add_qkvbias,
|
| 266 |
+
num_kv=num_kv,
|
| 267 |
+
use_ckpting=use_ckpting,
|
| 268 |
+
parallel_gpu=parallel_gpu,
|
| 269 |
+
)
|
| 270 |
+
if layer_norm == "standard":
|
| 271 |
+
self.layer_norm_2 = nn.LayerNorm(d_model, eps=norm_eps)
|
| 272 |
+
elif layer_norm == "rms":
|
| 273 |
+
self.layer_norm_2 = RMSNorm(d_model, eps=norm_eps)
|
| 274 |
+
else:
|
| 275 |
+
raise ValueError(f"{layer_norm} layer norm type is not supported")
|
| 276 |
+
|
| 277 |
+
def update_dropout(self, dropout, attention_dropout):
|
| 278 |
+
super(TransformerDecoderLayer, self).update_dropout(dropout, attention_dropout)
|
| 279 |
+
self.context_attn.update_dropout(attention_dropout)
|
| 280 |
+
|
| 281 |
+
def _forward(
|
| 282 |
+
self,
|
| 283 |
+
layer_in,
|
| 284 |
+
enc_out,
|
| 285 |
+
src_pad_mask,
|
| 286 |
+
tgt_pad_mask,
|
| 287 |
+
step=None,
|
| 288 |
+
future=False,
|
| 289 |
+
return_attn=False,
|
| 290 |
+
):
|
| 291 |
+
"""A naive forward pass for transformer decoder.
|
| 292 |
+
|
| 293 |
+
# T: could be 1 in the case of stepwise decoding or tgt_len
|
| 294 |
+
|
| 295 |
+
Args:
|
| 296 |
+
layer_in (FloatTensor): ``(batch_size, T, model_dim)``
|
| 297 |
+
enc_out (FloatTensor): ``(batch_size, src_len, model_dim)``
|
| 298 |
+
src_pad_mask (bool): ``(batch_size, 1, src_len)``
|
| 299 |
+
tgt_pad_mask (bool): ``(batch_size, 1, T)``
|
| 300 |
+
step (int or None): stepwise decoding counter
|
| 301 |
+
future (bool): If set True, do not apply future_mask.
|
| 302 |
+
return_attn (bool) : if set True requires attns output
|
| 303 |
+
|
| 304 |
+
Returns:
|
| 305 |
+
(FloatTensor, FloatTensor):
|
| 306 |
+
|
| 307 |
+
* layer_out ``(batch_size, T, model_dim)``
|
| 308 |
+
* attns ``(batch_size, head, T, src_len)``
|
| 309 |
+
|
| 310 |
+
"""
|
| 311 |
+
dec_mask = None
|
| 312 |
+
src_pad_mask = src_pad_mask.unsqueeze(1) # [B,1,1,slen]
|
| 313 |
+
|
| 314 |
+
if layer_in.size(1) > 1:
|
| 315 |
+
# masking is necessary when sequence length is greater than one
|
| 316 |
+
dec_mask = self._compute_dec_mask(tgt_pad_mask, future)
|
| 317 |
+
dec_mask = dec_mask.unsqueeze(1)
|
| 318 |
+
dec_mask = dec_mask.expand(-1, -1, dec_mask.size(3), -1)
|
| 319 |
+
src_pad_mask = src_pad_mask.expand(-1, -1, dec_mask.size(3), -1)
|
| 320 |
+
# mask now are (batch x 1 x tlen x s or t len)
|
| 321 |
+
# 1 = heads to be expanded in MHA
|
| 322 |
+
|
| 323 |
+
norm_layer_in = self.layer_norm_1(layer_in)
|
| 324 |
+
|
| 325 |
+
self_attn, _ = self._forward_self_attn(norm_layer_in, dec_mask, step)
|
| 326 |
+
|
| 327 |
+
if self.parallel_residual:
|
| 328 |
+
ctx_attn, attns = self.context_attn(
|
| 329 |
+
enc_out,
|
| 330 |
+
enc_out,
|
| 331 |
+
norm_layer_in,
|
| 332 |
+
mask=src_pad_mask,
|
| 333 |
+
return_attn=return_attn,
|
| 334 |
+
)
|
| 335 |
+
# feed_forward applies residual, so we remove and apply residual with un-normed
|
| 336 |
+
layer_out = (
|
| 337 |
+
self.feed_forward(norm_layer_in)
|
| 338 |
+
- norm_layer_in
|
| 339 |
+
+ layer_in
|
| 340 |
+
+ self.dropout(self_attn)
|
| 341 |
+
+ ctx_attn
|
| 342 |
+
)
|
| 343 |
+
else:
|
| 344 |
+
query = self.dropout(self_attn) + layer_in
|
| 345 |
+
norm_query = self.layer_norm_2(query)
|
| 346 |
+
ctx_attn, attns = self.context_attn(
|
| 347 |
+
enc_out, enc_out, norm_query, mask=src_pad_mask, return_attn=return_attn
|
| 348 |
+
)
|
| 349 |
+
layer_out = self.feed_forward(self.dropout(ctx_attn) + query)
|
| 350 |
+
|
| 351 |
+
return layer_out, attns
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
class TransformerDecoderBase(DecoderBase):
|
| 355 |
+
def __init__(
|
| 356 |
+
self, d_model, copy_attn, embeddings, alignment_layer, layer_norm, norm_eps
|
| 357 |
+
):
|
| 358 |
+
super(TransformerDecoderBase, self).__init__()
|
| 359 |
+
|
| 360 |
+
self.embeddings = embeddings
|
| 361 |
+
|
| 362 |
+
# Decoder State
|
| 363 |
+
self.state = {}
|
| 364 |
+
|
| 365 |
+
# previously, there was a GlobalAttention module here for copy
|
| 366 |
+
# attention. But it was never actually used -- the "copy" attention
|
| 367 |
+
# just reuses the context attention.
|
| 368 |
+
self._copy = copy_attn
|
| 369 |
+
if layer_norm == "standard":
|
| 370 |
+
self.layer_norm = nn.LayerNorm(d_model, eps=norm_eps)
|
| 371 |
+
elif layer_norm == "rms":
|
| 372 |
+
self.layer_norm = RMSNorm(d_model, eps=norm_eps)
|
| 373 |
+
else:
|
| 374 |
+
raise ValueError(f"{layer_norm} layer norm type is not supported")
|
| 375 |
+
|
| 376 |
+
self.alignment_layer = alignment_layer
|
| 377 |
+
|
| 378 |
+
@classmethod
|
| 379 |
+
def from_opt(cls, opt, embeddings):
|
| 380 |
+
"""Alternate constructor."""
|
| 381 |
+
return cls(
|
| 382 |
+
opt.dec_layers,
|
| 383 |
+
opt.dec_hid_size,
|
| 384 |
+
opt.heads,
|
| 385 |
+
opt.transformer_ff,
|
| 386 |
+
opt.copy_attn,
|
| 387 |
+
opt.self_attn_type,
|
| 388 |
+
opt.dropout[0] if type(opt.dropout) is list else opt.dropout,
|
| 389 |
+
opt.attention_dropout[0]
|
| 390 |
+
if type(opt.attention_dropout) is list
|
| 391 |
+
else opt.attention_dropout,
|
| 392 |
+
embeddings,
|
| 393 |
+
opt.max_relative_positions,
|
| 394 |
+
opt.relative_positions_buckets,
|
| 395 |
+
opt.aan_useffn,
|
| 396 |
+
opt.full_context_alignment,
|
| 397 |
+
opt.alignment_layer,
|
| 398 |
+
alignment_heads=opt.alignment_heads,
|
| 399 |
+
pos_ffn_activation_fn=opt.pos_ffn_activation_fn,
|
| 400 |
+
add_qkvbias=opt.add_qkvbias,
|
| 401 |
+
num_kv=opt.num_kv,
|
| 402 |
+
add_ffnbias=opt.add_ffnbias,
|
| 403 |
+
parallel_residual=opt.parallel_residual,
|
| 404 |
+
shared_layer_norm=opt.shared_layer_norm,
|
| 405 |
+
layer_norm=opt.layer_norm,
|
| 406 |
+
norm_eps=opt.norm_eps,
|
| 407 |
+
use_ckpting=opt.use_ckpting,
|
| 408 |
+
parallel_gpu=opt.world_size
|
| 409 |
+
if opt.parallel_mode == "tensor_parallel"
|
| 410 |
+
else 1,
|
| 411 |
+
)
|
| 412 |
+
|
| 413 |
+
def init_state(self, src, enc_out, enc_final_hs):
|
| 414 |
+
"""Initialize decoder state."""
|
| 415 |
+
self.state["src"] = src
|
| 416 |
+
|
| 417 |
+
def map_state(self, fn):
|
| 418 |
+
if self.state["src"] is not None:
|
| 419 |
+
self.state["src"] = fn(self.state["src"], 0)
|
| 420 |
+
for layer in self.transformer_layers:
|
| 421 |
+
if hasattr(layer, "context_attn"):
|
| 422 |
+
if layer.context_attn.layer_cache[1]["keys"].numel() != 0:
|
| 423 |
+
x = fn(layer.context_attn.layer_cache[1]["keys"], 0)
|
| 424 |
+
y = fn(layer.context_attn.layer_cache[1]["values"], 0)
|
| 425 |
+
layer.context_attn.layer_cache = True, {"keys": x, "values": y}
|
| 426 |
+
if isinstance(layer.self_attn, AverageAttention):
|
| 427 |
+
if layer.self_attn.layer_cache[1]["prev_g"].numel() != 0:
|
| 428 |
+
x = fn(layer.self_attn.layer_cache[1]["prev_g"], 0)
|
| 429 |
+
layer.self_attn.layer_cache = True, {"prev_g": x}
|
| 430 |
+
else:
|
| 431 |
+
if layer.self_attn.layer_cache[1]["keys"].numel() != 0:
|
| 432 |
+
x = fn(layer.self_attn.layer_cache[1]["keys"], 0)
|
| 433 |
+
y = fn(layer.self_attn.layer_cache[1]["values"], 0)
|
| 434 |
+
layer.self_attn.layer_cache = True, {"keys": x, "values": y}
|
| 435 |
+
|
| 436 |
+
def detach_state(self):
|
| 437 |
+
raise NotImplementedError
|
| 438 |
+
|
| 439 |
+
def forward(self, *args, **kwargs):
|
| 440 |
+
raise NotImplementedError
|
| 441 |
+
|
| 442 |
+
def update_dropout(self, dropout, attention_dropout):
|
| 443 |
+
self.embeddings.update_dropout(dropout)
|
| 444 |
+
for layer in self.transformer_layers:
|
| 445 |
+
layer.update_dropout(dropout, attention_dropout)
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
class TransformerDecoder(TransformerDecoderBase):
|
| 449 |
+
"""The Transformer decoder from "Attention is All You Need".
|
| 450 |
+
:cite:`DBLP:journals/corr/VaswaniSPUJGKP17`
|
| 451 |
+
|
| 452 |
+
Args:
|
| 453 |
+
num_layers (int): number of decoder layers.
|
| 454 |
+
d_model (int): size of the model
|
| 455 |
+
heads (int): number of heads
|
| 456 |
+
d_ff (int): size of the inner FF layer
|
| 457 |
+
copy_attn (bool): if using a separate copy attention
|
| 458 |
+
self_attn_type (str): type of self-attention scaled-dot, average
|
| 459 |
+
dropout (float): dropout in residual, self-attn(dot) and feed-forward
|
| 460 |
+
attention_dropout (float): dropout in context_attn (and self-attn(avg))
|
| 461 |
+
embeddings (onmt.modules.Embeddings):
|
| 462 |
+
embeddings to use, should have positional encodings
|
| 463 |
+
max_relative_positions (int):
|
| 464 |
+
Max distance between inputs in relative positions representations
|
| 465 |
+
relative_positions_buckets (int):
|
| 466 |
+
Number of buckets when using relative position bias
|
| 467 |
+
aan_useffn (bool): Turn on the FFN layer in the AAN decoder
|
| 468 |
+
full_context_alignment (bool):
|
| 469 |
+
whether enable an extra full context decoder forward for alignment
|
| 470 |
+
alignment_layer (int): N° Layer to supervise with for alignment guiding
|
| 471 |
+
alignment_heads (int):
|
| 472 |
+
N. of cross attention heads to use for alignment guiding
|
| 473 |
+
add_qkvbias (bool): whether to add bias to the Key/Value nn.Linear
|
| 474 |
+
layer_norm (string): type of layer normalization standard/rms
|
| 475 |
+
"""
|
| 476 |
+
|
| 477 |
+
def __init__(
|
| 478 |
+
self,
|
| 479 |
+
num_layers,
|
| 480 |
+
d_model,
|
| 481 |
+
heads,
|
| 482 |
+
d_ff,
|
| 483 |
+
copy_attn,
|
| 484 |
+
self_attn_type,
|
| 485 |
+
dropout,
|
| 486 |
+
attention_dropout,
|
| 487 |
+
embeddings,
|
| 488 |
+
max_relative_positions,
|
| 489 |
+
relative_positions_buckets,
|
| 490 |
+
aan_useffn,
|
| 491 |
+
full_context_alignment,
|
| 492 |
+
alignment_layer,
|
| 493 |
+
alignment_heads,
|
| 494 |
+
pos_ffn_activation_fn=ActivationFunction.relu,
|
| 495 |
+
add_qkvbias=False,
|
| 496 |
+
num_kv=0,
|
| 497 |
+
add_ffnbias=True,
|
| 498 |
+
parallel_residual=False,
|
| 499 |
+
shared_layer_norm=False,
|
| 500 |
+
layer_norm="standard",
|
| 501 |
+
norm_eps=1e-6,
|
| 502 |
+
use_ckpting=[],
|
| 503 |
+
parallel_gpu=1,
|
| 504 |
+
):
|
| 505 |
+
super(TransformerDecoder, self).__init__(
|
| 506 |
+
d_model, copy_attn, embeddings, alignment_layer, layer_norm, norm_eps
|
| 507 |
+
)
|
| 508 |
+
|
| 509 |
+
self.transformer_layers = nn.ModuleList(
|
| 510 |
+
[
|
| 511 |
+
TransformerDecoderLayer(
|
| 512 |
+
d_model,
|
| 513 |
+
heads,
|
| 514 |
+
d_ff,
|
| 515 |
+
dropout,
|
| 516 |
+
attention_dropout,
|
| 517 |
+
self_attn_type=self_attn_type,
|
| 518 |
+
max_relative_positions=max_relative_positions,
|
| 519 |
+
relative_positions_buckets=relative_positions_buckets,
|
| 520 |
+
aan_useffn=aan_useffn,
|
| 521 |
+
full_context_alignment=full_context_alignment,
|
| 522 |
+
alignment_heads=alignment_heads,
|
| 523 |
+
pos_ffn_activation_fn=pos_ffn_activation_fn,
|
| 524 |
+
add_qkvbias=add_qkvbias,
|
| 525 |
+
num_kv=num_kv,
|
| 526 |
+
add_ffnbias=add_ffnbias,
|
| 527 |
+
parallel_residual=parallel_residual,
|
| 528 |
+
shared_layer_norm=shared_layer_norm,
|
| 529 |
+
layer_norm=layer_norm,
|
| 530 |
+
norm_eps=norm_eps,
|
| 531 |
+
use_ckpting=use_ckpting,
|
| 532 |
+
parallel_gpu=parallel_gpu,
|
| 533 |
+
)
|
| 534 |
+
for i in range(num_layers)
|
| 535 |
+
]
|
| 536 |
+
)
|
| 537 |
+
|
| 538 |
+
def detach_state(self):
|
| 539 |
+
self.state["src"] = self.state["src"].detach()
|
| 540 |
+
|
| 541 |
+
def forward(self, tgt, enc_out=None, step=None, **kwargs):
|
| 542 |
+
"""
|
| 543 |
+
Decode, possibly stepwise.
|
| 544 |
+
when training step is always None, when decoding, step increases
|
| 545 |
+
tgt (Tensor): batch x tlen x feats
|
| 546 |
+
enc_out (Tensor): encoder output (batch x slen x model_dim)
|
| 547 |
+
"""
|
| 548 |
+
if enc_out is None:
|
| 549 |
+
enc_out = self.embeddings(tgt)
|
| 550 |
+
if step == 0:
|
| 551 |
+
self._init_cache(enc_out)
|
| 552 |
+
elif step is None:
|
| 553 |
+
for layer in self.transformer_layers:
|
| 554 |
+
if isinstance(layer.self_attn, AverageAttention):
|
| 555 |
+
layer.self_attn.layer_cache = False, {"prev_g": torch.tensor([])}
|
| 556 |
+
else:
|
| 557 |
+
layer.self_attn.layer_cache = (
|
| 558 |
+
False,
|
| 559 |
+
{"keys": torch.tensor([]), "values": torch.tensor([])},
|
| 560 |
+
)
|
| 561 |
+
layer.context_attn.layer_cache = (
|
| 562 |
+
False,
|
| 563 |
+
{"keys": torch.tensor([]), "values": torch.tensor([])},
|
| 564 |
+
)
|
| 565 |
+
|
| 566 |
+
emb = self.embeddings(tgt, step=step)
|
| 567 |
+
dec_out = emb
|
| 568 |
+
assert emb.dim() == 3 # len x batch x embedding_dim
|
| 569 |
+
|
| 570 |
+
pad_idx = self.embeddings.word_padding_idx
|
| 571 |
+
src_lens = kwargs["src_len"]
|
| 572 |
+
src_max_len = self.state["src"].shape[1]
|
| 573 |
+
src_pad_mask = ~sequence_mask(src_lens, src_max_len) # [B x slen]
|
| 574 |
+
src_pad_mask = src_pad_mask.unsqueeze(1) # [B x 1 x slen]
|
| 575 |
+
tgt_pad_mask = tgt[:, :, 0].eq(pad_idx).unsqueeze(1) # [B, 1, T_tgt]
|
| 576 |
+
|
| 577 |
+
with_align = kwargs.pop("with_align", False)
|
| 578 |
+
return_attn = with_align or self._copy
|
| 579 |
+
attn_aligns = []
|
| 580 |
+
|
| 581 |
+
for layer in self.transformer_layers:
|
| 582 |
+
dec_out, attn, attn_align = layer(
|
| 583 |
+
dec_out,
|
| 584 |
+
enc_out,
|
| 585 |
+
src_pad_mask,
|
| 586 |
+
tgt_pad_mask,
|
| 587 |
+
step=step,
|
| 588 |
+
with_align=with_align,
|
| 589 |
+
return_attn=return_attn,
|
| 590 |
+
)
|
| 591 |
+
if attn_align is not None:
|
| 592 |
+
attn_aligns.append(attn_align)
|
| 593 |
+
|
| 594 |
+
dec_out = self.layer_norm(dec_out)
|
| 595 |
+
|
| 596 |
+
attns = {"std": attn}
|
| 597 |
+
if self._copy:
|
| 598 |
+
attns["copy"] = attn
|
| 599 |
+
if with_align:
|
| 600 |
+
attns["align"] = attn_aligns[self.alignment_layer] # `(B, Q, K)`
|
| 601 |
+
# attns["align"] = torch.stack(attn_aligns, 0).mean(0) # All avg
|
| 602 |
+
|
| 603 |
+
# TODO change the way attns is returned dict => list or tuple (onnx)
|
| 604 |
+
return dec_out, attns
|
| 605 |
+
|
| 606 |
+
def _init_cache(self, enc_out):
|
| 607 |
+
batch_size = enc_out.size(0)
|
| 608 |
+
depth = enc_out.size(-1)
|
| 609 |
+
|
| 610 |
+
for layer in self.transformer_layers:
|
| 611 |
+
# first value set to True triggered by the beginning of decoding
|
| 612 |
+
# layer_cache becomes active in the MultiHeadedAttention fwd
|
| 613 |
+
layer.context_attn.layer_cache = (
|
| 614 |
+
True,
|
| 615 |
+
{
|
| 616 |
+
"keys": torch.tensor([], device=enc_out.device),
|
| 617 |
+
"values": torch.tensor([], device=enc_out.device),
|
| 618 |
+
},
|
| 619 |
+
)
|
| 620 |
+
if isinstance(layer.self_attn, AverageAttention):
|
| 621 |
+
layer.self_attn.layer_cache = True, {
|
| 622 |
+
"prev_g": torch.zeros(
|
| 623 |
+
(batch_size, 1, depth), device=enc_out.device
|
| 624 |
+
).to(enc_out.dtype)
|
| 625 |
+
}
|
| 626 |
+
else:
|
| 627 |
+
layer.self_attn.layer_cache = (
|
| 628 |
+
True,
|
| 629 |
+
{
|
| 630 |
+
"keys": torch.tensor([], device=enc_out.device),
|
| 631 |
+
"values": torch.tensor([], device=enc_out.device),
|
| 632 |
+
},
|
| 633 |
+
)
|
| 634 |
+
|
| 635 |
+
|
| 636 |
+
class TransformerLMDecoderLayer(TransformerDecoderLayerBase):
|
| 637 |
+
"""Transformer Decoder only layer block in GPT style.
|
| 638 |
+
Args:
|
| 639 |
+
See TransformerDecoderLayerBase
|
| 640 |
+
"""
|
| 641 |
+
|
| 642 |
+
def _forward(
|
| 643 |
+
self, layer_in, tgt_pad_mask, step=None, future=False, return_attn=False
|
| 644 |
+
):
|
| 645 |
+
"""A naive forward pass for transformer decoder.
|
| 646 |
+
|
| 647 |
+
# T: could be 1 in the case of stepwise decoding or tgt_len
|
| 648 |
+
|
| 649 |
+
Args:
|
| 650 |
+
layer_in (FloatTensor): ``(batch_size, T, model_dim)``
|
| 651 |
+
tgt_pad_mask (bool): ``(batch_size, 1, T)``
|
| 652 |
+
layer_cache (dict or None): cached layer info when stepwise decode
|
| 653 |
+
step (int or None): stepwise decoding counter
|
| 654 |
+
future (bool): If set True, do not apply future_mask.
|
| 655 |
+
return_attn (bool): If set True return attn
|
| 656 |
+
|
| 657 |
+
Returns:
|
| 658 |
+
(FloatTensor, FloatTensor):
|
| 659 |
+
|
| 660 |
+
* layer_out ``(batch_size, T, model_dim)``
|
| 661 |
+
* attns ``(batch_size, head, T, T)``
|
| 662 |
+
|
| 663 |
+
"""
|
| 664 |
+
dec_mask = None
|
| 665 |
+
|
| 666 |
+
if layer_in.size(1) > 1:
|
| 667 |
+
# masking is necessary when sequence length is greater than one
|
| 668 |
+
dec_mask = self._compute_dec_mask(tgt_pad_mask, future)
|
| 669 |
+
dec_mask = dec_mask.unsqueeze(1)
|
| 670 |
+
dec_mask = dec_mask.expand(-1, -1, dec_mask.size(3), -1)
|
| 671 |
+
# mask now are (batch x 1 x tlen x tlen)
|
| 672 |
+
# 1 = heads to be expanded in MHA
|
| 673 |
+
|
| 674 |
+
norm_layer_in = self.layer_norm_1(layer_in)
|
| 675 |
+
|
| 676 |
+
attn_output, attns = self._forward_self_attn(
|
| 677 |
+
norm_layer_in, dec_mask, step, return_attn=return_attn
|
| 678 |
+
)
|
| 679 |
+
|
| 680 |
+
if self.parallel_residual:
|
| 681 |
+
# feed_forward applies residual, so we remove and apply residual with un-normed
|
| 682 |
+
if not self.shared_layer_norm:
|
| 683 |
+
norm_res_layer_in = self.layer_norm_res(layer_in)
|
| 684 |
+
ff_in = norm_res_layer_in
|
| 685 |
+
else:
|
| 686 |
+
ff_in = norm_layer_in
|
| 687 |
+
layer_out = (
|
| 688 |
+
self.feed_forward(ff_in) - ff_in + layer_in + self.dropout(attn_output)
|
| 689 |
+
)
|
| 690 |
+
else:
|
| 691 |
+
layer_out = self.dropout(attn_output) + layer_in
|
| 692 |
+
layer_out = self.feed_forward(layer_out)
|
| 693 |
+
|
| 694 |
+
return layer_out, attns
|
| 695 |
+
|
| 696 |
+
|
| 697 |
+
class TransformerLMDecoder(TransformerDecoderBase):
|
| 698 |
+
"""The Transformer decoder from GPT-2
|
| 699 |
+
Args:
|
| 700 |
+
num_layers (int): number of decoder layers.
|
| 701 |
+
d_model (int): size of the model
|
| 702 |
+
heads (int): number of heads
|
| 703 |
+
d_ff (int): size of the inner FF layer
|
| 704 |
+
copy_attn (bool): if using a separate copy attention
|
| 705 |
+
self_attn_type (str): type of self-attention scaled-dot, average
|
| 706 |
+
dropout (float): dropout in residual, self-attn(dot) and feed-forward
|
| 707 |
+
attention_dropout (float): dropout in context_attn (and self-attn(avg))
|
| 708 |
+
embeddings (onmt.modules.Embeddings):
|
| 709 |
+
embeddings to use, should have positional encodings
|
| 710 |
+
max_relative_positions (int):
|
| 711 |
+
Max distance between inputs in relative positions representations
|
| 712 |
+
relative_positions_buckets (int):
|
| 713 |
+
Number of buckets when using Relative positions bias
|
| 714 |
+
aan_useffn (bool): Turn on the FFN layer in the AAN decoder
|
| 715 |
+
add_qkvbias (bool): whether to add bias to the Key/Value nn.Linear
|
| 716 |
+
"""
|
| 717 |
+
|
| 718 |
+
def __init__(
|
| 719 |
+
self,
|
| 720 |
+
num_layers,
|
| 721 |
+
d_model,
|
| 722 |
+
heads,
|
| 723 |
+
d_ff,
|
| 724 |
+
copy_attn,
|
| 725 |
+
self_attn_type,
|
| 726 |
+
dropout,
|
| 727 |
+
attention_dropout,
|
| 728 |
+
embeddings,
|
| 729 |
+
max_relative_positions,
|
| 730 |
+
relative_positions_buckets,
|
| 731 |
+
aan_useffn,
|
| 732 |
+
full_context_alignment=None,
|
| 733 |
+
alignment_layer=None,
|
| 734 |
+
alignment_heads=None,
|
| 735 |
+
pos_ffn_activation_fn=ActivationFunction.relu,
|
| 736 |
+
add_qkvbias=False,
|
| 737 |
+
num_kv=0,
|
| 738 |
+
add_ffnbias=True,
|
| 739 |
+
parallel_residual=False,
|
| 740 |
+
shared_layer_norm=False,
|
| 741 |
+
layer_norm="standard",
|
| 742 |
+
norm_eps=1e-6,
|
| 743 |
+
use_ckpting=[],
|
| 744 |
+
parallel_gpu=1,
|
| 745 |
+
):
|
| 746 |
+
super(TransformerLMDecoder, self).__init__(
|
| 747 |
+
d_model, copy_attn, embeddings, alignment_layer, layer_norm, norm_eps
|
| 748 |
+
)
|
| 749 |
+
self.transformer_layers = nn.ModuleList(
|
| 750 |
+
[
|
| 751 |
+
TransformerLMDecoderLayer(
|
| 752 |
+
d_model,
|
| 753 |
+
heads,
|
| 754 |
+
d_ff,
|
| 755 |
+
dropout,
|
| 756 |
+
attention_dropout,
|
| 757 |
+
self_attn_type=self_attn_type,
|
| 758 |
+
max_relative_positions=max_relative_positions,
|
| 759 |
+
relative_positions_buckets=relative_positions_buckets,
|
| 760 |
+
aan_useffn=aan_useffn,
|
| 761 |
+
full_context_alignment=None,
|
| 762 |
+
alignment_heads=None,
|
| 763 |
+
pos_ffn_activation_fn=pos_ffn_activation_fn,
|
| 764 |
+
add_qkvbias=add_qkvbias,
|
| 765 |
+
num_kv=num_kv,
|
| 766 |
+
add_ffnbias=add_ffnbias,
|
| 767 |
+
parallel_residual=parallel_residual,
|
| 768 |
+
shared_layer_norm=shared_layer_norm,
|
| 769 |
+
layer_norm=layer_norm,
|
| 770 |
+
norm_eps=norm_eps,
|
| 771 |
+
use_ckpting=use_ckpting,
|
| 772 |
+
parallel_gpu=parallel_gpu,
|
| 773 |
+
)
|
| 774 |
+
for i in range(num_layers)
|
| 775 |
+
]
|
| 776 |
+
)
|
| 777 |
+
|
| 778 |
+
def init_state(self, src=None, enc_out=None, enc_final_hs=None):
|
| 779 |
+
super(TransformerLMDecoder, self).init_state(None, None, None)
|
| 780 |
+
|
| 781 |
+
def detach_state(self):
|
| 782 |
+
pass
|
| 783 |
+
|
| 784 |
+
def forward(self, tgt, enc_out=None, step=None, **kwargs):
|
| 785 |
+
"""Decode, possibly stepwise."""
|
| 786 |
+
if step == 0:
|
| 787 |
+
self._init_cache(tgt)
|
| 788 |
+
elif step is None:
|
| 789 |
+
for layer in self.transformer_layers:
|
| 790 |
+
layer.self_attn.layer_cache = (
|
| 791 |
+
False,
|
| 792 |
+
{"keys": torch.tensor([]), "values": torch.tensor([])},
|
| 793 |
+
)
|
| 794 |
+
|
| 795 |
+
dec_out = self.embeddings(tgt, step=step)
|
| 796 |
+
|
| 797 |
+
assert dec_out.dim() == 3 # batch x len x embedding_dim
|
| 798 |
+
|
| 799 |
+
pad_idx = self.embeddings.word_padding_idx
|
| 800 |
+
tgt_pad_mask = tgt[:, :, 0].eq(pad_idx).unsqueeze(1) # [B, 1, T_tgt]
|
| 801 |
+
|
| 802 |
+
with_align = kwargs.pop("with_align", False)
|
| 803 |
+
return_attn = with_align or self._copy
|
| 804 |
+
assert not with_align, "TransformerLMDecoder does not support align"
|
| 805 |
+
|
| 806 |
+
for layer in self.transformer_layers:
|
| 807 |
+
dec_out, attn, _ = layer(
|
| 808 |
+
dec_out,
|
| 809 |
+
tgt_pad_mask,
|
| 810 |
+
step=step,
|
| 811 |
+
with_align=with_align,
|
| 812 |
+
return_attn=return_attn,
|
| 813 |
+
)
|
| 814 |
+
|
| 815 |
+
dec_out = self.layer_norm(dec_out)
|
| 816 |
+
|
| 817 |
+
attns = {"std": attn}
|
| 818 |
+
if self._copy:
|
| 819 |
+
attns["copy"] = attn
|
| 820 |
+
|
| 821 |
+
# TODO change the way attns is returned dict => list or tuple (onnx)
|
| 822 |
+
return dec_out, attns
|
| 823 |
+
|
| 824 |
+
def _init_cache(self, tgt=None):
|
| 825 |
+
for layer in self.transformer_layers:
|
| 826 |
+
if isinstance(layer.self_attn, AverageAttention):
|
| 827 |
+
raise NotImplementedError
|
| 828 |
+
else:
|
| 829 |
+
layer.self_attn.layer_cache = (
|
| 830 |
+
True,
|
| 831 |
+
{
|
| 832 |
+
"keys": torch.tensor([], device=tgt.device),
|
| 833 |
+
"values": torch.tensor([], device=tgt.device),
|
| 834 |
+
},
|
| 835 |
+
)
|
onmt/encoders/__init__.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Module defining encoders."""
|
| 2 |
+
import os
|
| 3 |
+
import importlib
|
| 4 |
+
from onmt.encoders.encoder import EncoderBase
|
| 5 |
+
from onmt.encoders.transformer import TransformerEncoder
|
| 6 |
+
from onmt.encoders.ggnn_encoder import GGNNEncoder
|
| 7 |
+
from onmt.encoders.rnn_encoder import RNNEncoder
|
| 8 |
+
from onmt.encoders.cnn_encoder import CNNEncoder
|
| 9 |
+
from onmt.encoders.mean_encoder import MeanEncoder
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
str2enc = {
|
| 13 |
+
"ggnn": GGNNEncoder,
|
| 14 |
+
"rnn": RNNEncoder,
|
| 15 |
+
"brnn": RNNEncoder,
|
| 16 |
+
"cnn": CNNEncoder,
|
| 17 |
+
"transformer": TransformerEncoder,
|
| 18 |
+
"mean": MeanEncoder,
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
__all__ = [
|
| 22 |
+
"EncoderBase",
|
| 23 |
+
"TransformerEncoder",
|
| 24 |
+
"GGNNEncoder",
|
| 25 |
+
"RNNEncoder",
|
| 26 |
+
"CNNEncoder",
|
| 27 |
+
"MeanEncoder",
|
| 28 |
+
"str2enc",
|
| 29 |
+
]
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def get_encoders_cls(encoder_names):
|
| 33 |
+
"""Return valid encoder class indicated in `encoder_names`."""
|
| 34 |
+
encoders_cls = {}
|
| 35 |
+
for name in encoder_names:
|
| 36 |
+
if name not in str2enc:
|
| 37 |
+
raise ValueError("%s encoder not supported!" % name)
|
| 38 |
+
encoders_cls[name] = str2enc[name]
|
| 39 |
+
return encoders_cls
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def register_encoder(name):
|
| 43 |
+
"""Encoder register that can be used to add new encoder class."""
|
| 44 |
+
|
| 45 |
+
def register_encoder_cls(cls):
|
| 46 |
+
if name in str2enc:
|
| 47 |
+
raise ValueError("Cannot register duplicate encoder ({})".format(name))
|
| 48 |
+
if not issubclass(cls, EncoderBase):
|
| 49 |
+
raise ValueError(f"encoder ({name}: {cls.__name_}) must extend EncoderBase")
|
| 50 |
+
str2enc[name] = cls
|
| 51 |
+
__all__.append(cls.__name__) # added to be complete
|
| 52 |
+
return cls
|
| 53 |
+
|
| 54 |
+
return register_encoder_cls
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
# Auto import python files in this directory
|
| 58 |
+
encoder_dir = os.path.dirname(__file__)
|
| 59 |
+
for file in os.listdir(encoder_dir):
|
| 60 |
+
path = os.path.join(encoder_dir, file)
|
| 61 |
+
if (
|
| 62 |
+
not file.startswith("_")
|
| 63 |
+
and not file.startswith(".")
|
| 64 |
+
and (file.endswith(".py") or os.path.isdir(path))
|
| 65 |
+
):
|
| 66 |
+
file_name = file[: file.find(".py")] if file.endswith(".py") else file
|
| 67 |
+
module = importlib.import_module("onmt.encoders." + file_name)
|
onmt/encoders/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (3.13 kB). View file
|
|
|