Upload 313 files
Browse filesThis view is limited to 50 files because it contains too many changes.  
							See raw diff
- e_smiles.py +0 -0
- infer.sh +10 -0
- inference.py +5 -0
- onmt/__init__.py +24 -0
- onmt/__pycache__/__init__.cpython-311.pyc +0 -0
- onmt/__pycache__/__init__.cpython-37.pyc +0 -0
- onmt/__pycache__/__init__.cpython-38.pyc +0 -0
- onmt/__pycache__/constants.cpython-311.pyc +0 -0
- onmt/__pycache__/constants.cpython-38.pyc +0 -0
- onmt/__pycache__/inference_engine.cpython-38.pyc +0 -0
- onmt/__pycache__/model_builder.cpython-311.pyc +0 -0
- onmt/__pycache__/model_builder.cpython-38.pyc +0 -0
- onmt/__pycache__/opts.cpython-311.pyc +0 -0
- onmt/__pycache__/opts.cpython-38.pyc +0 -0
- onmt/__pycache__/train_single.cpython-38.pyc +0 -0
- onmt/__pycache__/trainer.cpython-38.pyc +0 -0
- onmt/bin/__init__.py +0 -0
- onmt/bin/__pycache__/__init__.cpython-311.pyc +0 -0
- onmt/bin/__pycache__/__init__.cpython-38.pyc +0 -0
- onmt/bin/__pycache__/average_models.cpython-38.pyc +0 -0
- onmt/bin/__pycache__/build_vocab.cpython-38.pyc +0 -0
- onmt/bin/__pycache__/release_model.cpython-38.pyc +0 -0
- onmt/bin/__pycache__/server.cpython-38.pyc +0 -0
- onmt/bin/__pycache__/train.cpython-38.pyc +0 -0
- onmt/bin/__pycache__/translate.cpython-311.pyc +0 -0
- onmt/bin/__pycache__/translate.cpython-38.pyc +0 -0
- onmt/bin/average_models.py +60 -0
- onmt/bin/build_vocab.py +287 -0
- onmt/bin/release_model.py +39 -0
- onmt/bin/server.py +167 -0
- onmt/bin/train.py +71 -0
- onmt/bin/translate.py +60 -0
- onmt/constants.py +41 -0
- onmt/decoders/__init__.py +63 -0
- onmt/decoders/__pycache__/__init__.cpython-311.pyc +0 -0
- onmt/decoders/__pycache__/__init__.cpython-38.pyc +0 -0
- onmt/decoders/__pycache__/cnn_decoder.cpython-311.pyc +0 -0
- onmt/decoders/__pycache__/cnn_decoder.cpython-38.pyc +0 -0
- onmt/decoders/__pycache__/decoder.cpython-311.pyc +0 -0
- onmt/decoders/__pycache__/decoder.cpython-38.pyc +0 -0
- onmt/decoders/__pycache__/ensemble.cpython-311.pyc +0 -0
- onmt/decoders/__pycache__/ensemble.cpython-38.pyc +0 -0
- onmt/decoders/__pycache__/transformer.cpython-311.pyc +0 -0
- onmt/decoders/__pycache__/transformer.cpython-38.pyc +0 -0
- onmt/decoders/cnn_decoder.py +141 -0
- onmt/decoders/decoder.py +405 -0
- onmt/decoders/ensemble.py +150 -0
- onmt/decoders/transformer.py +835 -0
- onmt/encoders/__init__.py +67 -0
- onmt/encoders/__pycache__/__init__.cpython-311.pyc +0 -0
    	
        e_smiles.py
    ADDED
    
    | The diff for this file is too large to render. 
		See raw diff | 
|  | 
    	
        infer.sh
    ADDED
    
    | @@ -0,0 +1,10 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            python inference.py \
         | 
| 2 | 
            +
            --model trained_models/retrosnyhesis_ReactSeq_prompt_model_on_50k_aug100.pt \
         | 
| 3 | 
            +
            --src ./tmp_data/src.txt \
         | 
| 4 | 
            +
            --output ./tmp_data/tgt.txt \
         | 
| 5 | 
            +
            --beam_size 10 \
         | 
| 6 | 
            +
            --n_best 10 \
         | 
| 7 | 
            +
            --batch_size 16384 \
         | 
| 8 | 
            +
            --batch_type tokens \
         | 
| 9 | 
            +
            --max_length 500 \
         | 
| 10 | 
            +
            --seed 0
         | 
    	
        inference.py
    ADDED
    
    | @@ -0,0 +1,5 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #!/usr/bin/env python
         | 
| 2 | 
            +
            from onmt.bin.translate import main
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            if __name__ == "__main__":
         | 
| 5 | 
            +
                main()
         | 
    	
        onmt/__init__.py
    ADDED
    
    | @@ -0,0 +1,24 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            """ Main entry point of the ONMT library """
         | 
| 2 | 
            +
            import onmt.inputters
         | 
| 3 | 
            +
            import onmt.encoders
         | 
| 4 | 
            +
            import onmt.decoders
         | 
| 5 | 
            +
            import onmt.models
         | 
| 6 | 
            +
            import onmt.utils
         | 
| 7 | 
            +
            import onmt.modules
         | 
| 8 | 
            +
            import sys
         | 
| 9 | 
            +
            import onmt.utils.optimizers
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            onmt.utils.optimizers.Optim = onmt.utils.optimizers.Optimizer
         | 
| 12 | 
            +
            sys.modules["onmt.Optim"] = onmt.utils.optimizers
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            # For Flake
         | 
| 15 | 
            +
            __all__ = [
         | 
| 16 | 
            +
                onmt.inputters,
         | 
| 17 | 
            +
                onmt.encoders,
         | 
| 18 | 
            +
                onmt.decoders,
         | 
| 19 | 
            +
                onmt.models,
         | 
| 20 | 
            +
                onmt.utils,
         | 
| 21 | 
            +
                onmt.modules,
         | 
| 22 | 
            +
            ]
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            __version__ = "3.4.1"
         | 
    	
        onmt/__pycache__/__init__.cpython-311.pyc
    ADDED
    
    | Binary file (892 Bytes). View file | 
|  | 
    	
        onmt/__pycache__/__init__.cpython-37.pyc
    ADDED
    
    | Binary file (605 Bytes). View file | 
|  | 
    	
        onmt/__pycache__/__init__.cpython-38.pyc
    ADDED
    
    | Binary file (603 Bytes). View file | 
|  | 
    	
        onmt/__pycache__/constants.cpython-311.pyc
    ADDED
    
    | Binary file (2.06 kB). View file | 
|  | 
    	
        onmt/__pycache__/constants.cpython-38.pyc
    ADDED
    
    | Binary file (1.61 kB). View file | 
|  | 
    	
        onmt/__pycache__/inference_engine.cpython-38.pyc
    ADDED
    
    | Binary file (3.22 kB). View file | 
|  | 
    	
        onmt/__pycache__/model_builder.cpython-311.pyc
    ADDED
    
    | Binary file (19.4 kB). View file | 
|  | 
    	
        onmt/__pycache__/model_builder.cpython-38.pyc
    ADDED
    
    | Binary file (10.6 kB). View file | 
|  | 
    	
        onmt/__pycache__/opts.cpython-311.pyc
    ADDED
    
    | Binary file (58 kB). View file | 
|  | 
    	
        onmt/__pycache__/opts.cpython-38.pyc
    ADDED
    
    | Binary file (38.4 kB). View file | 
|  | 
    	
        onmt/__pycache__/train_single.cpython-38.pyc
    ADDED
    
    | Binary file (6.41 kB). View file | 
|  | 
    	
        onmt/__pycache__/trainer.cpython-38.pyc
    ADDED
    
    | Binary file (14.5 kB). View file | 
|  | 
    	
        onmt/bin/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        onmt/bin/__pycache__/__init__.cpython-311.pyc
    ADDED
    
    | Binary file (171 Bytes). View file | 
|  | 
    	
        onmt/bin/__pycache__/__init__.cpython-38.pyc
    ADDED
    
    | Binary file (145 Bytes). View file | 
|  | 
    	
        onmt/bin/__pycache__/average_models.cpython-38.pyc
    ADDED
    
    | Binary file (1.48 kB). View file | 
|  | 
    	
        onmt/bin/__pycache__/build_vocab.cpython-38.pyc
    ADDED
    
    | Binary file (8.77 kB). View file | 
|  | 
    	
        onmt/bin/__pycache__/release_model.cpython-38.pyc
    ADDED
    
    | Binary file (1.17 kB). View file | 
|  | 
    	
        onmt/bin/__pycache__/server.cpython-38.pyc
    ADDED
    
    | Binary file (5.08 kB). View file | 
|  | 
    	
        onmt/bin/__pycache__/train.cpython-38.pyc
    ADDED
    
    | Binary file (1.84 kB). View file | 
|  | 
    	
        onmt/bin/__pycache__/translate.cpython-311.pyc
    ADDED
    
    | Binary file (2.89 kB). View file | 
|  | 
    	
        onmt/bin/__pycache__/translate.cpython-38.pyc
    ADDED
    
    | Binary file (1.77 kB). View file | 
|  | 
    	
        onmt/bin/average_models.py
    ADDED
    
    | @@ -0,0 +1,60 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #!/usr/bin/env python
         | 
| 2 | 
            +
            import argparse
         | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
             | 
| 5 | 
            +
             | 
| 6 | 
            +
            def average_models(model_files, fp32=False):
         | 
| 7 | 
            +
                vocab = None
         | 
| 8 | 
            +
                opt = None
         | 
| 9 | 
            +
                avg_model = None
         | 
| 10 | 
            +
                avg_generator = None
         | 
| 11 | 
            +
             | 
| 12 | 
            +
                for i, model_file in enumerate(model_files):
         | 
| 13 | 
            +
                    m = torch.load(model_file, map_location="cpu")
         | 
| 14 | 
            +
                    model_weights = m["model"]
         | 
| 15 | 
            +
                    generator_weights = m["generator"]
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                    if fp32:
         | 
| 18 | 
            +
                        for k, v in model_weights.items():
         | 
| 19 | 
            +
                            model_weights[k] = v.float()
         | 
| 20 | 
            +
                        for k, v in generator_weights.items():
         | 
| 21 | 
            +
                            generator_weights[k] = v.float()
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                    if i == 0:
         | 
| 24 | 
            +
                        vocab, opt = m["vocab"], m["opt"]
         | 
| 25 | 
            +
                        avg_model = model_weights
         | 
| 26 | 
            +
                        avg_generator = generator_weights
         | 
| 27 | 
            +
                    else:
         | 
| 28 | 
            +
                        for k, v in avg_model.items():
         | 
| 29 | 
            +
                            avg_model[k].mul_(i).add_(model_weights[k]).div_(i + 1)
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                        for k, v in avg_generator.items():
         | 
| 32 | 
            +
                            avg_generator[k].mul_(i).add_(generator_weights[k]).div_(i + 1)
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                final = {
         | 
| 35 | 
            +
                    "vocab": vocab,
         | 
| 36 | 
            +
                    "opt": opt,
         | 
| 37 | 
            +
                    "optim": None,
         | 
| 38 | 
            +
                    "generator": avg_generator,
         | 
| 39 | 
            +
                    "model": avg_model,
         | 
| 40 | 
            +
                }
         | 
| 41 | 
            +
                return final
         | 
| 42 | 
            +
             | 
| 43 | 
            +
             | 
| 44 | 
            +
            def main():
         | 
| 45 | 
            +
                parser = argparse.ArgumentParser(description="")
         | 
| 46 | 
            +
                parser.add_argument(
         | 
| 47 | 
            +
                    "-models", "-m", nargs="+", required=True, help="List of models"
         | 
| 48 | 
            +
                )
         | 
| 49 | 
            +
                parser.add_argument("-output", "-o", required=True, help="Output file")
         | 
| 50 | 
            +
                parser.add_argument(
         | 
| 51 | 
            +
                    "-fp32", "-f", action="store_true", help="Cast params to float32"
         | 
| 52 | 
            +
                )
         | 
| 53 | 
            +
                opt = parser.parse_args()
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                final = average_models(opt.models, opt.fp32)
         | 
| 56 | 
            +
                torch.save(final, opt.output)
         | 
| 57 | 
            +
             | 
| 58 | 
            +
             | 
| 59 | 
            +
            if __name__ == "__main__":
         | 
| 60 | 
            +
                main()
         | 
    	
        onmt/bin/build_vocab.py
    ADDED
    
    | @@ -0,0 +1,287 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #!/usr/bin/env python
         | 
| 2 | 
            +
            """Get vocabulary coutings from transformed corpora samples."""
         | 
| 3 | 
            +
            import os
         | 
| 4 | 
            +
            import copy
         | 
| 5 | 
            +
            import multiprocessing as mp
         | 
| 6 | 
            +
            import pyonmttok
         | 
| 7 | 
            +
            from functools import partial
         | 
| 8 | 
            +
            from onmt.utils.logging import init_logger, logger
         | 
| 9 | 
            +
            from onmt.utils.misc import set_random_seed, check_path
         | 
| 10 | 
            +
            from onmt.utils.parse import ArgumentParser
         | 
| 11 | 
            +
            from onmt.opts import dynamic_prepare_opts
         | 
| 12 | 
            +
            from onmt.inputters.text_corpus import build_corpora_iters, get_corpora
         | 
| 13 | 
            +
            from onmt.inputters.text_utils import process, append_features_to_text
         | 
| 14 | 
            +
            from onmt.transforms import make_transforms, get_transforms_cls
         | 
| 15 | 
            +
            from onmt.constants import CorpusName, CorpusTask
         | 
| 16 | 
            +
            from collections import Counter
         | 
| 17 | 
            +
             | 
| 18 | 
            +
             | 
| 19 | 
            +
            MAXBUCKETSIZE = 256000
         | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
            def write_files_from_queues(sample_path, queues):
         | 
| 23 | 
            +
                """
         | 
| 24 | 
            +
                Standalone process that reads data from
         | 
| 25 | 
            +
                queues in order and write to sample files.
         | 
| 26 | 
            +
                """
         | 
| 27 | 
            +
                os.makedirs(sample_path, exist_ok=True)
         | 
| 28 | 
            +
                for c_name in queues.keys():
         | 
| 29 | 
            +
                    dest_base = os.path.join(sample_path, "{}.{}".format(c_name, CorpusName.SAMPLE))
         | 
| 30 | 
            +
                    with open(dest_base + ".src", "w", encoding="utf-8") as f_src, open(
         | 
| 31 | 
            +
                        dest_base + ".tgt", "w", encoding="utf-8"
         | 
| 32 | 
            +
                    ) as f_tgt:
         | 
| 33 | 
            +
                        while True:
         | 
| 34 | 
            +
                            _next = False
         | 
| 35 | 
            +
                            for q in queues[c_name]:
         | 
| 36 | 
            +
                                item = q.get()
         | 
| 37 | 
            +
                                if item == "blank":
         | 
| 38 | 
            +
                                    continue
         | 
| 39 | 
            +
                                if item == "break":
         | 
| 40 | 
            +
                                    _next = True
         | 
| 41 | 
            +
                                    break
         | 
| 42 | 
            +
                                _, src_line, tgt_line = item
         | 
| 43 | 
            +
                                f_src.write(src_line + "\n")
         | 
| 44 | 
            +
                                f_tgt.write(tgt_line + "\n")
         | 
| 45 | 
            +
                            if _next:
         | 
| 46 | 
            +
                                break
         | 
| 47 | 
            +
             | 
| 48 | 
            +
             | 
| 49 | 
            +
            def build_sub_vocab(corpora, transforms, opts, n_sample, stride, offset):
         | 
| 50 | 
            +
                """Build vocab on (strided) subpart of the data."""
         | 
| 51 | 
            +
                sub_counter_src = Counter()
         | 
| 52 | 
            +
                sub_counter_tgt = Counter()
         | 
| 53 | 
            +
                sub_counter_src_feats = [Counter() for _ in range(opts.n_src_feats)]
         | 
| 54 | 
            +
                datasets_iterables = build_corpora_iters(
         | 
| 55 | 
            +
                    corpora,
         | 
| 56 | 
            +
                    transforms,
         | 
| 57 | 
            +
                    opts.data,
         | 
| 58 | 
            +
                    skip_empty_level=opts.skip_empty_level,
         | 
| 59 | 
            +
                    stride=stride,
         | 
| 60 | 
            +
                    offset=offset,
         | 
| 61 | 
            +
                )
         | 
| 62 | 
            +
                for c_name, c_iter in datasets_iterables.items():
         | 
| 63 | 
            +
                    for i, item in enumerate(c_iter):
         | 
| 64 | 
            +
                        maybe_example = process(CorpusTask.TRAIN, [item])
         | 
| 65 | 
            +
                        if maybe_example is not None:
         | 
| 66 | 
            +
                            maybe_example = maybe_example[0]
         | 
| 67 | 
            +
                        else:
         | 
| 68 | 
            +
                            if opts.dump_samples:
         | 
| 69 | 
            +
                                build_sub_vocab.queues[c_name][offset].put("blank")
         | 
| 70 | 
            +
                            continue
         | 
| 71 | 
            +
                        src_line, tgt_line = (
         | 
| 72 | 
            +
                            maybe_example["src"]["src"],
         | 
| 73 | 
            +
                            maybe_example["tgt"]["tgt"],
         | 
| 74 | 
            +
                        )
         | 
| 75 | 
            +
                        sub_counter_src.update(src_line.split(" "))
         | 
| 76 | 
            +
                        sub_counter_tgt.update(tgt_line.split(" "))
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                        if "feats" in maybe_example["src"]:
         | 
| 79 | 
            +
                            src_feats_lines = maybe_example["src"]["feats"]
         | 
| 80 | 
            +
                            for k in range(opts.n_src_feats):
         | 
| 81 | 
            +
                                sub_counter_src_feats[k].update(src_feats_lines[k].split(" "))
         | 
| 82 | 
            +
                        else:
         | 
| 83 | 
            +
                            src_feats_lines = []
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                        if opts.dump_samples:
         | 
| 86 | 
            +
                            src_pretty_line = append_features_to_text(src_line, src_feats_lines)
         | 
| 87 | 
            +
                            build_sub_vocab.queues[c_name][offset].put(
         | 
| 88 | 
            +
                                (i, src_pretty_line, tgt_line)
         | 
| 89 | 
            +
                            )
         | 
| 90 | 
            +
                        if n_sample > 0 and ((i + 1) * stride + offset) >= n_sample:
         | 
| 91 | 
            +
                            if opts.dump_samples:
         | 
| 92 | 
            +
                                build_sub_vocab.queues[c_name][offset].put("break")
         | 
| 93 | 
            +
                            break
         | 
| 94 | 
            +
                    if opts.dump_samples:
         | 
| 95 | 
            +
                        build_sub_vocab.queues[c_name][offset].put("break")
         | 
| 96 | 
            +
                return sub_counter_src, sub_counter_tgt, sub_counter_src_feats
         | 
| 97 | 
            +
             | 
| 98 | 
            +
             | 
| 99 | 
            +
            def init_pool(queues):
         | 
| 100 | 
            +
                """Add the queues as attribute of the pooled function."""
         | 
| 101 | 
            +
                build_sub_vocab.queues = queues
         | 
| 102 | 
            +
             | 
| 103 | 
            +
             | 
| 104 | 
            +
            def build_vocab(opts, transforms, n_sample=3):
         | 
| 105 | 
            +
                """Build vocabulary from data."""
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                if n_sample == -1:
         | 
| 108 | 
            +
                    logger.info(f"n_sample={n_sample}: Build vocab on full datasets.")
         | 
| 109 | 
            +
                elif n_sample > 0:
         | 
| 110 | 
            +
                    logger.info(f"Build vocab on {n_sample} transformed examples/corpus.")
         | 
| 111 | 
            +
                else:
         | 
| 112 | 
            +
                    raise ValueError(f"n_sample should > 0 or == -1, get {n_sample}.")
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                if opts.dump_samples:
         | 
| 115 | 
            +
                    logger.info(
         | 
| 116 | 
            +
                        "The samples on which the vocab is built will be "
         | 
| 117 | 
            +
                        "dumped to disk. It may slow down the process."
         | 
| 118 | 
            +
                    )
         | 
| 119 | 
            +
                corpora = get_corpora(opts, task=CorpusTask.TRAIN)
         | 
| 120 | 
            +
                counter_src = Counter()
         | 
| 121 | 
            +
                counter_tgt = Counter()
         | 
| 122 | 
            +
                counter_src_feats = [Counter() for _ in range(opts.n_src_feats)]
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                queues = {
         | 
| 125 | 
            +
                    c_name: [
         | 
| 126 | 
            +
                        mp.Queue(opts.vocab_sample_queue_size) for i in range(opts.num_threads)
         | 
| 127 | 
            +
                    ]
         | 
| 128 | 
            +
                    for c_name in corpora.keys()
         | 
| 129 | 
            +
                }
         | 
| 130 | 
            +
                sample_path = os.path.join(os.path.dirname(opts.save_data), CorpusName.SAMPLE)
         | 
| 131 | 
            +
                if opts.dump_samples:
         | 
| 132 | 
            +
                    write_process = mp.Process(
         | 
| 133 | 
            +
                        target=write_files_from_queues, args=(sample_path, queues), daemon=True
         | 
| 134 | 
            +
                    )
         | 
| 135 | 
            +
                    write_process.start()
         | 
| 136 | 
            +
                with mp.Pool(opts.num_threads, init_pool, [queues]) as p:
         | 
| 137 | 
            +
                    func = partial(
         | 
| 138 | 
            +
                        build_sub_vocab, corpora, transforms, opts, n_sample, opts.num_threads
         | 
| 139 | 
            +
                    )
         | 
| 140 | 
            +
                    for sub_counter_src, sub_counter_tgt, sub_counter_src_feats in p.imap(
         | 
| 141 | 
            +
                        func, range(0, opts.num_threads)
         | 
| 142 | 
            +
                    ):
         | 
| 143 | 
            +
                        counter_src.update(sub_counter_src)
         | 
| 144 | 
            +
                        counter_tgt.update(sub_counter_tgt)
         | 
| 145 | 
            +
                        for i in range(opts.n_src_feats):
         | 
| 146 | 
            +
                            counter_src_feats[i].update(sub_counter_src_feats[i])
         | 
| 147 | 
            +
                if opts.dump_samples:
         | 
| 148 | 
            +
                    write_process.join()
         | 
| 149 | 
            +
                return counter_src, counter_tgt, counter_src_feats
         | 
| 150 | 
            +
             | 
| 151 | 
            +
             | 
| 152 | 
            +
            def ingest_tokens(opts, transforms, n_sample, learner, stride, offset):
         | 
| 153 | 
            +
                def _mp_ingest(data):
         | 
| 154 | 
            +
                    func = partial(process, CorpusName.TRAIN)
         | 
| 155 | 
            +
                    chunk = len(data) // opts.num_threads
         | 
| 156 | 
            +
                    with mp.Pool(opts.num_threads) as pool:
         | 
| 157 | 
            +
                        buckets = pool.map(
         | 
| 158 | 
            +
                            func,
         | 
| 159 | 
            +
                            [data[i * chunk : (i + 1) * chunk] for i in range(0, opts.num_threads)],
         | 
| 160 | 
            +
                        )
         | 
| 161 | 
            +
                    for bucket in buckets:
         | 
| 162 | 
            +
                        for ex in bucket:
         | 
| 163 | 
            +
                            if ex is not None:
         | 
| 164 | 
            +
                                src_line, tgt_line = (ex["src"]["src"], ex["tgt"]["tgt"])
         | 
| 165 | 
            +
                                learner.ingest(src_line)
         | 
| 166 | 
            +
                                learner.ingest(tgt_line)
         | 
| 167 | 
            +
             | 
| 168 | 
            +
                corpora = get_corpora(opts, task=CorpusTask.TRAIN)
         | 
| 169 | 
            +
                datasets_iterables = build_corpora_iters(
         | 
| 170 | 
            +
                    corpora,
         | 
| 171 | 
            +
                    transforms,
         | 
| 172 | 
            +
                    opts.data,
         | 
| 173 | 
            +
                    skip_empty_level=opts.skip_empty_level,
         | 
| 174 | 
            +
                    stride=stride,
         | 
| 175 | 
            +
                    offset=offset,
         | 
| 176 | 
            +
                )
         | 
| 177 | 
            +
                to_ingest = []
         | 
| 178 | 
            +
                for c_name, c_iter in datasets_iterables.items():
         | 
| 179 | 
            +
                    for i, item in enumerate(c_iter):
         | 
| 180 | 
            +
                        if n_sample >= 0 and i >= n_sample:
         | 
| 181 | 
            +
                            break
         | 
| 182 | 
            +
                        if len(to_ingest) >= MAXBUCKETSIZE:
         | 
| 183 | 
            +
                            _mp_ingest(to_ingest)
         | 
| 184 | 
            +
                            to_ingest = []
         | 
| 185 | 
            +
                        to_ingest.append(item)
         | 
| 186 | 
            +
                    _mp_ingest(to_ingest)
         | 
| 187 | 
            +
             | 
| 188 | 
            +
             | 
| 189 | 
            +
            def make_learner(tokenization_type, symbols):
         | 
| 190 | 
            +
                if tokenization_type == "bpe":
         | 
| 191 | 
            +
                    # BPE training
         | 
| 192 | 
            +
                    learner = pyonmttok.BPELearner(tokenizer=None, symbols=symbols)
         | 
| 193 | 
            +
                elif tokenization_type == "sentencepiece":
         | 
| 194 | 
            +
                    # SentencePiece training
         | 
| 195 | 
            +
                    learner = pyonmttok.SentencePieceLearner(
         | 
| 196 | 
            +
                        vocab_size=symbols, character_coverage=0.98
         | 
| 197 | 
            +
                    )
         | 
| 198 | 
            +
                return learner
         | 
| 199 | 
            +
             | 
| 200 | 
            +
             | 
| 201 | 
            +
            def build_vocab_main(opts):
         | 
| 202 | 
            +
                """Apply transforms to samples of specified data and build vocab from it.
         | 
| 203 | 
            +
             | 
| 204 | 
            +
                Transforms that need vocab will be disabled in this.
         | 
| 205 | 
            +
                Built vocab is saved in plain text format as following and can be pass as
         | 
| 206 | 
            +
                `-src_vocab` (and `-tgt_vocab`) when training:
         | 
| 207 | 
            +
                ```
         | 
| 208 | 
            +
                <tok_0>\t<count_0>
         | 
| 209 | 
            +
                <tok_1>\t<count_1>
         | 
| 210 | 
            +
                ```
         | 
| 211 | 
            +
                """
         | 
| 212 | 
            +
             | 
| 213 | 
            +
                ArgumentParser.validate_prepare_opts(opts, build_vocab_only=True)
         | 
| 214 | 
            +
                assert (
         | 
| 215 | 
            +
                    opts.n_sample == -1 or opts.n_sample > 1
         | 
| 216 | 
            +
                ), f"Illegal argument n_sample={opts.n_sample}."
         | 
| 217 | 
            +
             | 
| 218 | 
            +
                logger = init_logger()
         | 
| 219 | 
            +
                set_random_seed(opts.seed, False)
         | 
| 220 | 
            +
                transforms_cls = get_transforms_cls(opts._all_transform)
         | 
| 221 | 
            +
             | 
| 222 | 
            +
                if opts.learn_subwords:
         | 
| 223 | 
            +
                    logger.info(f"Ingesting {opts.src_subword_type} model from corpus")
         | 
| 224 | 
            +
                    learner = make_learner(opts.src_subword_type, opts.learn_subwords_size)
         | 
| 225 | 
            +
                    if opts.src_subword_model is not None:
         | 
| 226 | 
            +
                        tok_path = opts.src_subword_model
         | 
| 227 | 
            +
                    else:
         | 
| 228 | 
            +
                        data_dir = os.path.split(opts.save_data)[0]
         | 
| 229 | 
            +
                        if not os.path.exists(data_dir):
         | 
| 230 | 
            +
                            os.makedirs(data_dir)
         | 
| 231 | 
            +
                        tok_path = os.path.join(data_dir, f"{opts.src_subword_type}.model")
         | 
| 232 | 
            +
                    save_opts = copy.deepcopy(opts)
         | 
| 233 | 
            +
                    opts.src_subword_type = "none"
         | 
| 234 | 
            +
                    opts.tgt_subword_type = "none"
         | 
| 235 | 
            +
                    opts.src_onmttok_kwargs["joiner_annotate"] = False
         | 
| 236 | 
            +
                    opts.tgt_onmttok_kwargs["joiner_annotate"] = False
         | 
| 237 | 
            +
                    transforms = make_transforms(opts, transforms_cls, None)
         | 
| 238 | 
            +
                    ingest_tokens(opts, transforms, opts.n_sample, learner, 1, 0)
         | 
| 239 | 
            +
                    logger.info(f"Learning {tok_path} model, patience")
         | 
| 240 | 
            +
                    learner.learn(tok_path)
         | 
| 241 | 
            +
                    opts = save_opts
         | 
| 242 | 
            +
             | 
| 243 | 
            +
                transforms = make_transforms(opts, transforms_cls, None)
         | 
| 244 | 
            +
             | 
| 245 | 
            +
                logger.info(f"Counter vocab from {opts.n_sample} samples.")
         | 
| 246 | 
            +
                src_counter, tgt_counter, src_feats_counter = build_vocab(
         | 
| 247 | 
            +
                    opts, transforms, n_sample=opts.n_sample
         | 
| 248 | 
            +
                )
         | 
| 249 | 
            +
             | 
| 250 | 
            +
                logger.info(f"Counters src: {len(src_counter)}")
         | 
| 251 | 
            +
                logger.info(f"Counters tgt: {len(tgt_counter)}")
         | 
| 252 | 
            +
                for i, feat_counter in enumerate(src_feats_counter):
         | 
| 253 | 
            +
                    logger.info(f"Counters src feat_{i}: {len(feat_counter)}")
         | 
| 254 | 
            +
             | 
| 255 | 
            +
                def save_counter(counter, save_path):
         | 
| 256 | 
            +
                    check_path(save_path, exist_ok=opts.overwrite, log=logger.warning)
         | 
| 257 | 
            +
                    with open(save_path, "w", encoding="utf8") as fo:
         | 
| 258 | 
            +
                        for tok, count in counter.most_common():
         | 
| 259 | 
            +
                            fo.write(tok + "\t" + str(count) + "\n")
         | 
| 260 | 
            +
             | 
| 261 | 
            +
                if opts.share_vocab:
         | 
| 262 | 
            +
                    src_counter += tgt_counter
         | 
| 263 | 
            +
                    tgt_counter = src_counter
         | 
| 264 | 
            +
                    logger.info(f"Counters after share:{len(src_counter)}")
         | 
| 265 | 
            +
                    save_counter(src_counter, opts.src_vocab)
         | 
| 266 | 
            +
                else:
         | 
| 267 | 
            +
                    save_counter(src_counter, opts.src_vocab)
         | 
| 268 | 
            +
                    save_counter(tgt_counter, opts.tgt_vocab)
         | 
| 269 | 
            +
             | 
| 270 | 
            +
                for i, c in enumerate(src_feats_counter):
         | 
| 271 | 
            +
                    save_counter(c, f"{opts.src_vocab}_feat{i}")
         | 
| 272 | 
            +
             | 
| 273 | 
            +
             | 
| 274 | 
            +
            def _get_parser():
         | 
| 275 | 
            +
                parser = ArgumentParser(description="build_vocab.py")
         | 
| 276 | 
            +
                dynamic_prepare_opts(parser, build_vocab_only=True)
         | 
| 277 | 
            +
                return parser
         | 
| 278 | 
            +
             | 
| 279 | 
            +
             | 
| 280 | 
            +
            def main():
         | 
| 281 | 
            +
                parser = _get_parser()
         | 
| 282 | 
            +
                opts, unknown = parser.parse_known_args()
         | 
| 283 | 
            +
                build_vocab_main(opts)
         | 
| 284 | 
            +
             | 
| 285 | 
            +
             | 
| 286 | 
            +
            if __name__ == "__main__":
         | 
| 287 | 
            +
                main()
         | 
    	
        onmt/bin/release_model.py
    ADDED
    
    | @@ -0,0 +1,39 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #!/usr/bin/env python
         | 
| 2 | 
            +
            import argparse
         | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
             | 
| 5 | 
            +
             | 
| 6 | 
            +
            def main():
         | 
| 7 | 
            +
                parser = argparse.ArgumentParser(
         | 
| 8 | 
            +
                    description="Release an OpenNMT-py model for inference"
         | 
| 9 | 
            +
                )
         | 
| 10 | 
            +
                parser.add_argument("--model", "-m", help="The model path", required=True)
         | 
| 11 | 
            +
                parser.add_argument("--output", "-o", help="The output path", required=True)
         | 
| 12 | 
            +
                parser.add_argument(
         | 
| 13 | 
            +
                    "--format",
         | 
| 14 | 
            +
                    choices=["pytorch", "ctranslate2"],
         | 
| 15 | 
            +
                    default="pytorch",
         | 
| 16 | 
            +
                    help="The format of the released model",
         | 
| 17 | 
            +
                )
         | 
| 18 | 
            +
                parser.add_argument(
         | 
| 19 | 
            +
                    "--quantization",
         | 
| 20 | 
            +
                    "-q",
         | 
| 21 | 
            +
                    choices=["int8", "int16", "float16", "int8_float16"],
         | 
| 22 | 
            +
                    default=None,
         | 
| 23 | 
            +
                    help="Quantization type for CT2 model.",
         | 
| 24 | 
            +
                )
         | 
| 25 | 
            +
                opt = parser.parse_args()
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                model = torch.load(opt.model, map_location=torch.device("cpu"))
         | 
| 28 | 
            +
                if opt.format == "pytorch":
         | 
| 29 | 
            +
                    model["optim"] = None
         | 
| 30 | 
            +
                    torch.save(model, opt.output)
         | 
| 31 | 
            +
                elif opt.format == "ctranslate2":
         | 
| 32 | 
            +
                    import ctranslate2
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                    converter = ctranslate2.converters.OpenNMTPyConverter(opt.model)
         | 
| 35 | 
            +
                    converter.convert(opt.output, force=True, quantization=opt.quantization)
         | 
| 36 | 
            +
             | 
| 37 | 
            +
             | 
| 38 | 
            +
            if __name__ == "__main__":
         | 
| 39 | 
            +
                main()
         | 
    	
        onmt/bin/server.py
    ADDED
    
    | @@ -0,0 +1,167 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #!/usr/bin/env python
         | 
| 2 | 
            +
            import configargparse
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            from flask import Flask, jsonify, request
         | 
| 5 | 
            +
            from waitress import serve
         | 
| 6 | 
            +
            from onmt.translate import TranslationServer, ServerModelError
         | 
| 7 | 
            +
            import logging
         | 
| 8 | 
            +
            from logging.handlers import RotatingFileHandler
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            STATUS_OK = "ok"
         | 
| 11 | 
            +
            STATUS_ERROR = "error"
         | 
| 12 | 
            +
             | 
| 13 | 
            +
             | 
| 14 | 
            +
            def start(config_file, url_root="./translator", host="0.0.0.0", port=5000, debug=False):
         | 
| 15 | 
            +
                def prefix_route(route_function, prefix="", mask="{0}{1}"):
         | 
| 16 | 
            +
                    def newroute(route, *args, **kwargs):
         | 
| 17 | 
            +
                        return route_function(mask.format(prefix, route), *args, **kwargs)
         | 
| 18 | 
            +
             | 
| 19 | 
            +
                    return newroute
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                if debug:
         | 
| 22 | 
            +
                    logger = logging.getLogger("main")
         | 
| 23 | 
            +
                    log_format = logging.Formatter("[%(asctime)s %(levelname)s] %(message)s")
         | 
| 24 | 
            +
                    file_handler = RotatingFileHandler(
         | 
| 25 | 
            +
                        "debug_requests.log", maxBytes=1000000, backupCount=10
         | 
| 26 | 
            +
                    )
         | 
| 27 | 
            +
                    file_handler.setFormatter(log_format)
         | 
| 28 | 
            +
                    logger.addHandler(file_handler)
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                app = Flask(__name__)
         | 
| 31 | 
            +
                app.route = prefix_route(app.route, url_root)
         | 
| 32 | 
            +
                translation_server = TranslationServer()
         | 
| 33 | 
            +
                translation_server.start(config_file)
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                @app.route("/models", methods=["GET"])
         | 
| 36 | 
            +
                def get_models():
         | 
| 37 | 
            +
                    out = translation_server.list_models()
         | 
| 38 | 
            +
                    return jsonify(out)
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                @app.route("/health", methods=["GET"])
         | 
| 41 | 
            +
                def health():
         | 
| 42 | 
            +
                    out = {}
         | 
| 43 | 
            +
                    out["status"] = STATUS_OK
         | 
| 44 | 
            +
                    return jsonify(out)
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                @app.route("/clone_model/<int:model_id>", methods=["POST"])
         | 
| 47 | 
            +
                def clone_model(model_id):
         | 
| 48 | 
            +
                    out = {}
         | 
| 49 | 
            +
                    data = request.get_json(force=True)
         | 
| 50 | 
            +
                    timeout = -1
         | 
| 51 | 
            +
                    if "timeout" in data:
         | 
| 52 | 
            +
                        timeout = data["timeout"]
         | 
| 53 | 
            +
                        del data["timeout"]
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                    opt = data.get("opt", None)
         | 
| 56 | 
            +
                    try:
         | 
| 57 | 
            +
                        model_id, load_time = translation_server.clone_model(model_id, opt, timeout)
         | 
| 58 | 
            +
                    except ServerModelError as e:
         | 
| 59 | 
            +
                        out["status"] = STATUS_ERROR
         | 
| 60 | 
            +
                        out["error"] = str(e)
         | 
| 61 | 
            +
                    else:
         | 
| 62 | 
            +
                        out["status"] = STATUS_OK
         | 
| 63 | 
            +
                        out["model_id"] = model_id
         | 
| 64 | 
            +
                        out["load_time"] = load_time
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                    return jsonify(out)
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                @app.route("/unload_model/<int:model_id>", methods=["GET"])
         | 
| 69 | 
            +
                def unload_model(model_id):
         | 
| 70 | 
            +
                    out = {"model_id": model_id}
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                    try:
         | 
| 73 | 
            +
                        translation_server.unload_model(model_id)
         | 
| 74 | 
            +
                        out["status"] = STATUS_OK
         | 
| 75 | 
            +
                    except Exception as e:
         | 
| 76 | 
            +
                        out["status"] = STATUS_ERROR
         | 
| 77 | 
            +
                        out["error"] = str(e)
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                    return jsonify(out)
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                @app.route("/translate", methods=["POST"])
         | 
| 82 | 
            +
                def translate():
         | 
| 83 | 
            +
                    inputs = request.get_json(force=True)
         | 
| 84 | 
            +
                    if debug:
         | 
| 85 | 
            +
                        logger.info(inputs)
         | 
| 86 | 
            +
                    out = {}
         | 
| 87 | 
            +
                    try:
         | 
| 88 | 
            +
                        trans, scores, n_best, _, aligns, align_scores = translation_server.run(
         | 
| 89 | 
            +
                            inputs
         | 
| 90 | 
            +
                        )
         | 
| 91 | 
            +
                        assert len(trans) == len(inputs) * n_best
         | 
| 92 | 
            +
                        assert len(scores) == len(inputs) * n_best
         | 
| 93 | 
            +
                        assert len(aligns) == len(inputs) * n_best
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                        out = [[] for _ in range(n_best)]
         | 
| 96 | 
            +
                        for i in range(len(trans)):
         | 
| 97 | 
            +
                            response = {
         | 
| 98 | 
            +
                                "src": inputs[i // n_best]["src"],
         | 
| 99 | 
            +
                                "tgt": trans[i],
         | 
| 100 | 
            +
                                "n_best": n_best,
         | 
| 101 | 
            +
                                "pred_score": scores[i],
         | 
| 102 | 
            +
                            }
         | 
| 103 | 
            +
                            if len(aligns[i]) > 0 and aligns[i][0] is not None:
         | 
| 104 | 
            +
                                response["align"] = aligns[i]
         | 
| 105 | 
            +
                                response["align_score"] = align_scores[i]
         | 
| 106 | 
            +
                            out[i % n_best].append(response)
         | 
| 107 | 
            +
                    except ServerModelError as e:
         | 
| 108 | 
            +
                        model_id = inputs[0].get("id")
         | 
| 109 | 
            +
                        if debug:
         | 
| 110 | 
            +
                            logger.warning(
         | 
| 111 | 
            +
                                "Unload model #{} " "because of an error".format(model_id)
         | 
| 112 | 
            +
                            )
         | 
| 113 | 
            +
                        translation_server.models[model_id].unload()
         | 
| 114 | 
            +
                        out["error"] = str(e)
         | 
| 115 | 
            +
                        out["status"] = STATUS_ERROR
         | 
| 116 | 
            +
                    if debug:
         | 
| 117 | 
            +
                        logger.info(out)
         | 
| 118 | 
            +
                    return jsonify(out)
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                @app.route("/to_cpu/<int:model_id>", methods=["GET"])
         | 
| 121 | 
            +
                def to_cpu(model_id):
         | 
| 122 | 
            +
                    out = {"model_id": model_id}
         | 
| 123 | 
            +
                    translation_server.models[model_id].to_cpu()
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                    out["status"] = STATUS_OK
         | 
| 126 | 
            +
                    return jsonify(out)
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                @app.route("/to_gpu/<int:model_id>", methods=["GET"])
         | 
| 129 | 
            +
                def to_gpu(model_id):
         | 
| 130 | 
            +
                    out = {"model_id": model_id}
         | 
| 131 | 
            +
                    translation_server.models[model_id].to_gpu()
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                    out["status"] = STATUS_OK
         | 
| 134 | 
            +
                    return jsonify(out)
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                serve(app, host=host, port=port)
         | 
| 137 | 
            +
             | 
| 138 | 
            +
             | 
| 139 | 
            +
            def _get_parser():
         | 
| 140 | 
            +
                parser = configargparse.ArgumentParser(
         | 
| 141 | 
            +
                    config_file_parser_class=configargparse.YAMLConfigFileParser,
         | 
| 142 | 
            +
                    description="OpenNMT-py REST Server",
         | 
| 143 | 
            +
                )
         | 
| 144 | 
            +
                parser.add_argument("--ip", type=str, default="0.0.0.0")
         | 
| 145 | 
            +
                parser.add_argument("--port", type=int, default="5000")
         | 
| 146 | 
            +
                parser.add_argument("--url_root", type=str, default="/translator")
         | 
| 147 | 
            +
                parser.add_argument("--debug", "-d", action="store_true")
         | 
| 148 | 
            +
                parser.add_argument(
         | 
| 149 | 
            +
                    "--config", "-c", type=str, default="./available_models/conf.json"
         | 
| 150 | 
            +
                )
         | 
| 151 | 
            +
                return parser
         | 
| 152 | 
            +
             | 
| 153 | 
            +
             | 
| 154 | 
            +
            def main():
         | 
| 155 | 
            +
                parser = _get_parser()
         | 
| 156 | 
            +
                args = parser.parse_args()
         | 
| 157 | 
            +
                start(
         | 
| 158 | 
            +
                    args.config,
         | 
| 159 | 
            +
                    url_root=args.url_root,
         | 
| 160 | 
            +
                    host=args.ip,
         | 
| 161 | 
            +
                    port=args.port,
         | 
| 162 | 
            +
                    debug=args.debug,
         | 
| 163 | 
            +
                )
         | 
| 164 | 
            +
             | 
| 165 | 
            +
             | 
| 166 | 
            +
            if __name__ == "__main__":
         | 
| 167 | 
            +
                main()
         | 
    	
        onmt/bin/train.py
    ADDED
    
    | @@ -0,0 +1,71 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #!/usr/bin/env python
         | 
| 2 | 
            +
            """Train models with dynamic data."""
         | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            from functools import partial
         | 
| 5 | 
            +
            from onmt.utils.distributed import ErrorHandler, spawned_train
         | 
| 6 | 
            +
            from onmt.utils.misc import set_random_seed
         | 
| 7 | 
            +
            from onmt.utils.logging import init_logger, logger
         | 
| 8 | 
            +
            from onmt.utils.parse import ArgumentParser
         | 
| 9 | 
            +
            from onmt.opts import train_opts
         | 
| 10 | 
            +
            from onmt.train_single import main as single_main
         | 
| 11 | 
            +
             | 
| 12 | 
            +
             | 
| 13 | 
            +
            # Set sharing strategy manually instead of default based on the OS.
         | 
| 14 | 
            +
            # torch.multiprocessing.set_sharing_strategy('file_system')
         | 
| 15 | 
            +
             | 
| 16 | 
            +
             | 
| 17 | 
            +
            def train(opt):
         | 
| 18 | 
            +
                init_logger(opt.log_file)
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                ArgumentParser.validate_train_opts(opt)
         | 
| 21 | 
            +
                ArgumentParser.update_model_opts(opt)
         | 
| 22 | 
            +
                ArgumentParser.validate_model_opts(opt)
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                set_random_seed(opt.seed, False)
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                train_process = partial(single_main)
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                nb_gpu = len(opt.gpu_ranks)
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                if opt.world_size > 1:
         | 
| 31 | 
            +
                    mp = torch.multiprocessing.get_context("spawn")
         | 
| 32 | 
            +
                    # Create a thread to listen for errors in the child processes.
         | 
| 33 | 
            +
                    error_queue = mp.SimpleQueue()
         | 
| 34 | 
            +
                    error_handler = ErrorHandler(error_queue)
         | 
| 35 | 
            +
                    # Train with multiprocessing.
         | 
| 36 | 
            +
                    procs = []
         | 
| 37 | 
            +
                    for device_id in range(nb_gpu):
         | 
| 38 | 
            +
                        procs.append(
         | 
| 39 | 
            +
                            mp.Process(
         | 
| 40 | 
            +
                                target=spawned_train,
         | 
| 41 | 
            +
                                args=(train_process, opt, device_id, error_queue),
         | 
| 42 | 
            +
                                daemon=False,
         | 
| 43 | 
            +
                            )
         | 
| 44 | 
            +
                        )
         | 
| 45 | 
            +
                        procs[device_id].start()
         | 
| 46 | 
            +
                        logger.info(" Starting process pid: %d  " % procs[device_id].pid)
         | 
| 47 | 
            +
                        error_handler.add_child(procs[device_id].pid)
         | 
| 48 | 
            +
                    for p in procs:
         | 
| 49 | 
            +
                        p.join()
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                elif nb_gpu == 1:  # case 1 GPU only
         | 
| 52 | 
            +
                    train_process(opt, device_id=0)
         | 
| 53 | 
            +
                else:  # case only CPU
         | 
| 54 | 
            +
                    train_process(opt, device_id=-1)
         | 
| 55 | 
            +
             | 
| 56 | 
            +
             | 
| 57 | 
            +
            def _get_parser():
         | 
| 58 | 
            +
                parser = ArgumentParser(description="train.py")
         | 
| 59 | 
            +
                train_opts(parser)
         | 
| 60 | 
            +
                return parser
         | 
| 61 | 
            +
             | 
| 62 | 
            +
             | 
| 63 | 
            +
            def main():
         | 
| 64 | 
            +
                parser = _get_parser()
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                opt, unknown = parser.parse_known_args()
         | 
| 67 | 
            +
                train(opt)
         | 
| 68 | 
            +
             | 
| 69 | 
            +
             | 
| 70 | 
            +
            if __name__ == "__main__":
         | 
| 71 | 
            +
                main()
         | 
    	
        onmt/bin/translate.py
    ADDED
    
    | @@ -0,0 +1,60 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #!/usr/bin/env python
         | 
| 2 | 
            +
            # -*- coding: utf-8 -*-
         | 
| 3 | 
            +
            from onmt.utils.logging import init_logger
         | 
| 4 | 
            +
            from onmt.translate.translator import build_translator
         | 
| 5 | 
            +
            from onmt.inputters.dynamic_iterator import build_dynamic_dataset_iter
         | 
| 6 | 
            +
            from onmt.inputters.inputter import IterOnDevice
         | 
| 7 | 
            +
            from onmt.transforms import get_transforms_cls
         | 
| 8 | 
            +
            from onmt.constants import CorpusTask
         | 
| 9 | 
            +
            import onmt.opts as opts
         | 
| 10 | 
            +
            from onmt.utils.parse import ArgumentParser
         | 
| 11 | 
            +
            from onmt.utils.misc import use_gpu, set_random_seed
         | 
| 12 | 
            +
             | 
| 13 | 
            +
             | 
| 14 | 
            +
            def translate(opt):
         | 
| 15 | 
            +
                ArgumentParser.validate_translate_opts(opt)
         | 
| 16 | 
            +
                ArgumentParser._get_all_transform_translate(opt)
         | 
| 17 | 
            +
                ArgumentParser._validate_transforms_opts(opt)
         | 
| 18 | 
            +
                ArgumentParser.validate_translate_opts_dynamic(opt)
         | 
| 19 | 
            +
                logger = init_logger(opt.log_file)
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                set_random_seed(opt.seed, use_gpu(opt))
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                translator = build_translator(opt, logger=logger, report_score=True)
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                transforms_cls = get_transforms_cls(opt._all_transform)
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                infer_iter = build_dynamic_dataset_iter(
         | 
| 28 | 
            +
                    opt,
         | 
| 29 | 
            +
                    transforms_cls,
         | 
| 30 | 
            +
                    translator.vocabs,
         | 
| 31 | 
            +
                    task=CorpusTask.INFER,
         | 
| 32 | 
            +
                    copy=translator.copy_attn,
         | 
| 33 | 
            +
                )
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                infer_iter = IterOnDevice(infer_iter, opt.gpu)
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                _, _ = translator._translate(
         | 
| 38 | 
            +
                    infer_iter,
         | 
| 39 | 
            +
                    transform=infer_iter.transform,
         | 
| 40 | 
            +
                    attn_debug=opt.attn_debug,
         | 
| 41 | 
            +
                    align_debug=opt.align_debug,
         | 
| 42 | 
            +
                )
         | 
| 43 | 
            +
             | 
| 44 | 
            +
             | 
| 45 | 
            +
            def _get_parser():
         | 
| 46 | 
            +
                parser = ArgumentParser(description="translate.py")
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                opts.config_opts(parser)
         | 
| 49 | 
            +
                opts.translate_opts(parser, dynamic=True)
         | 
| 50 | 
            +
                return parser
         | 
| 51 | 
            +
             | 
| 52 | 
            +
             | 
| 53 | 
            +
            def main():
         | 
| 54 | 
            +
                parser = _get_parser()
         | 
| 55 | 
            +
                opt = parser.parse_args()
         | 
| 56 | 
            +
                translate(opt)
         | 
| 57 | 
            +
             | 
| 58 | 
            +
             | 
| 59 | 
            +
            if __name__ == "__main__":
         | 
| 60 | 
            +
                main()
         | 
    	
        onmt/constants.py
    ADDED
    
    | @@ -0,0 +1,41 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            """Define constant values used across the project."""
         | 
| 2 | 
            +
             | 
| 3 | 
            +
             | 
| 4 | 
            +
            class DefaultTokens(object):
         | 
| 5 | 
            +
                PAD = "<blank>"
         | 
| 6 | 
            +
                BOS = "<s>"
         | 
| 7 | 
            +
                EOS = "</s>"
         | 
| 8 | 
            +
                UNK = "<unk>"
         | 
| 9 | 
            +
                MASK = "<mask>"
         | 
| 10 | 
            +
                VOCAB_PAD = "averyunlikelytoken"
         | 
| 11 | 
            +
                SENT_FULL_STOPS = [".", "?", "!"]
         | 
| 12 | 
            +
                PHRASE_TABLE_SEPARATOR = "|||"
         | 
| 13 | 
            +
                ALIGNMENT_SEPARATOR = " ||| "
         | 
| 14 | 
            +
                SEP = "⦅newline⦆"
         | 
| 15 | 
            +
                MASK_BEFORE = "⦅_mask_before_⦆"
         | 
| 16 | 
            +
             | 
| 17 | 
            +
             | 
| 18 | 
            +
            class CorpusName(object):
         | 
| 19 | 
            +
                VALID = "valid"
         | 
| 20 | 
            +
                TRAIN = "train"
         | 
| 21 | 
            +
                SAMPLE = "sample"
         | 
| 22 | 
            +
                INFER = "infer"
         | 
| 23 | 
            +
             | 
| 24 | 
            +
             | 
| 25 | 
            +
            class CorpusTask(object):
         | 
| 26 | 
            +
                TRAIN = "train"
         | 
| 27 | 
            +
                VALID = "valid"
         | 
| 28 | 
            +
                INFER = "infer"
         | 
| 29 | 
            +
             | 
| 30 | 
            +
             | 
| 31 | 
            +
            class SubwordMarker(object):
         | 
| 32 | 
            +
                SPACER = "▁"
         | 
| 33 | 
            +
                JOINER = "■"
         | 
| 34 | 
            +
                BEGIN_UPPERCASE = "⦅mrk_begin_case_region_U⦆"
         | 
| 35 | 
            +
                END_UPPERCASE = "⦅mrk_end_case_region_U⦆"
         | 
| 36 | 
            +
                BEGIN_CASED = "⦅mrk_case_modifier_C⦆"
         | 
| 37 | 
            +
             | 
| 38 | 
            +
             | 
| 39 | 
            +
            class ModelTask(object):
         | 
| 40 | 
            +
                LANGUAGE_MODEL = "lm"
         | 
| 41 | 
            +
                SEQ2SEQ = "seq2seq"
         | 
    	
        onmt/decoders/__init__.py
    ADDED
    
    | @@ -0,0 +1,63 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            """Module defining decoders."""
         | 
| 2 | 
            +
            import os
         | 
| 3 | 
            +
            import importlib
         | 
| 4 | 
            +
            from onmt.decoders.decoder import DecoderBase, InputFeedRNNDecoder, StdRNNDecoder
         | 
| 5 | 
            +
            from onmt.decoders.transformer import TransformerDecoder, TransformerLMDecoder
         | 
| 6 | 
            +
            from onmt.decoders.cnn_decoder import CNNDecoder
         | 
| 7 | 
            +
             | 
| 8 | 
            +
             | 
| 9 | 
            +
            str2dec = {
         | 
| 10 | 
            +
                "rnn": StdRNNDecoder,
         | 
| 11 | 
            +
                "ifrnn": InputFeedRNNDecoder,
         | 
| 12 | 
            +
                "cnn": CNNDecoder,
         | 
| 13 | 
            +
                "transformer": TransformerDecoder,
         | 
| 14 | 
            +
                "transformer_lm": TransformerLMDecoder,
         | 
| 15 | 
            +
            }
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            __all__ = [
         | 
| 18 | 
            +
                "DecoderBase",
         | 
| 19 | 
            +
                "TransformerDecoder",
         | 
| 20 | 
            +
                "StdRNNDecoder",
         | 
| 21 | 
            +
                "CNNDecoder",
         | 
| 22 | 
            +
                "InputFeedRNNDecoder",
         | 
| 23 | 
            +
                "str2dec",
         | 
| 24 | 
            +
                "TransformerLMDecoder",
         | 
| 25 | 
            +
            ]
         | 
| 26 | 
            +
             | 
| 27 | 
            +
             | 
| 28 | 
            +
            def get_decoders_cls(decoders_names):
         | 
| 29 | 
            +
                """Return valid encoder class indicated in `decoders_names`."""
         | 
| 30 | 
            +
                decoders_cls = {}
         | 
| 31 | 
            +
                for name in decoders_names:
         | 
| 32 | 
            +
                    if name not in str2dec:
         | 
| 33 | 
            +
                        raise ValueError("%s decoder not supported!" % name)
         | 
| 34 | 
            +
                    decoders_cls[name] = str2dec[name]
         | 
| 35 | 
            +
                return decoders_cls
         | 
| 36 | 
            +
             | 
| 37 | 
            +
             | 
| 38 | 
            +
            def register_decoder(name):
         | 
| 39 | 
            +
                """Encoder register that can be used to add new encoder class."""
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                def register_decoder_cls(cls):
         | 
| 42 | 
            +
                    if name in str2dec:
         | 
| 43 | 
            +
                        raise ValueError("Cannot register duplicate decoder ({})".format(name))
         | 
| 44 | 
            +
                    if not issubclass(cls, DecoderBase):
         | 
| 45 | 
            +
                        raise ValueError(f"decoder ({name}: {cls.__name_}) must extend DecoderBase")
         | 
| 46 | 
            +
                    str2dec[name] = cls
         | 
| 47 | 
            +
                    __all__.append(cls.__name__)  # added to be complete
         | 
| 48 | 
            +
                    return cls
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                return register_decoder_cls
         | 
| 51 | 
            +
             | 
| 52 | 
            +
             | 
| 53 | 
            +
            # Auto import python files in this directory
         | 
| 54 | 
            +
            decoder_dir = os.path.dirname(__file__)
         | 
| 55 | 
            +
            for file in os.listdir(decoder_dir):
         | 
| 56 | 
            +
                path = os.path.join(decoder_dir, file)
         | 
| 57 | 
            +
                if (
         | 
| 58 | 
            +
                    not file.startswith("_")
         | 
| 59 | 
            +
                    and not file.startswith(".")
         | 
| 60 | 
            +
                    and (file.endswith(".py") or os.path.isdir(path))
         | 
| 61 | 
            +
                ):
         | 
| 62 | 
            +
                    file_name = file[: file.find(".py")] if file.endswith(".py") else file
         | 
| 63 | 
            +
                    module = importlib.import_module("onmt.decoders." + file_name)
         | 
    	
        onmt/decoders/__pycache__/__init__.cpython-311.pyc
    ADDED
    
    | Binary file (3 kB). View file | 
|  | 
    	
        onmt/decoders/__pycache__/__init__.cpython-38.pyc
    ADDED
    
    | Binary file (1.84 kB). View file | 
|  | 
    	
        onmt/decoders/__pycache__/cnn_decoder.cpython-311.pyc
    ADDED
    
    | Binary file (7.32 kB). View file | 
|  | 
    	
        onmt/decoders/__pycache__/cnn_decoder.cpython-38.pyc
    ADDED
    
    | Binary file (4.02 kB). View file | 
|  | 
    	
        onmt/decoders/__pycache__/decoder.cpython-311.pyc
    ADDED
    
    | Binary file (18.4 kB). View file | 
|  | 
    	
        onmt/decoders/__pycache__/decoder.cpython-38.pyc
    ADDED
    
    | Binary file (11.5 kB). View file | 
|  | 
    	
        onmt/decoders/__pycache__/ensemble.cpython-311.pyc
    ADDED
    
    | Binary file (11 kB). View file | 
|  | 
    	
        onmt/decoders/__pycache__/ensemble.cpython-38.pyc
    ADDED
    
    | Binary file (7.13 kB). View file | 
|  | 
    	
        onmt/decoders/__pycache__/transformer.cpython-311.pyc
    ADDED
    
    | Binary file (32.9 kB). View file | 
|  | 
    	
        onmt/decoders/__pycache__/transformer.cpython-38.pyc
    ADDED
    
    | Binary file (20.4 kB). View file | 
|  | 
    	
        onmt/decoders/cnn_decoder.py
    ADDED
    
    | @@ -0,0 +1,141 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            """Implementation of the CNN Decoder part of
         | 
| 2 | 
            +
            "Convolutional Sequence to Sequence Learning"
         | 
| 3 | 
            +
            """
         | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
            import torch.nn as nn
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            from onmt.modules import ConvMultiStepAttention, GlobalAttention
         | 
| 8 | 
            +
            from onmt.utils.cnn_factory import shape_transform, GatedConv
         | 
| 9 | 
            +
            from onmt.decoders.decoder import DecoderBase
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            SCALE_WEIGHT = 0.5**0.5
         | 
| 12 | 
            +
             | 
| 13 | 
            +
             | 
| 14 | 
            +
            class CNNDecoder(DecoderBase):
         | 
| 15 | 
            +
                """Decoder based on "Convolutional Sequence to Sequence Learning"
         | 
| 16 | 
            +
                :cite:`DBLP:journals/corr/GehringAGYD17`.
         | 
| 17 | 
            +
             | 
| 18 | 
            +
                Consists of residual convolutional layers, with ConvMultiStepAttention.
         | 
| 19 | 
            +
                """
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                def __init__(
         | 
| 22 | 
            +
                    self,
         | 
| 23 | 
            +
                    num_layers,
         | 
| 24 | 
            +
                    hidden_size,
         | 
| 25 | 
            +
                    attn_type,
         | 
| 26 | 
            +
                    copy_attn,
         | 
| 27 | 
            +
                    cnn_kernel_width,
         | 
| 28 | 
            +
                    dropout,
         | 
| 29 | 
            +
                    embeddings,
         | 
| 30 | 
            +
                    copy_attn_type,
         | 
| 31 | 
            +
                ):
         | 
| 32 | 
            +
                    super(CNNDecoder, self).__init__()
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                    self.cnn_kernel_width = cnn_kernel_width
         | 
| 35 | 
            +
                    self.embeddings = embeddings
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                    # Decoder State
         | 
| 38 | 
            +
                    self.state = {}
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                    input_size = self.embeddings.embedding_size
         | 
| 41 | 
            +
                    self.linear = nn.Linear(input_size, hidden_size)
         | 
| 42 | 
            +
                    self.conv_layers = nn.ModuleList(
         | 
| 43 | 
            +
                        [
         | 
| 44 | 
            +
                            GatedConv(hidden_size, cnn_kernel_width, dropout, True)
         | 
| 45 | 
            +
                            for i in range(num_layers)
         | 
| 46 | 
            +
                        ]
         | 
| 47 | 
            +
                    )
         | 
| 48 | 
            +
                    self.attn_layers = nn.ModuleList(
         | 
| 49 | 
            +
                        [ConvMultiStepAttention(hidden_size) for i in range(num_layers)]
         | 
| 50 | 
            +
                    )
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                    # CNNDecoder has its own attention mechanism.
         | 
| 53 | 
            +
                    # Set up a separate copy attention layer if needed.
         | 
| 54 | 
            +
                    assert not copy_attn, "Copy mechanism not yet tested in conv2conv"
         | 
| 55 | 
            +
                    if copy_attn:
         | 
| 56 | 
            +
                        self.copy_attn = GlobalAttention(hidden_size, attn_type=copy_attn_type)
         | 
| 57 | 
            +
                    else:
         | 
| 58 | 
            +
                        self.copy_attn = None
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                @classmethod
         | 
| 61 | 
            +
                def from_opt(cls, opt, embeddings):
         | 
| 62 | 
            +
                    """Alternate constructor."""
         | 
| 63 | 
            +
                    return cls(
         | 
| 64 | 
            +
                        opt.dec_layers,
         | 
| 65 | 
            +
                        opt.dec_hid_size,
         | 
| 66 | 
            +
                        opt.global_attention,
         | 
| 67 | 
            +
                        opt.copy_attn,
         | 
| 68 | 
            +
                        opt.cnn_kernel_width,
         | 
| 69 | 
            +
                        opt.dropout[0] if type(opt.dropout) is list else opt.dropout,
         | 
| 70 | 
            +
                        embeddings,
         | 
| 71 | 
            +
                        opt.copy_attn_type,
         | 
| 72 | 
            +
                    )
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                def init_state(self, _, enc_out, enc_hidden):
         | 
| 75 | 
            +
                    """Init decoder state."""
         | 
| 76 | 
            +
                    self.state["src"] = (enc_out + enc_hidden) * SCALE_WEIGHT
         | 
| 77 | 
            +
                    self.state["previous_input"] = None
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                def map_state(self, fn):
         | 
| 80 | 
            +
                    self.state["src"] = fn(self.state["src"], 0)
         | 
| 81 | 
            +
                    if self.state["previous_input"] is not None:
         | 
| 82 | 
            +
                        self.state["previous_input"] = fn(self.state["previous_input"], 0)
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                def detach_state(self):
         | 
| 85 | 
            +
                    self.state["previous_input"] = self.state["previous_input"].detach()
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                def forward(self, tgt, enc_out, step=None, **kwargs):
         | 
| 88 | 
            +
                    """See :obj:`onmt.modules.RNNDecoderBase.forward()`"""
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                    if self.state["previous_input"] is not None:
         | 
| 91 | 
            +
                        tgt = torch.cat([self.state["previous_input"], tgt], 1)
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                    dec_outs = []
         | 
| 94 | 
            +
                    attns = {"std": []}
         | 
| 95 | 
            +
                    if self.copy_attn is not None:
         | 
| 96 | 
            +
                        attns["copy"] = []
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                    emb = self.embeddings(tgt)
         | 
| 99 | 
            +
                    assert emb.dim() == 3  # batch x len x embedding_dim
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                    tgt_emb = emb
         | 
| 102 | 
            +
                    # The output of CNNEncoder.
         | 
| 103 | 
            +
                    enc_out_t = enc_out
         | 
| 104 | 
            +
                    # The combination of output of CNNEncoder and source embeddings.
         | 
| 105 | 
            +
                    enc_out_c = self.state["src"]
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                    emb_reshape = tgt_emb.view(tgt_emb.size(0) * tgt_emb.size(1), -1)
         | 
| 108 | 
            +
                    linear_out = self.linear(emb_reshape)
         | 
| 109 | 
            +
                    x = linear_out.view(tgt_emb.size(0), tgt_emb.size(1), -1)
         | 
| 110 | 
            +
                    x = shape_transform(x)
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                    pad = torch.zeros(x.size(0), x.size(1), self.cnn_kernel_width - 1, 1)
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                    pad = pad.type_as(x)
         | 
| 115 | 
            +
                    base_target_emb = x
         | 
| 116 | 
            +
             | 
| 117 | 
            +
                    for conv, attention in zip(self.conv_layers, self.attn_layers):
         | 
| 118 | 
            +
                        new_target_input = torch.cat([pad, x], 2)
         | 
| 119 | 
            +
                        out = conv(new_target_input)
         | 
| 120 | 
            +
                        c, attn = attention(base_target_emb, out, enc_out_t, enc_out_c)
         | 
| 121 | 
            +
                        x = (x + (c + out) * SCALE_WEIGHT) * SCALE_WEIGHT
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                    dec_outs = x.squeeze(3).transpose(1, 2)
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                    # Process the result and update the attentions.
         | 
| 126 | 
            +
                    if self.state["previous_input"] is not None:
         | 
| 127 | 
            +
                        dec_outs = dec_outs[:, self.state["previous_input"].size(1) :, :]
         | 
| 128 | 
            +
                        attn = attn[:, self.state["previous_input"].size(1) :].squeeze()
         | 
| 129 | 
            +
                        attn = torch.stack([attn])
         | 
| 130 | 
            +
                    attns["std"] = attn
         | 
| 131 | 
            +
                    if self.copy_attn is not None:
         | 
| 132 | 
            +
                        attns["copy"] = attn
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                    # Update the state.
         | 
| 135 | 
            +
                    self.state["previous_input"] = tgt
         | 
| 136 | 
            +
                    # TODO change the way attns is returned dict => list or tuple (onnx)
         | 
| 137 | 
            +
                    return dec_outs, attns
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                def update_dropout(self, dropout, attention_dropout=None):
         | 
| 140 | 
            +
                    for layer in self.conv_layers:
         | 
| 141 | 
            +
                        layer.dropout.p = dropout
         | 
    	
        onmt/decoders/decoder.py
    ADDED
    
    | @@ -0,0 +1,405 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import torch.nn as nn
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            from onmt.modules.stacked_rnn import StackedLSTM, StackedGRU
         | 
| 5 | 
            +
            from onmt.modules import context_gate_factory, GlobalAttention
         | 
| 6 | 
            +
            from onmt.utils.rnn_factory import rnn_factory
         | 
| 7 | 
            +
             | 
| 8 | 
            +
             | 
| 9 | 
            +
            class DecoderBase(nn.Module):
         | 
| 10 | 
            +
                """Abstract class for decoders.
         | 
| 11 | 
            +
             | 
| 12 | 
            +
                Args:
         | 
| 13 | 
            +
                    attentional (bool): The decoder returns non-empty attention.
         | 
| 14 | 
            +
                """
         | 
| 15 | 
            +
             | 
| 16 | 
            +
                def __init__(self, attentional=True):
         | 
| 17 | 
            +
                    super(DecoderBase, self).__init__()
         | 
| 18 | 
            +
                    self.attentional = attentional
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                @classmethod
         | 
| 21 | 
            +
                def from_opt(cls, opt, embeddings):
         | 
| 22 | 
            +
                    """Alternate constructor.
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                    Subclasses should override this method.
         | 
| 25 | 
            +
                    """
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                    raise NotImplementedError
         | 
| 28 | 
            +
             | 
| 29 | 
            +
             | 
| 30 | 
            +
            class RNNDecoderBase(DecoderBase):
         | 
| 31 | 
            +
                """Base recurrent attention-based decoder class.
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                Specifies the interface used by different decoder types
         | 
| 34 | 
            +
                and required by :class:`~onmt.models.NMTModel`.
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                Args:
         | 
| 37 | 
            +
                   rnn_type (str):
         | 
| 38 | 
            +
                      style of recurrent unit to use, one of [RNN, LSTM, GRU, SRU]
         | 
| 39 | 
            +
                   bidirectional_encoder (bool) : use with a bidirectional encoder
         | 
| 40 | 
            +
                   num_layers (int) : number of stacked layers
         | 
| 41 | 
            +
                   hidden_size (int) : hidden size of each layer
         | 
| 42 | 
            +
                   attn_type (str) : see :class:`~onmt.modules.GlobalAttention`
         | 
| 43 | 
            +
                   attn_func (str) : see :class:`~onmt.modules.GlobalAttention`
         | 
| 44 | 
            +
                   coverage_attn (str): see :class:`~onmt.modules.GlobalAttention`
         | 
| 45 | 
            +
                   context_gate (str): see :class:`~onmt.modules.ContextGate`
         | 
| 46 | 
            +
                   copy_attn (bool): setup a separate copy attention mechanism
         | 
| 47 | 
            +
                   dropout (float) : dropout value for :class:`torch.nn.Dropout`
         | 
| 48 | 
            +
                   embeddings (onmt.modules.Embeddings): embedding module to use
         | 
| 49 | 
            +
                   reuse_copy_attn (bool): reuse the attention for copying
         | 
| 50 | 
            +
                   copy_attn_type (str): The copy attention style. See
         | 
| 51 | 
            +
                    :class:`~onmt.modules.GlobalAttention`.
         | 
| 52 | 
            +
                """
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                def __init__(
         | 
| 55 | 
            +
                    self,
         | 
| 56 | 
            +
                    rnn_type,
         | 
| 57 | 
            +
                    bidirectional_encoder,
         | 
| 58 | 
            +
                    num_layers,
         | 
| 59 | 
            +
                    hidden_size,
         | 
| 60 | 
            +
                    attn_type="general",
         | 
| 61 | 
            +
                    attn_func="softmax",
         | 
| 62 | 
            +
                    coverage_attn=False,
         | 
| 63 | 
            +
                    context_gate=None,
         | 
| 64 | 
            +
                    copy_attn=False,
         | 
| 65 | 
            +
                    dropout=0.0,
         | 
| 66 | 
            +
                    embeddings=None,
         | 
| 67 | 
            +
                    reuse_copy_attn=False,
         | 
| 68 | 
            +
                    copy_attn_type="general",
         | 
| 69 | 
            +
                ):
         | 
| 70 | 
            +
                    super(RNNDecoderBase, self).__init__(
         | 
| 71 | 
            +
                        attentional=attn_type != "none" and attn_type is not None
         | 
| 72 | 
            +
                    )
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                    self.bidirectional_encoder = bidirectional_encoder
         | 
| 75 | 
            +
                    self.num_layers = num_layers
         | 
| 76 | 
            +
                    self.hidden_size = hidden_size
         | 
| 77 | 
            +
                    self.embeddings = embeddings
         | 
| 78 | 
            +
                    self.dropout = nn.Dropout(dropout)
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                    # Decoder state
         | 
| 81 | 
            +
                    self.state = {}
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                    # Build the RNN.
         | 
| 84 | 
            +
                    self.rnn = self._build_rnn(
         | 
| 85 | 
            +
                        rnn_type,
         | 
| 86 | 
            +
                        input_size=self._input_size,
         | 
| 87 | 
            +
                        hidden_size=hidden_size,
         | 
| 88 | 
            +
                        num_layers=num_layers,
         | 
| 89 | 
            +
                        dropout=dropout,
         | 
| 90 | 
            +
                    )
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                    # Set up the context gate.
         | 
| 93 | 
            +
                    self.context_gate = None
         | 
| 94 | 
            +
                    if context_gate is not None:
         | 
| 95 | 
            +
                        self.context_gate = context_gate_factory(
         | 
| 96 | 
            +
                            context_gate, self._input_size, hidden_size, hidden_size, hidden_size
         | 
| 97 | 
            +
                        )
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                    # Set up the standard attention.
         | 
| 100 | 
            +
                    self._coverage = coverage_attn
         | 
| 101 | 
            +
                    if not self.attentional:
         | 
| 102 | 
            +
                        if self._coverage:
         | 
| 103 | 
            +
                            raise ValueError("Cannot use coverage term with no attention.")
         | 
| 104 | 
            +
                        self.attn = None
         | 
| 105 | 
            +
                    else:
         | 
| 106 | 
            +
                        self.attn = GlobalAttention(
         | 
| 107 | 
            +
                            hidden_size,
         | 
| 108 | 
            +
                            coverage=coverage_attn,
         | 
| 109 | 
            +
                            attn_type=attn_type,
         | 
| 110 | 
            +
                            attn_func=attn_func,
         | 
| 111 | 
            +
                        )
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                    if copy_attn and not reuse_copy_attn:
         | 
| 114 | 
            +
                        if copy_attn_type == "none" or copy_attn_type is None:
         | 
| 115 | 
            +
                            raise ValueError("Cannot use copy_attn with copy_attn_type none")
         | 
| 116 | 
            +
                        self.copy_attn = GlobalAttention(
         | 
| 117 | 
            +
                            hidden_size, attn_type=copy_attn_type, attn_func=attn_func
         | 
| 118 | 
            +
                        )
         | 
| 119 | 
            +
                    else:
         | 
| 120 | 
            +
                        self.copy_attn = None
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                    self._reuse_copy_attn = reuse_copy_attn and copy_attn
         | 
| 123 | 
            +
                    if self._reuse_copy_attn and not self.attentional:
         | 
| 124 | 
            +
                        raise ValueError("Cannot reuse copy attention with no attention.")
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                @classmethod
         | 
| 127 | 
            +
                def from_opt(cls, opt, embeddings):
         | 
| 128 | 
            +
                    """Alternate constructor."""
         | 
| 129 | 
            +
                    return cls(
         | 
| 130 | 
            +
                        opt.rnn_type,
         | 
| 131 | 
            +
                        opt.brnn,
         | 
| 132 | 
            +
                        opt.dec_layers,
         | 
| 133 | 
            +
                        opt.dec_hid_size,
         | 
| 134 | 
            +
                        opt.global_attention,
         | 
| 135 | 
            +
                        opt.global_attention_function,
         | 
| 136 | 
            +
                        opt.coverage_attn,
         | 
| 137 | 
            +
                        opt.context_gate,
         | 
| 138 | 
            +
                        opt.copy_attn,
         | 
| 139 | 
            +
                        opt.dropout[0] if type(opt.dropout) is list else opt.dropout,
         | 
| 140 | 
            +
                        embeddings,
         | 
| 141 | 
            +
                        opt.reuse_copy_attn,
         | 
| 142 | 
            +
                        opt.copy_attn_type,
         | 
| 143 | 
            +
                    )
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                def init_state(self, src, _, enc_final_hs):
         | 
| 146 | 
            +
                    """Initialize decoder state with last state of the encoder."""
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                    def _fix_enc_hidden(hidden):
         | 
| 149 | 
            +
                        # The encoder hidden is  (layers*directions) x batch x dim.
         | 
| 150 | 
            +
                        # We need to convert it to layers x batch x (directions*dim).
         | 
| 151 | 
            +
                        if self.bidirectional_encoder:
         | 
| 152 | 
            +
                            hidden = torch.cat(
         | 
| 153 | 
            +
                                [hidden[0 : hidden.size(0) : 2], hidden[1 : hidden.size(0) : 2]], 2
         | 
| 154 | 
            +
                            )
         | 
| 155 | 
            +
                        return hidden
         | 
| 156 | 
            +
             | 
| 157 | 
            +
                    if isinstance(enc_final_hs, tuple):  # LSTM
         | 
| 158 | 
            +
                        self.state["hidden"] = tuple(
         | 
| 159 | 
            +
                            _fix_enc_hidden(enc_hid) for enc_hid in enc_final_hs
         | 
| 160 | 
            +
                        )
         | 
| 161 | 
            +
                    else:  # GRU
         | 
| 162 | 
            +
                        self.state["hidden"] = (_fix_enc_hidden(enc_final_hs),)
         | 
| 163 | 
            +
             | 
| 164 | 
            +
                    # Init the input feed.
         | 
| 165 | 
            +
                    batch_size = self.state["hidden"][0].size(1)
         | 
| 166 | 
            +
             | 
| 167 | 
            +
                    h_size = (batch_size, self.hidden_size)
         | 
| 168 | 
            +
                    self.state["input_feed"] = (
         | 
| 169 | 
            +
                        self.state["hidden"][0].data.new(*h_size).zero_().unsqueeze(0)
         | 
| 170 | 
            +
                    )
         | 
| 171 | 
            +
             | 
| 172 | 
            +
                    self.state["coverage"] = None
         | 
| 173 | 
            +
             | 
| 174 | 
            +
                def map_state(self, fn):
         | 
| 175 | 
            +
                    self.state["hidden"] = tuple(fn(h, 1) for h in self.state["hidden"])
         | 
| 176 | 
            +
                    self.state["input_feed"] = fn(self.state["input_feed"], 1)
         | 
| 177 | 
            +
                    if self._coverage and self.state["coverage"] is not None:
         | 
| 178 | 
            +
                        self.state["coverage"] = fn(self.state["coverage"], 1)
         | 
| 179 | 
            +
             | 
| 180 | 
            +
                def detach_state(self):
         | 
| 181 | 
            +
                    self.state["hidden"] = tuple(h.detach() for h in self.state["hidden"])
         | 
| 182 | 
            +
                    self.state["input_feed"] = self.state["input_feed"].detach()
         | 
| 183 | 
            +
                    if self._coverage and self.state["coverage"] is not None:
         | 
| 184 | 
            +
                        self.state["coverage"] = self.state["coverage"].detach()
         | 
| 185 | 
            +
             | 
| 186 | 
            +
                def forward(self, tgt, enc_out, src_len=None, step=None, **kwargs):
         | 
| 187 | 
            +
                    """
         | 
| 188 | 
            +
                    Args:
         | 
| 189 | 
            +
                        tgt (LongTensor): sequences of padded tokens
         | 
| 190 | 
            +
                             ``(batch, tgt_len, nfeats)``.
         | 
| 191 | 
            +
                        enc_out (FloatTensor): vectors from the encoder
         | 
| 192 | 
            +
                             ``(batch, src_len, hidden)``.
         | 
| 193 | 
            +
                        src_len (LongTensor): the padded source lengths
         | 
| 194 | 
            +
                            ``(batch,)``.
         | 
| 195 | 
            +
             | 
| 196 | 
            +
                    Returns:
         | 
| 197 | 
            +
                        (FloatTensor, dict[str, FloatTensor]):
         | 
| 198 | 
            +
             | 
| 199 | 
            +
                        * dec_outs: output from the decoder (after attn)
         | 
| 200 | 
            +
                          ``(batch, tgt_len, hidden)``.
         | 
| 201 | 
            +
                        * attns: distribution over src at each tgt
         | 
| 202 | 
            +
                          ``(batch, tgt_len, src_len)``.
         | 
| 203 | 
            +
                    """
         | 
| 204 | 
            +
                    dec_state, dec_outs, attns = self._run_forward_pass(
         | 
| 205 | 
            +
                        tgt, enc_out, src_len=src_len
         | 
| 206 | 
            +
                    )
         | 
| 207 | 
            +
             | 
| 208 | 
            +
                    # Update the state with the result.
         | 
| 209 | 
            +
                    if not isinstance(dec_state, tuple):
         | 
| 210 | 
            +
                        dec_state = (dec_state,)
         | 
| 211 | 
            +
                    self.state["hidden"] = dec_state
         | 
| 212 | 
            +
             | 
| 213 | 
            +
                    # Concatenates sequence of tensors along a new dimension.
         | 
| 214 | 
            +
                    # NOTE: v0.3 to 0.4: dec_outs / attns[*] may not be list
         | 
| 215 | 
            +
                    #       (in particular in case of SRU) it was not raising error in 0.3
         | 
| 216 | 
            +
                    #       since stack(Variable) was allowed.
         | 
| 217 | 
            +
                    #       In 0.4, SRU returns a tensor that shouldn't be stacke
         | 
| 218 | 
            +
                    if type(dec_outs) == list:
         | 
| 219 | 
            +
                        dec_outs = torch.stack(dec_outs, dim=1)
         | 
| 220 | 
            +
                        for k in attns:
         | 
| 221 | 
            +
                            if type(attns[k]) == list:
         | 
| 222 | 
            +
                                attns[k] = torch.stack(attns[k])
         | 
| 223 | 
            +
             | 
| 224 | 
            +
                    self.state["input_feed"] = dec_outs[:, -1, :].unsqueeze(0)
         | 
| 225 | 
            +
                    self.state["coverage"] = None
         | 
| 226 | 
            +
                    if "coverage" in attns:
         | 
| 227 | 
            +
                        self.state["coverage"] = attns["coverage"][-1, :, :].unsqueeze(0)
         | 
| 228 | 
            +
             | 
| 229 | 
            +
                    return dec_outs, attns
         | 
| 230 | 
            +
             | 
| 231 | 
            +
                def update_dropout(self, dropout, attention_dropout=None):
         | 
| 232 | 
            +
                    self.dropout.p = dropout
         | 
| 233 | 
            +
                    self.embeddings.update_dropout(dropout)
         | 
| 234 | 
            +
             | 
| 235 | 
            +
             | 
| 236 | 
            +
            class StdRNNDecoder(RNNDecoderBase):
         | 
| 237 | 
            +
                """Standard fully batched RNN decoder with attention.
         | 
| 238 | 
            +
             | 
| 239 | 
            +
                Faster implementation, uses CuDNN for implementation.
         | 
| 240 | 
            +
                See :class:`~onmt.decoders.decoder.RNNDecoderBase` for options.
         | 
| 241 | 
            +
             | 
| 242 | 
            +
             | 
| 243 | 
            +
                Based around the approach from
         | 
| 244 | 
            +
                "Neural Machine Translation By Jointly Learning To Align and Translate"
         | 
| 245 | 
            +
                :cite:`Bahdanau2015`
         | 
| 246 | 
            +
             | 
| 247 | 
            +
             | 
| 248 | 
            +
                Implemented without input_feeding and currently with no `coverage_attn`
         | 
| 249 | 
            +
                or `copy_attn` support.
         | 
| 250 | 
            +
                """
         | 
| 251 | 
            +
             | 
| 252 | 
            +
                def _run_forward_pass(self, tgt, enc_out, src_len=None):
         | 
| 253 | 
            +
                    """
         | 
| 254 | 
            +
                    Private helper for running the specific RNN forward pass.
         | 
| 255 | 
            +
                    Must be overriden by all subclasses.
         | 
| 256 | 
            +
             | 
| 257 | 
            +
                    Args:
         | 
| 258 | 
            +
                        tgt (LongTensor): a sequence of input tokens tensors
         | 
| 259 | 
            +
                            ``(batch, tgt_len, nfeats)``.
         | 
| 260 | 
            +
                        enc_out (FloatTensor): output(tensor sequence) from the
         | 
| 261 | 
            +
                            encoder RNN of size ``(batch, src_len, hidden_size)``.
         | 
| 262 | 
            +
                        src_len (LongTensor): the source enc_out lengths.
         | 
| 263 | 
            +
             | 
| 264 | 
            +
                    Returns:
         | 
| 265 | 
            +
                        (Tensor, List[FloatTensor], Dict[str, List[FloatTensor]):
         | 
| 266 | 
            +
             | 
| 267 | 
            +
                        * dec_state: final hidden state from the decoder.
         | 
| 268 | 
            +
                        * dec_outs: an array of output of every time
         | 
| 269 | 
            +
                          step from the decoder.
         | 
| 270 | 
            +
                        * attns: a dictionary of different
         | 
| 271 | 
            +
                          type of attention Tensor array of every time
         | 
| 272 | 
            +
                          step from the decoder.
         | 
| 273 | 
            +
                    """
         | 
| 274 | 
            +
             | 
| 275 | 
            +
                    assert self.copy_attn is None  # TODO, no support yet.
         | 
| 276 | 
            +
                    assert not self._coverage  # TODO, no support yet.
         | 
| 277 | 
            +
             | 
| 278 | 
            +
                    attns = {}
         | 
| 279 | 
            +
                    emb = self.embeddings(tgt)
         | 
| 280 | 
            +
             | 
| 281 | 
            +
                    if isinstance(self.rnn, nn.GRU):
         | 
| 282 | 
            +
                        rnn_out, dec_state = self.rnn(emb, self.state["hidden"][0])
         | 
| 283 | 
            +
                    else:
         | 
| 284 | 
            +
                        rnn_out, dec_state = self.rnn(emb, self.state["hidden"])
         | 
| 285 | 
            +
             | 
| 286 | 
            +
                    tgt_batch, tgt_len, _ = tgt.size()
         | 
| 287 | 
            +
             | 
| 288 | 
            +
                    # Calculate the attention.
         | 
| 289 | 
            +
                    if not self.attentional:
         | 
| 290 | 
            +
                        dec_outs = rnn_out
         | 
| 291 | 
            +
                    else:
         | 
| 292 | 
            +
                        dec_outs, p_attn = self.attn(rnn_out, enc_out, src_len=src_len)
         | 
| 293 | 
            +
                        attns["std"] = p_attn
         | 
| 294 | 
            +
             | 
| 295 | 
            +
                    # Calculate the context gate.
         | 
| 296 | 
            +
                    if self.context_gate is not None:
         | 
| 297 | 
            +
                        dec_outs = self.context_gate(
         | 
| 298 | 
            +
                            emb.view(-1, emb.size(2)),
         | 
| 299 | 
            +
                            rnn_out.view(-1, rnn_out.size(2)),
         | 
| 300 | 
            +
                            dec_outs.view(-1, dec_outs.size(2)),
         | 
| 301 | 
            +
                        )
         | 
| 302 | 
            +
                        dec_outs = dec_outs.view(tgt_batch, tgt_len, self.hidden_size)
         | 
| 303 | 
            +
             | 
| 304 | 
            +
                    dec_outs = self.dropout(dec_outs)
         | 
| 305 | 
            +
             | 
| 306 | 
            +
                    return dec_state, dec_outs, attns
         | 
| 307 | 
            +
             | 
| 308 | 
            +
                def _build_rnn(self, rnn_type, **kwargs):
         | 
| 309 | 
            +
                    rnn, _ = rnn_factory(rnn_type, **kwargs)
         | 
| 310 | 
            +
                    return rnn
         | 
| 311 | 
            +
             | 
| 312 | 
            +
                @property
         | 
| 313 | 
            +
                def _input_size(self):
         | 
| 314 | 
            +
                    return self.embeddings.embedding_size
         | 
| 315 | 
            +
             | 
| 316 | 
            +
             | 
| 317 | 
            +
            class InputFeedRNNDecoder(RNNDecoderBase):
         | 
| 318 | 
            +
                """Input feeding based decoder.
         | 
| 319 | 
            +
             | 
| 320 | 
            +
                See :class:`~onmt.decoders.decoder.RNNDecoderBase` for options.
         | 
| 321 | 
            +
             | 
| 322 | 
            +
                Based around the input feeding approach from
         | 
| 323 | 
            +
                "Effective Approaches to Attention-based Neural Machine Translation"
         | 
| 324 | 
            +
                :cite:`Luong2015`
         | 
| 325 | 
            +
             | 
| 326 | 
            +
                """
         | 
| 327 | 
            +
             | 
| 328 | 
            +
                def _run_forward_pass(self, tgt, enc_out, src_len=None):
         | 
| 329 | 
            +
                    """
         | 
| 330 | 
            +
                    See StdRNNDecoder._run_forward_pass() for description
         | 
| 331 | 
            +
                    of arguments and return values.
         | 
| 332 | 
            +
                    """
         | 
| 333 | 
            +
                    # Additional args check.
         | 
| 334 | 
            +
                    input_feed = self.state["input_feed"].squeeze(0)
         | 
| 335 | 
            +
             | 
| 336 | 
            +
                    dec_outs = []
         | 
| 337 | 
            +
                    attns = {}
         | 
| 338 | 
            +
                    if self.attn is not None:
         | 
| 339 | 
            +
                        attns["std"] = []
         | 
| 340 | 
            +
                    if self.copy_attn is not None or self._reuse_copy_attn:
         | 
| 341 | 
            +
                        attns["copy"] = []
         | 
| 342 | 
            +
                    if self._coverage:
         | 
| 343 | 
            +
                        attns["coverage"] = []
         | 
| 344 | 
            +
             | 
| 345 | 
            +
                    emb = self.embeddings(tgt)
         | 
| 346 | 
            +
                    assert emb.dim() == 3  # batch x len x embedding_dim
         | 
| 347 | 
            +
             | 
| 348 | 
            +
                    dec_state = self.state["hidden"]
         | 
| 349 | 
            +
             | 
| 350 | 
            +
                    coverage = (
         | 
| 351 | 
            +
                        self.state["coverage"].squeeze(0)
         | 
| 352 | 
            +
                        if self.state["coverage"] is not None
         | 
| 353 | 
            +
                        else None
         | 
| 354 | 
            +
                    )
         | 
| 355 | 
            +
             | 
| 356 | 
            +
                    # Input feed concatenates hidden state with
         | 
| 357 | 
            +
                    # input at every time step.
         | 
| 358 | 
            +
                    for emb_t in emb.split(1, dim=1):
         | 
| 359 | 
            +
                        dec_in = torch.cat([emb_t.squeeze(1), input_feed], 1)
         | 
| 360 | 
            +
                        rnn_out, dec_state = self.rnn(dec_in, dec_state)
         | 
| 361 | 
            +
                        if self.attentional:
         | 
| 362 | 
            +
                            dec_out, p_attn = self.attn(rnn_out, enc_out, src_len=src_len)
         | 
| 363 | 
            +
                            attns["std"].append(p_attn)
         | 
| 364 | 
            +
                        else:
         | 
| 365 | 
            +
                            dec_out = rnn_out
         | 
| 366 | 
            +
                        if self.context_gate is not None:
         | 
| 367 | 
            +
                            # TODO: context gate should be employed
         | 
| 368 | 
            +
                            # instead of second RNN transform.
         | 
| 369 | 
            +
                            dec_out = self.context_gate(dec_in, rnn_out, dec_out)
         | 
| 370 | 
            +
                        dec_out = self.dropout(dec_out)
         | 
| 371 | 
            +
                        input_feed = dec_out
         | 
| 372 | 
            +
             | 
| 373 | 
            +
                        dec_outs += [dec_out]
         | 
| 374 | 
            +
             | 
| 375 | 
            +
                        # Update the coverage attention.
         | 
| 376 | 
            +
                        # attns["coverage"] is actually c^(t+1) of See et al(2017)
         | 
| 377 | 
            +
                        # 1-index shifted
         | 
| 378 | 
            +
                        if self._coverage:
         | 
| 379 | 
            +
                            coverage = p_attn if coverage is None else p_attn + coverage
         | 
| 380 | 
            +
                            attns["coverage"] += [coverage]
         | 
| 381 | 
            +
             | 
| 382 | 
            +
                        if self.copy_attn is not None:
         | 
| 383 | 
            +
                            _, copy_attn = self.copy_attn(dec_out, enc_out)
         | 
| 384 | 
            +
                            attns["copy"] += [copy_attn]
         | 
| 385 | 
            +
                        elif self._reuse_copy_attn:
         | 
| 386 | 
            +
                            attns["copy"] = attns["std"]
         | 
| 387 | 
            +
             | 
| 388 | 
            +
                    return dec_state, dec_outs, attns
         | 
| 389 | 
            +
             | 
| 390 | 
            +
                def _build_rnn(self, rnn_type, input_size, hidden_size, num_layers, dropout):
         | 
| 391 | 
            +
                    assert rnn_type != "SRU", (
         | 
| 392 | 
            +
                        "SRU doesn't support input feed! " "Please set -input_feed 0!"
         | 
| 393 | 
            +
                    )
         | 
| 394 | 
            +
                    stacked_cell = StackedLSTM if rnn_type == "LSTM" else StackedGRU
         | 
| 395 | 
            +
                    return stacked_cell(num_layers, input_size, hidden_size, dropout)
         | 
| 396 | 
            +
             | 
| 397 | 
            +
                @property
         | 
| 398 | 
            +
                def _input_size(self):
         | 
| 399 | 
            +
                    """Using input feed by concatenating input with attention vectors."""
         | 
| 400 | 
            +
                    return self.embeddings.embedding_size + self.hidden_size
         | 
| 401 | 
            +
             | 
| 402 | 
            +
                def update_dropout(self, dropout, attention_dropout=None):
         | 
| 403 | 
            +
                    self.dropout.p = dropout
         | 
| 404 | 
            +
                    self.rnn.dropout.p = dropout
         | 
| 405 | 
            +
                    self.embeddings.update_dropout(dropout)
         | 
    	
        onmt/decoders/ensemble.py
    ADDED
    
    | @@ -0,0 +1,150 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            """Ensemble decoding.
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            Decodes using multiple models simultaneously,
         | 
| 4 | 
            +
            combining their prediction distributions by averaging.
         | 
| 5 | 
            +
            All models in the ensemble must share a target vocabulary.
         | 
| 6 | 
            +
            """
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            import torch
         | 
| 9 | 
            +
            import torch.nn as nn
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            from onmt.encoders.encoder import EncoderBase
         | 
| 12 | 
            +
            from onmt.decoders.decoder import DecoderBase
         | 
| 13 | 
            +
            from onmt.models import NMTModel
         | 
| 14 | 
            +
            import onmt.model_builder
         | 
| 15 | 
            +
             | 
| 16 | 
            +
             | 
| 17 | 
            +
            class EnsembleDecoderOutput(object):
         | 
| 18 | 
            +
                """Wrapper around multiple decoder final hidden states."""
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                def __init__(self, model_dec_outs):
         | 
| 21 | 
            +
                    self.model_dec_outs = tuple(model_dec_outs)
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                def squeeze(self, dim=None):
         | 
| 24 | 
            +
                    """Delegate squeeze to avoid modifying
         | 
| 25 | 
            +
                    :func:`onmt.translate.translator.Translator.translate_batch()`
         | 
| 26 | 
            +
                    """
         | 
| 27 | 
            +
                    return EnsembleDecoderOutput([x.squeeze(dim) for x in self.model_dec_outs])
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                def __getitem__(self, index):
         | 
| 30 | 
            +
                    return self.model_dec_outs[index]
         | 
| 31 | 
            +
             | 
| 32 | 
            +
             | 
| 33 | 
            +
            class EnsembleEncoder(EncoderBase):
         | 
| 34 | 
            +
                """Dummy Encoder that delegates to individual real Encoders."""
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                def __init__(self, model_encoders):
         | 
| 37 | 
            +
                    super(EnsembleEncoder, self).__init__()
         | 
| 38 | 
            +
                    self.model_encoders = nn.ModuleList(model_encoders)
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                def forward(self, src, src_len=None):
         | 
| 41 | 
            +
                    enc_out, enc_final_hs, _ = zip(
         | 
| 42 | 
            +
                        *[model_encoder(src, src_len) for model_encoder in self.model_encoders]
         | 
| 43 | 
            +
                    )
         | 
| 44 | 
            +
                    return enc_out, enc_final_hs, src_len
         | 
| 45 | 
            +
             | 
| 46 | 
            +
             | 
| 47 | 
            +
            class EnsembleDecoder(DecoderBase):
         | 
| 48 | 
            +
                """Dummy Decoder that delegates to individual real Decoders."""
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                def __init__(self, model_decoders):
         | 
| 51 | 
            +
                    model_decoders = nn.ModuleList(model_decoders)
         | 
| 52 | 
            +
                    attentional = any([dec.attentional for dec in model_decoders])
         | 
| 53 | 
            +
                    super(EnsembleDecoder, self).__init__(attentional)
         | 
| 54 | 
            +
                    self.model_decoders = model_decoders
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                def forward(self, tgt, enc_out, src_len=None, step=None, **kwargs):
         | 
| 57 | 
            +
                    """See :func:`onmt.decoders.decoder.DecoderBase.forward()`."""
         | 
| 58 | 
            +
                    # src_len is a single tensor shared between all models.
         | 
| 59 | 
            +
                    # This assumption will not hold if Translator is modified
         | 
| 60 | 
            +
                    # to calculate src_len as something other than the length
         | 
| 61 | 
            +
                    # of the input.
         | 
| 62 | 
            +
                    dec_outs, attns = zip(
         | 
| 63 | 
            +
                        *[
         | 
| 64 | 
            +
                            model_decoder(tgt, enc_out[i], src_len=src_len, step=step, **kwargs)
         | 
| 65 | 
            +
                            for i, model_decoder in enumerate(self.model_decoders)
         | 
| 66 | 
            +
                        ]
         | 
| 67 | 
            +
                    )
         | 
| 68 | 
            +
                    mean_attns = self.combine_attns(attns)
         | 
| 69 | 
            +
                    return EnsembleDecoderOutput(dec_outs), mean_attns
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                def combine_attns(self, attns):
         | 
| 72 | 
            +
                    result = {}
         | 
| 73 | 
            +
                    for key in attns[0].keys():
         | 
| 74 | 
            +
                        result[key] = torch.stack(
         | 
| 75 | 
            +
                            [attn[key] for attn in attns if attn[key] is not None]
         | 
| 76 | 
            +
                        ).mean(0)
         | 
| 77 | 
            +
                    return result
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                def init_state(self, src, enc_out, enc_hidden):
         | 
| 80 | 
            +
                    """See :obj:`RNNDecoderBase.init_state()`"""
         | 
| 81 | 
            +
                    for i, model_decoder in enumerate(self.model_decoders):
         | 
| 82 | 
            +
                        model_decoder.init_state(src, enc_out[i], enc_hidden[i])
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                def map_state(self, fn):
         | 
| 85 | 
            +
                    for model_decoder in self.model_decoders:
         | 
| 86 | 
            +
                        model_decoder.map_state(fn)
         | 
| 87 | 
            +
             | 
| 88 | 
            +
             | 
| 89 | 
            +
            class EnsembleGenerator(nn.Module):
         | 
| 90 | 
            +
                """
         | 
| 91 | 
            +
                Dummy Generator that delegates to individual real Generators,
         | 
| 92 | 
            +
                and then averages the resulting target distributions.
         | 
| 93 | 
            +
                """
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                def __init__(self, model_generators, raw_probs=False):
         | 
| 96 | 
            +
                    super(EnsembleGenerator, self).__init__()
         | 
| 97 | 
            +
                    self.model_generators = nn.ModuleList(model_generators)
         | 
| 98 | 
            +
                    self._raw_probs = raw_probs
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                def forward(self, hidden, attn=None, src_map=None):
         | 
| 101 | 
            +
                    """
         | 
| 102 | 
            +
                    Compute a distribution over the target dictionary
         | 
| 103 | 
            +
                    by averaging distributions from models in the ensemble.
         | 
| 104 | 
            +
                    All models in the ensemble must share a target vocabulary.
         | 
| 105 | 
            +
                    """
         | 
| 106 | 
            +
                    distributions = torch.stack(
         | 
| 107 | 
            +
                        [
         | 
| 108 | 
            +
                            mg(h) if attn is None else mg(h, attn, src_map)
         | 
| 109 | 
            +
                            for h, mg in zip(hidden, self.model_generators)
         | 
| 110 | 
            +
                        ]
         | 
| 111 | 
            +
                    )
         | 
| 112 | 
            +
                    if self._raw_probs:
         | 
| 113 | 
            +
                        return torch.log(torch.exp(distributions).mean(0))
         | 
| 114 | 
            +
                    else:
         | 
| 115 | 
            +
                        return distributions.mean(0)
         | 
| 116 | 
            +
             | 
| 117 | 
            +
             | 
| 118 | 
            +
            class EnsembleModel(NMTModel):
         | 
| 119 | 
            +
                """Dummy NMTModel wrapping individual real NMTModels."""
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                def __init__(self, models, raw_probs=False):
         | 
| 122 | 
            +
                    encoder = EnsembleEncoder(model.encoder for model in models)
         | 
| 123 | 
            +
                    decoder = EnsembleDecoder(model.decoder for model in models)
         | 
| 124 | 
            +
                    super(EnsembleModel, self).__init__(encoder, decoder)
         | 
| 125 | 
            +
                    self.generator = EnsembleGenerator(
         | 
| 126 | 
            +
                        [model.generator for model in models], raw_probs
         | 
| 127 | 
            +
                    )
         | 
| 128 | 
            +
                    self.models = nn.ModuleList(models)
         | 
| 129 | 
            +
             | 
| 130 | 
            +
             | 
| 131 | 
            +
            def load_test_model(opt, device_id=0):
         | 
| 132 | 
            +
                """Read in multiple models for ensemble."""
         | 
| 133 | 
            +
                shared_vocabs = None
         | 
| 134 | 
            +
                shared_model_opt = None
         | 
| 135 | 
            +
                models = []
         | 
| 136 | 
            +
                for model_path in opt.models:
         | 
| 137 | 
            +
                    vocabs, model, model_opt = onmt.model_builder.load_test_model(
         | 
| 138 | 
            +
                        opt, device_id, model_path=model_path
         | 
| 139 | 
            +
                    )
         | 
| 140 | 
            +
                    if shared_vocabs is None:
         | 
| 141 | 
            +
                        shared_vocabs = vocabs
         | 
| 142 | 
            +
                    else:
         | 
| 143 | 
            +
                        assert (
         | 
| 144 | 
            +
                            shared_vocabs["src"].tokens_to_ids == vocabs["src"].tokens_to_ids
         | 
| 145 | 
            +
                        ), "Ensemble models must use the same vocabs "
         | 
| 146 | 
            +
                    models.append(model)
         | 
| 147 | 
            +
                    if shared_model_opt is None:
         | 
| 148 | 
            +
                        shared_model_opt = model_opt
         | 
| 149 | 
            +
                ensemble_model = EnsembleModel(models, opt.avg_raw_probs)
         | 
| 150 | 
            +
                return shared_vocabs, ensemble_model, shared_model_opt
         | 
    	
        onmt/decoders/transformer.py
    ADDED
    
    | @@ -0,0 +1,835 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            """
         | 
| 2 | 
            +
            Implementation of "Attention is All You Need" and of
         | 
| 3 | 
            +
            subsequent transformer based architectures
         | 
| 4 | 
            +
            """
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import torch
         | 
| 7 | 
            +
            import torch.nn as nn
         | 
| 8 | 
            +
            from onmt.decoders.decoder import DecoderBase
         | 
| 9 | 
            +
            from onmt.modules import MultiHeadedAttention, AverageAttention
         | 
| 10 | 
            +
            from onmt.modules.position_ffn import PositionwiseFeedForward
         | 
| 11 | 
            +
            from onmt.modules.position_ffn import ActivationFunction
         | 
| 12 | 
            +
            from onmt.utils.misc import sequence_mask
         | 
| 13 | 
            +
            from onmt.modules.rmsnorm import RMSNorm
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            class TransformerDecoderLayerBase(nn.Module):
         | 
| 17 | 
            +
                def __init__(
         | 
| 18 | 
            +
                    self,
         | 
| 19 | 
            +
                    d_model,
         | 
| 20 | 
            +
                    heads,
         | 
| 21 | 
            +
                    d_ff,
         | 
| 22 | 
            +
                    dropout,
         | 
| 23 | 
            +
                    attention_dropout,
         | 
| 24 | 
            +
                    self_attn_type="scaled-dot",
         | 
| 25 | 
            +
                    max_relative_positions=0,
         | 
| 26 | 
            +
                    relative_positions_buckets=0,
         | 
| 27 | 
            +
                    aan_useffn=False,
         | 
| 28 | 
            +
                    full_context_alignment=False,
         | 
| 29 | 
            +
                    alignment_heads=0,
         | 
| 30 | 
            +
                    pos_ffn_activation_fn=ActivationFunction.relu,
         | 
| 31 | 
            +
                    add_qkvbias=False,
         | 
| 32 | 
            +
                    num_kv=0,
         | 
| 33 | 
            +
                    add_ffnbias=True,
         | 
| 34 | 
            +
                    parallel_residual=False,
         | 
| 35 | 
            +
                    shared_layer_norm=False,
         | 
| 36 | 
            +
                    layer_norm="standard",
         | 
| 37 | 
            +
                    norm_eps=1e-6,
         | 
| 38 | 
            +
                    use_ckpting=[],
         | 
| 39 | 
            +
                    parallel_gpu=1,
         | 
| 40 | 
            +
                ):
         | 
| 41 | 
            +
                    """
         | 
| 42 | 
            +
                    Args:
         | 
| 43 | 
            +
                        d_model (int): the dimension of keys/values/queries in
         | 
| 44 | 
            +
                            :class:`MultiHeadedAttention`, also the input size of
         | 
| 45 | 
            +
                            the first-layer of the :class:`PositionwiseFeedForward`.
         | 
| 46 | 
            +
                        heads (int): the number of heads for MultiHeadedAttention.
         | 
| 47 | 
            +
                        d_ff (int): the second-layer of the
         | 
| 48 | 
            +
                            :class:`PositionwiseFeedForward`.
         | 
| 49 | 
            +
                        dropout (float): dropout in residual, self-attn(dot) and
         | 
| 50 | 
            +
                            feed-forward
         | 
| 51 | 
            +
                        attention_dropout (float): dropout in context_attn  (and
         | 
| 52 | 
            +
                            self-attn(avg))
         | 
| 53 | 
            +
                        self_attn_type (string): type of self-attention scaled-dot,
         | 
| 54 | 
            +
                            average
         | 
| 55 | 
            +
                        max_relative_positions (int):
         | 
| 56 | 
            +
                            Max distance between inputs in relative positions
         | 
| 57 | 
            +
                            representations
         | 
| 58 | 
            +
                        aan_useffn (bool): Turn on the FFN layer in the AAN decoder
         | 
| 59 | 
            +
                        full_context_alignment (bool):
         | 
| 60 | 
            +
                            whether enable an extra full context decoder forward for
         | 
| 61 | 
            +
                            alignment
         | 
| 62 | 
            +
                        alignment_heads (int):
         | 
| 63 | 
            +
                            N. of cross attention heads to use for alignment guiding
         | 
| 64 | 
            +
                        pos_ffn_activation_fn (ActivationFunction):
         | 
| 65 | 
            +
                            activation function choice for PositionwiseFeedForward layer
         | 
| 66 | 
            +
                        add_qkvbias (bool): whether to add bias to the Key/Value nn.Linear
         | 
| 67 | 
            +
                        layer_norm (string): type of layer normalization standard/rms
         | 
| 68 | 
            +
                        norm_eps (float): layer norm epsilon
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                    """
         | 
| 71 | 
            +
                    super(TransformerDecoderLayerBase, self).__init__()
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                    self.self_attn_type = self_attn_type
         | 
| 74 | 
            +
                    if self_attn_type == "scaled-dot":
         | 
| 75 | 
            +
                        self.self_attn = MultiHeadedAttention(
         | 
| 76 | 
            +
                            heads,
         | 
| 77 | 
            +
                            d_model,
         | 
| 78 | 
            +
                            dropout=attention_dropout,
         | 
| 79 | 
            +
                            max_relative_positions=max_relative_positions,
         | 
| 80 | 
            +
                            relative_positions_buckets=relative_positions_buckets,
         | 
| 81 | 
            +
                            attn_type="self",
         | 
| 82 | 
            +
                            add_qkvbias=add_qkvbias,
         | 
| 83 | 
            +
                            num_kv=num_kv,
         | 
| 84 | 
            +
                            use_ckpting=use_ckpting,
         | 
| 85 | 
            +
                            parallel_gpu=parallel_gpu,
         | 
| 86 | 
            +
                        )
         | 
| 87 | 
            +
                    elif self_attn_type == "average":
         | 
| 88 | 
            +
                        self.self_attn = AverageAttention(
         | 
| 89 | 
            +
                            d_model, dropout=attention_dropout, aan_useffn=aan_useffn
         | 
| 90 | 
            +
                        )
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                    self.feed_forward = PositionwiseFeedForward(
         | 
| 93 | 
            +
                        d_model,
         | 
| 94 | 
            +
                        d_ff,
         | 
| 95 | 
            +
                        dropout,
         | 
| 96 | 
            +
                        pos_ffn_activation_fn,
         | 
| 97 | 
            +
                        add_ffnbias,
         | 
| 98 | 
            +
                        parallel_residual,
         | 
| 99 | 
            +
                        layer_norm,
         | 
| 100 | 
            +
                        norm_eps,
         | 
| 101 | 
            +
                        use_ckpting=use_ckpting,
         | 
| 102 | 
            +
                        parallel_gpu=parallel_gpu,
         | 
| 103 | 
            +
                    )
         | 
| 104 | 
            +
                    self.parallel_residual = parallel_residual
         | 
| 105 | 
            +
                    self.shared_layer_norm = shared_layer_norm
         | 
| 106 | 
            +
                    if layer_norm == "standard":
         | 
| 107 | 
            +
                        self.layer_norm_1 = nn.LayerNorm(d_model, eps=norm_eps)
         | 
| 108 | 
            +
                        if parallel_residual and not shared_layer_norm:
         | 
| 109 | 
            +
                            self.layer_norm_res = nn.LayerNorm(d_model, eps=norm_eps)
         | 
| 110 | 
            +
                    elif layer_norm == "rms":
         | 
| 111 | 
            +
                        self.layer_norm_1 = RMSNorm(d_model, eps=norm_eps)
         | 
| 112 | 
            +
                        if parallel_residual and not shared_layer_norm:
         | 
| 113 | 
            +
                            self.layer_norm_res = RMSNorm(d_model, eps=norm_eps)
         | 
| 114 | 
            +
                    else:
         | 
| 115 | 
            +
                        raise ValueError(f"{layer_norm} layer norm type is not supported")
         | 
| 116 | 
            +
             | 
| 117 | 
            +
                    self.dropout = nn.Dropout(dropout)
         | 
| 118 | 
            +
                    self.full_context_alignment = full_context_alignment
         | 
| 119 | 
            +
                    self.alignment_heads = alignment_heads
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                def forward(self, *args, **kwargs):
         | 
| 122 | 
            +
                    """Extend `_forward` for (possibly) multiple decoder pass:
         | 
| 123 | 
            +
                    Always a default (future masked) decoder forward pass,
         | 
| 124 | 
            +
                    Possibly a second future aware decoder pass for joint learn
         | 
| 125 | 
            +
                    full context alignement, :cite:`garg2019jointly`.
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                    Args:
         | 
| 128 | 
            +
                        * All arguments of _forward, of which
         | 
| 129 | 
            +
                        with_align (bool): needed to compute attn_align
         | 
| 130 | 
            +
                        return_attn (bool): to force MHA to return attns
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                    Returns:
         | 
| 133 | 
            +
                        (FloatTensor, FloatTensor, FloatTensor or None):
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                        * layer_out ``(batch_size, T, model_dim)``
         | 
| 136 | 
            +
                        * top_attn ``(batch_size, T, src_len)``
         | 
| 137 | 
            +
                        * attn_align ``(batch_size, T, src_len)`` or None
         | 
| 138 | 
            +
                    """
         | 
| 139 | 
            +
                    with_align = kwargs.pop("with_align", False)
         | 
| 140 | 
            +
                    layer_out, attns = self._forward(*args, **kwargs)
         | 
| 141 | 
            +
                    top_attn = None if attns is None else attns[:, 0, :, :].contiguous()
         | 
| 142 | 
            +
                    attn_align = None
         | 
| 143 | 
            +
                    if with_align:
         | 
| 144 | 
            +
                        if self.full_context_alignment:
         | 
| 145 | 
            +
                            # return _, (B, Q_len, K_len)
         | 
| 146 | 
            +
                            _, attns = self._forward(*args, **kwargs, future=True)
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                        if self.alignment_heads > 0:
         | 
| 149 | 
            +
                            attns = attns[:, : self.alignment_heads, :, :].contiguous()
         | 
| 150 | 
            +
                        # layer average attention across heads, get ``(B, Q, K)``
         | 
| 151 | 
            +
                        # Case 1: no full_context, no align heads -> layer avg baseline
         | 
| 152 | 
            +
                        # Case 2: no full_context, 1 align heads -> guided align
         | 
| 153 | 
            +
                        # Case 3: full_context, 1 align heads -> full cte guided align
         | 
| 154 | 
            +
                        attn_align = attns.mean(dim=1)
         | 
| 155 | 
            +
                    return layer_out, top_attn, attn_align
         | 
| 156 | 
            +
             | 
| 157 | 
            +
                def update_dropout(self, dropout, attention_dropout):
         | 
| 158 | 
            +
                    self.self_attn.update_dropout(attention_dropout)
         | 
| 159 | 
            +
                    self.feed_forward.update_dropout(dropout)
         | 
| 160 | 
            +
                    self.dropout.p = dropout
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                def _forward(self, *args, **kwargs):
         | 
| 163 | 
            +
                    raise NotImplementedError
         | 
| 164 | 
            +
             | 
| 165 | 
            +
                def _compute_dec_mask(self, tgt_pad_mask, future):
         | 
| 166 | 
            +
                    tgt_len = tgt_pad_mask.size(-1)
         | 
| 167 | 
            +
                    if not future:  # apply future_mask, result mask in (B, T, T)
         | 
| 168 | 
            +
                        future_mask = torch.ones(
         | 
| 169 | 
            +
                            [tgt_len, tgt_len],
         | 
| 170 | 
            +
                            device=tgt_pad_mask.device,
         | 
| 171 | 
            +
                            dtype=torch.uint8,
         | 
| 172 | 
            +
                        )
         | 
| 173 | 
            +
                        future_mask = future_mask.triu_(1).view(1, tgt_len, tgt_len)
         | 
| 174 | 
            +
                        # BoolTensor was introduced in pytorch 1.2
         | 
| 175 | 
            +
                        try:
         | 
| 176 | 
            +
                            future_mask = future_mask.bool()
         | 
| 177 | 
            +
                        except AttributeError:
         | 
| 178 | 
            +
                            pass
         | 
| 179 | 
            +
                        dec_mask = torch.gt(tgt_pad_mask + future_mask, 0)
         | 
| 180 | 
            +
                    else:  # only mask padding, result mask in (B, 1, T)
         | 
| 181 | 
            +
                        dec_mask = tgt_pad_mask
         | 
| 182 | 
            +
                    return dec_mask
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                def _forward_self_attn(self, norm_layer_in, dec_mask, step, return_attn=False):
         | 
| 185 | 
            +
                    if self.self_attn_type == "scaled-dot":
         | 
| 186 | 
            +
                        return self.self_attn(
         | 
| 187 | 
            +
                            norm_layer_in,
         | 
| 188 | 
            +
                            norm_layer_in,
         | 
| 189 | 
            +
                            norm_layer_in,
         | 
| 190 | 
            +
                            mask=dec_mask,
         | 
| 191 | 
            +
                            step=step,
         | 
| 192 | 
            +
                            return_attn=return_attn,
         | 
| 193 | 
            +
                        )
         | 
| 194 | 
            +
                    elif self.self_attn_type == "average":
         | 
| 195 | 
            +
                        return self.self_attn(norm_layer_in, mask=dec_mask, step=step)
         | 
| 196 | 
            +
                    else:
         | 
| 197 | 
            +
                        raise ValueError(f"self attention {type(self.self_attn)} not supported")
         | 
| 198 | 
            +
             | 
| 199 | 
            +
             | 
| 200 | 
            +
            class TransformerDecoderLayer(TransformerDecoderLayerBase):
         | 
| 201 | 
            +
                """Transformer Decoder layer block in Pre-Norm style.
         | 
| 202 | 
            +
                Pre-Norm style is an improvement w.r.t. Original paper's Post-Norm style,
         | 
| 203 | 
            +
                providing better converge speed and performance. This is also the actual
         | 
| 204 | 
            +
                implementation in tensor2tensor and also avalable in fairseq.
         | 
| 205 | 
            +
                See https://tunz.kr/post/4 and :cite:`DeeperTransformer`.
         | 
| 206 | 
            +
             | 
| 207 | 
            +
                """
         | 
| 208 | 
            +
             | 
| 209 | 
            +
                def __init__(
         | 
| 210 | 
            +
                    self,
         | 
| 211 | 
            +
                    d_model,
         | 
| 212 | 
            +
                    heads,
         | 
| 213 | 
            +
                    d_ff,
         | 
| 214 | 
            +
                    dropout,
         | 
| 215 | 
            +
                    attention_dropout,
         | 
| 216 | 
            +
                    self_attn_type="scaled-dot",
         | 
| 217 | 
            +
                    max_relative_positions=0,
         | 
| 218 | 
            +
                    relative_positions_buckets=0,
         | 
| 219 | 
            +
                    aan_useffn=False,
         | 
| 220 | 
            +
                    full_context_alignment=False,
         | 
| 221 | 
            +
                    alignment_heads=0,
         | 
| 222 | 
            +
                    pos_ffn_activation_fn=ActivationFunction.relu,
         | 
| 223 | 
            +
                    add_qkvbias=False,
         | 
| 224 | 
            +
                    num_kv=0,
         | 
| 225 | 
            +
                    add_ffnbias=True,
         | 
| 226 | 
            +
                    parallel_residual=False,
         | 
| 227 | 
            +
                    shared_layer_norm=False,
         | 
| 228 | 
            +
                    layer_norm="standard",
         | 
| 229 | 
            +
                    norm_eps=1e-6,
         | 
| 230 | 
            +
                    use_ckpting=[],
         | 
| 231 | 
            +
                    parallel_gpu=1,
         | 
| 232 | 
            +
                ):
         | 
| 233 | 
            +
                    """
         | 
| 234 | 
            +
                    Args:
         | 
| 235 | 
            +
                        See TransformerDecoderLayerBase
         | 
| 236 | 
            +
                    """
         | 
| 237 | 
            +
                    super(TransformerDecoderLayer, self).__init__(
         | 
| 238 | 
            +
                        d_model,
         | 
| 239 | 
            +
                        heads,
         | 
| 240 | 
            +
                        d_ff,
         | 
| 241 | 
            +
                        dropout,
         | 
| 242 | 
            +
                        attention_dropout,
         | 
| 243 | 
            +
                        self_attn_type,
         | 
| 244 | 
            +
                        max_relative_positions,
         | 
| 245 | 
            +
                        relative_positions_buckets,
         | 
| 246 | 
            +
                        aan_useffn,
         | 
| 247 | 
            +
                        full_context_alignment,
         | 
| 248 | 
            +
                        alignment_heads,
         | 
| 249 | 
            +
                        pos_ffn_activation_fn=pos_ffn_activation_fn,
         | 
| 250 | 
            +
                        add_qkvbias=add_qkvbias,
         | 
| 251 | 
            +
                        num_kv=num_kv,
         | 
| 252 | 
            +
                        add_ffnbias=add_ffnbias,
         | 
| 253 | 
            +
                        parallel_residual=parallel_residual,
         | 
| 254 | 
            +
                        shared_layer_norm=shared_layer_norm,
         | 
| 255 | 
            +
                        layer_norm=layer_norm,
         | 
| 256 | 
            +
                        norm_eps=norm_eps,
         | 
| 257 | 
            +
                        use_ckpting=use_ckpting,
         | 
| 258 | 
            +
                        parallel_gpu=parallel_gpu,
         | 
| 259 | 
            +
                    )
         | 
| 260 | 
            +
                    self.context_attn = MultiHeadedAttention(
         | 
| 261 | 
            +
                        heads,
         | 
| 262 | 
            +
                        d_model,
         | 
| 263 | 
            +
                        dropout=attention_dropout,
         | 
| 264 | 
            +
                        attn_type="context",
         | 
| 265 | 
            +
                        add_qkvbias=add_qkvbias,
         | 
| 266 | 
            +
                        num_kv=num_kv,
         | 
| 267 | 
            +
                        use_ckpting=use_ckpting,
         | 
| 268 | 
            +
                        parallel_gpu=parallel_gpu,
         | 
| 269 | 
            +
                    )
         | 
| 270 | 
            +
                    if layer_norm == "standard":
         | 
| 271 | 
            +
                        self.layer_norm_2 = nn.LayerNorm(d_model, eps=norm_eps)
         | 
| 272 | 
            +
                    elif layer_norm == "rms":
         | 
| 273 | 
            +
                        self.layer_norm_2 = RMSNorm(d_model, eps=norm_eps)
         | 
| 274 | 
            +
                    else:
         | 
| 275 | 
            +
                        raise ValueError(f"{layer_norm} layer norm type is not supported")
         | 
| 276 | 
            +
             | 
| 277 | 
            +
                def update_dropout(self, dropout, attention_dropout):
         | 
| 278 | 
            +
                    super(TransformerDecoderLayer, self).update_dropout(dropout, attention_dropout)
         | 
| 279 | 
            +
                    self.context_attn.update_dropout(attention_dropout)
         | 
| 280 | 
            +
             | 
| 281 | 
            +
                def _forward(
         | 
| 282 | 
            +
                    self,
         | 
| 283 | 
            +
                    layer_in,
         | 
| 284 | 
            +
                    enc_out,
         | 
| 285 | 
            +
                    src_pad_mask,
         | 
| 286 | 
            +
                    tgt_pad_mask,
         | 
| 287 | 
            +
                    step=None,
         | 
| 288 | 
            +
                    future=False,
         | 
| 289 | 
            +
                    return_attn=False,
         | 
| 290 | 
            +
                ):
         | 
| 291 | 
            +
                    """A naive forward pass for transformer decoder.
         | 
| 292 | 
            +
             | 
| 293 | 
            +
                    # T: could be 1 in the case of stepwise decoding or tgt_len
         | 
| 294 | 
            +
             | 
| 295 | 
            +
                    Args:
         | 
| 296 | 
            +
                        layer_in (FloatTensor): ``(batch_size, T, model_dim)``
         | 
| 297 | 
            +
                        enc_out (FloatTensor): ``(batch_size, src_len, model_dim)``
         | 
| 298 | 
            +
                        src_pad_mask (bool): ``(batch_size, 1, src_len)``
         | 
| 299 | 
            +
                        tgt_pad_mask (bool): ``(batch_size, 1, T)``
         | 
| 300 | 
            +
                        step (int or None): stepwise decoding counter
         | 
| 301 | 
            +
                        future (bool): If set True, do not apply future_mask.
         | 
| 302 | 
            +
                        return_attn (bool) : if set True requires attns output
         | 
| 303 | 
            +
             | 
| 304 | 
            +
                    Returns:
         | 
| 305 | 
            +
                        (FloatTensor, FloatTensor):
         | 
| 306 | 
            +
             | 
| 307 | 
            +
                        * layer_out ``(batch_size, T, model_dim)``
         | 
| 308 | 
            +
                        * attns ``(batch_size, head, T, src_len)``
         | 
| 309 | 
            +
             | 
| 310 | 
            +
                    """
         | 
| 311 | 
            +
                    dec_mask = None
         | 
| 312 | 
            +
                    src_pad_mask = src_pad_mask.unsqueeze(1)  # [B,1,1,slen]
         | 
| 313 | 
            +
             | 
| 314 | 
            +
                    if layer_in.size(1) > 1:
         | 
| 315 | 
            +
                        # masking is necessary when sequence length is greater than one
         | 
| 316 | 
            +
                        dec_mask = self._compute_dec_mask(tgt_pad_mask, future)
         | 
| 317 | 
            +
                        dec_mask = dec_mask.unsqueeze(1)
         | 
| 318 | 
            +
                        dec_mask = dec_mask.expand(-1, -1, dec_mask.size(3), -1)
         | 
| 319 | 
            +
                        src_pad_mask = src_pad_mask.expand(-1, -1, dec_mask.size(3), -1)
         | 
| 320 | 
            +
                        # mask now are (batch x 1 x tlen x s or t len)
         | 
| 321 | 
            +
                        # 1 = heads to be expanded in MHA
         | 
| 322 | 
            +
             | 
| 323 | 
            +
                    norm_layer_in = self.layer_norm_1(layer_in)
         | 
| 324 | 
            +
             | 
| 325 | 
            +
                    self_attn, _ = self._forward_self_attn(norm_layer_in, dec_mask, step)
         | 
| 326 | 
            +
             | 
| 327 | 
            +
                    if self.parallel_residual:
         | 
| 328 | 
            +
                        ctx_attn, attns = self.context_attn(
         | 
| 329 | 
            +
                            enc_out,
         | 
| 330 | 
            +
                            enc_out,
         | 
| 331 | 
            +
                            norm_layer_in,
         | 
| 332 | 
            +
                            mask=src_pad_mask,
         | 
| 333 | 
            +
                            return_attn=return_attn,
         | 
| 334 | 
            +
                        )
         | 
| 335 | 
            +
                        # feed_forward applies residual, so we remove and apply residual with un-normed
         | 
| 336 | 
            +
                        layer_out = (
         | 
| 337 | 
            +
                            self.feed_forward(norm_layer_in)
         | 
| 338 | 
            +
                            - norm_layer_in
         | 
| 339 | 
            +
                            + layer_in
         | 
| 340 | 
            +
                            + self.dropout(self_attn)
         | 
| 341 | 
            +
                            + ctx_attn
         | 
| 342 | 
            +
                        )
         | 
| 343 | 
            +
                    else:
         | 
| 344 | 
            +
                        query = self.dropout(self_attn) + layer_in
         | 
| 345 | 
            +
                        norm_query = self.layer_norm_2(query)
         | 
| 346 | 
            +
                        ctx_attn, attns = self.context_attn(
         | 
| 347 | 
            +
                            enc_out, enc_out, norm_query, mask=src_pad_mask, return_attn=return_attn
         | 
| 348 | 
            +
                        )
         | 
| 349 | 
            +
                        layer_out = self.feed_forward(self.dropout(ctx_attn) + query)
         | 
| 350 | 
            +
             | 
| 351 | 
            +
                    return layer_out, attns
         | 
| 352 | 
            +
             | 
| 353 | 
            +
             | 
| 354 | 
            +
            class TransformerDecoderBase(DecoderBase):
         | 
| 355 | 
            +
                def __init__(
         | 
| 356 | 
            +
                    self, d_model, copy_attn, embeddings, alignment_layer, layer_norm, norm_eps
         | 
| 357 | 
            +
                ):
         | 
| 358 | 
            +
                    super(TransformerDecoderBase, self).__init__()
         | 
| 359 | 
            +
             | 
| 360 | 
            +
                    self.embeddings = embeddings
         | 
| 361 | 
            +
             | 
| 362 | 
            +
                    # Decoder State
         | 
| 363 | 
            +
                    self.state = {}
         | 
| 364 | 
            +
             | 
| 365 | 
            +
                    # previously, there was a GlobalAttention module here for copy
         | 
| 366 | 
            +
                    # attention. But it was never actually used -- the "copy" attention
         | 
| 367 | 
            +
                    # just reuses the context attention.
         | 
| 368 | 
            +
                    self._copy = copy_attn
         | 
| 369 | 
            +
                    if layer_norm == "standard":
         | 
| 370 | 
            +
                        self.layer_norm = nn.LayerNorm(d_model, eps=norm_eps)
         | 
| 371 | 
            +
                    elif layer_norm == "rms":
         | 
| 372 | 
            +
                        self.layer_norm = RMSNorm(d_model, eps=norm_eps)
         | 
| 373 | 
            +
                    else:
         | 
| 374 | 
            +
                        raise ValueError(f"{layer_norm} layer norm type is not supported")
         | 
| 375 | 
            +
             | 
| 376 | 
            +
                    self.alignment_layer = alignment_layer
         | 
| 377 | 
            +
             | 
| 378 | 
            +
                @classmethod
         | 
| 379 | 
            +
                def from_opt(cls, opt, embeddings):
         | 
| 380 | 
            +
                    """Alternate constructor."""
         | 
| 381 | 
            +
                    return cls(
         | 
| 382 | 
            +
                        opt.dec_layers,
         | 
| 383 | 
            +
                        opt.dec_hid_size,
         | 
| 384 | 
            +
                        opt.heads,
         | 
| 385 | 
            +
                        opt.transformer_ff,
         | 
| 386 | 
            +
                        opt.copy_attn,
         | 
| 387 | 
            +
                        opt.self_attn_type,
         | 
| 388 | 
            +
                        opt.dropout[0] if type(opt.dropout) is list else opt.dropout,
         | 
| 389 | 
            +
                        opt.attention_dropout[0]
         | 
| 390 | 
            +
                        if type(opt.attention_dropout) is list
         | 
| 391 | 
            +
                        else opt.attention_dropout,
         | 
| 392 | 
            +
                        embeddings,
         | 
| 393 | 
            +
                        opt.max_relative_positions,
         | 
| 394 | 
            +
                        opt.relative_positions_buckets,
         | 
| 395 | 
            +
                        opt.aan_useffn,
         | 
| 396 | 
            +
                        opt.full_context_alignment,
         | 
| 397 | 
            +
                        opt.alignment_layer,
         | 
| 398 | 
            +
                        alignment_heads=opt.alignment_heads,
         | 
| 399 | 
            +
                        pos_ffn_activation_fn=opt.pos_ffn_activation_fn,
         | 
| 400 | 
            +
                        add_qkvbias=opt.add_qkvbias,
         | 
| 401 | 
            +
                        num_kv=opt.num_kv,
         | 
| 402 | 
            +
                        add_ffnbias=opt.add_ffnbias,
         | 
| 403 | 
            +
                        parallel_residual=opt.parallel_residual,
         | 
| 404 | 
            +
                        shared_layer_norm=opt.shared_layer_norm,
         | 
| 405 | 
            +
                        layer_norm=opt.layer_norm,
         | 
| 406 | 
            +
                        norm_eps=opt.norm_eps,
         | 
| 407 | 
            +
                        use_ckpting=opt.use_ckpting,
         | 
| 408 | 
            +
                        parallel_gpu=opt.world_size
         | 
| 409 | 
            +
                        if opt.parallel_mode == "tensor_parallel"
         | 
| 410 | 
            +
                        else 1,
         | 
| 411 | 
            +
                    )
         | 
| 412 | 
            +
             | 
| 413 | 
            +
                def init_state(self, src, enc_out, enc_final_hs):
         | 
| 414 | 
            +
                    """Initialize decoder state."""
         | 
| 415 | 
            +
                    self.state["src"] = src
         | 
| 416 | 
            +
             | 
| 417 | 
            +
                def map_state(self, fn):
         | 
| 418 | 
            +
                    if self.state["src"] is not None:
         | 
| 419 | 
            +
                        self.state["src"] = fn(self.state["src"], 0)
         | 
| 420 | 
            +
                    for layer in self.transformer_layers:
         | 
| 421 | 
            +
                        if hasattr(layer, "context_attn"):
         | 
| 422 | 
            +
                            if layer.context_attn.layer_cache[1]["keys"].numel() != 0:
         | 
| 423 | 
            +
                                x = fn(layer.context_attn.layer_cache[1]["keys"], 0)
         | 
| 424 | 
            +
                                y = fn(layer.context_attn.layer_cache[1]["values"], 0)
         | 
| 425 | 
            +
                                layer.context_attn.layer_cache = True, {"keys": x, "values": y}
         | 
| 426 | 
            +
                        if isinstance(layer.self_attn, AverageAttention):
         | 
| 427 | 
            +
                            if layer.self_attn.layer_cache[1]["prev_g"].numel() != 0:
         | 
| 428 | 
            +
                                x = fn(layer.self_attn.layer_cache[1]["prev_g"], 0)
         | 
| 429 | 
            +
                                layer.self_attn.layer_cache = True, {"prev_g": x}
         | 
| 430 | 
            +
                        else:
         | 
| 431 | 
            +
                            if layer.self_attn.layer_cache[1]["keys"].numel() != 0:
         | 
| 432 | 
            +
                                x = fn(layer.self_attn.layer_cache[1]["keys"], 0)
         | 
| 433 | 
            +
                                y = fn(layer.self_attn.layer_cache[1]["values"], 0)
         | 
| 434 | 
            +
                                layer.self_attn.layer_cache = True, {"keys": x, "values": y}
         | 
| 435 | 
            +
             | 
| 436 | 
            +
                def detach_state(self):
         | 
| 437 | 
            +
                    raise NotImplementedError
         | 
| 438 | 
            +
             | 
| 439 | 
            +
                def forward(self, *args, **kwargs):
         | 
| 440 | 
            +
                    raise NotImplementedError
         | 
| 441 | 
            +
             | 
| 442 | 
            +
                def update_dropout(self, dropout, attention_dropout):
         | 
| 443 | 
            +
                    self.embeddings.update_dropout(dropout)
         | 
| 444 | 
            +
                    for layer in self.transformer_layers:
         | 
| 445 | 
            +
                        layer.update_dropout(dropout, attention_dropout)
         | 
| 446 | 
            +
             | 
| 447 | 
            +
             | 
| 448 | 
            +
            class TransformerDecoder(TransformerDecoderBase):
         | 
| 449 | 
            +
                """The Transformer decoder from "Attention is All You Need".
         | 
| 450 | 
            +
                :cite:`DBLP:journals/corr/VaswaniSPUJGKP17`
         | 
| 451 | 
            +
             | 
| 452 | 
            +
                Args:
         | 
| 453 | 
            +
                    num_layers (int): number of decoder layers.
         | 
| 454 | 
            +
                    d_model (int): size of the model
         | 
| 455 | 
            +
                    heads (int): number of heads
         | 
| 456 | 
            +
                    d_ff (int): size of the inner FF layer
         | 
| 457 | 
            +
                    copy_attn (bool): if using a separate copy attention
         | 
| 458 | 
            +
                    self_attn_type (str): type of self-attention scaled-dot, average
         | 
| 459 | 
            +
                    dropout (float): dropout in residual, self-attn(dot) and feed-forward
         | 
| 460 | 
            +
                    attention_dropout (float): dropout in context_attn (and self-attn(avg))
         | 
| 461 | 
            +
                    embeddings (onmt.modules.Embeddings):
         | 
| 462 | 
            +
                        embeddings to use, should have positional encodings
         | 
| 463 | 
            +
                    max_relative_positions (int):
         | 
| 464 | 
            +
                        Max distance between inputs in relative positions representations
         | 
| 465 | 
            +
                    relative_positions_buckets (int):
         | 
| 466 | 
            +
                        Number of buckets when using relative position bias
         | 
| 467 | 
            +
                    aan_useffn (bool): Turn on the FFN layer in the AAN decoder
         | 
| 468 | 
            +
                    full_context_alignment (bool):
         | 
| 469 | 
            +
                        whether enable an extra full context decoder forward for alignment
         | 
| 470 | 
            +
                    alignment_layer (int): N° Layer to supervise with for alignment guiding
         | 
| 471 | 
            +
                    alignment_heads (int):
         | 
| 472 | 
            +
                        N. of cross attention heads to use for alignment guiding
         | 
| 473 | 
            +
                    add_qkvbias (bool): whether to add bias to the Key/Value nn.Linear
         | 
| 474 | 
            +
                    layer_norm (string): type of layer normalization standard/rms
         | 
| 475 | 
            +
                """
         | 
| 476 | 
            +
             | 
| 477 | 
            +
                def __init__(
         | 
| 478 | 
            +
                    self,
         | 
| 479 | 
            +
                    num_layers,
         | 
| 480 | 
            +
                    d_model,
         | 
| 481 | 
            +
                    heads,
         | 
| 482 | 
            +
                    d_ff,
         | 
| 483 | 
            +
                    copy_attn,
         | 
| 484 | 
            +
                    self_attn_type,
         | 
| 485 | 
            +
                    dropout,
         | 
| 486 | 
            +
                    attention_dropout,
         | 
| 487 | 
            +
                    embeddings,
         | 
| 488 | 
            +
                    max_relative_positions,
         | 
| 489 | 
            +
                    relative_positions_buckets,
         | 
| 490 | 
            +
                    aan_useffn,
         | 
| 491 | 
            +
                    full_context_alignment,
         | 
| 492 | 
            +
                    alignment_layer,
         | 
| 493 | 
            +
                    alignment_heads,
         | 
| 494 | 
            +
                    pos_ffn_activation_fn=ActivationFunction.relu,
         | 
| 495 | 
            +
                    add_qkvbias=False,
         | 
| 496 | 
            +
                    num_kv=0,
         | 
| 497 | 
            +
                    add_ffnbias=True,
         | 
| 498 | 
            +
                    parallel_residual=False,
         | 
| 499 | 
            +
                    shared_layer_norm=False,
         | 
| 500 | 
            +
                    layer_norm="standard",
         | 
| 501 | 
            +
                    norm_eps=1e-6,
         | 
| 502 | 
            +
                    use_ckpting=[],
         | 
| 503 | 
            +
                    parallel_gpu=1,
         | 
| 504 | 
            +
                ):
         | 
| 505 | 
            +
                    super(TransformerDecoder, self).__init__(
         | 
| 506 | 
            +
                        d_model, copy_attn, embeddings, alignment_layer, layer_norm, norm_eps
         | 
| 507 | 
            +
                    )
         | 
| 508 | 
            +
             | 
| 509 | 
            +
                    self.transformer_layers = nn.ModuleList(
         | 
| 510 | 
            +
                        [
         | 
| 511 | 
            +
                            TransformerDecoderLayer(
         | 
| 512 | 
            +
                                d_model,
         | 
| 513 | 
            +
                                heads,
         | 
| 514 | 
            +
                                d_ff,
         | 
| 515 | 
            +
                                dropout,
         | 
| 516 | 
            +
                                attention_dropout,
         | 
| 517 | 
            +
                                self_attn_type=self_attn_type,
         | 
| 518 | 
            +
                                max_relative_positions=max_relative_positions,
         | 
| 519 | 
            +
                                relative_positions_buckets=relative_positions_buckets,
         | 
| 520 | 
            +
                                aan_useffn=aan_useffn,
         | 
| 521 | 
            +
                                full_context_alignment=full_context_alignment,
         | 
| 522 | 
            +
                                alignment_heads=alignment_heads,
         | 
| 523 | 
            +
                                pos_ffn_activation_fn=pos_ffn_activation_fn,
         | 
| 524 | 
            +
                                add_qkvbias=add_qkvbias,
         | 
| 525 | 
            +
                                num_kv=num_kv,
         | 
| 526 | 
            +
                                add_ffnbias=add_ffnbias,
         | 
| 527 | 
            +
                                parallel_residual=parallel_residual,
         | 
| 528 | 
            +
                                shared_layer_norm=shared_layer_norm,
         | 
| 529 | 
            +
                                layer_norm=layer_norm,
         | 
| 530 | 
            +
                                norm_eps=norm_eps,
         | 
| 531 | 
            +
                                use_ckpting=use_ckpting,
         | 
| 532 | 
            +
                                parallel_gpu=parallel_gpu,
         | 
| 533 | 
            +
                            )
         | 
| 534 | 
            +
                            for i in range(num_layers)
         | 
| 535 | 
            +
                        ]
         | 
| 536 | 
            +
                    )
         | 
| 537 | 
            +
             | 
| 538 | 
            +
                def detach_state(self):
         | 
| 539 | 
            +
                    self.state["src"] = self.state["src"].detach()
         | 
| 540 | 
            +
             | 
| 541 | 
            +
                def forward(self, tgt, enc_out=None, step=None, **kwargs):
         | 
| 542 | 
            +
                    """
         | 
| 543 | 
            +
                    Decode, possibly stepwise.
         | 
| 544 | 
            +
                    when training step is always None, when decoding, step increases
         | 
| 545 | 
            +
                    tgt (Tensor): batch x tlen x feats
         | 
| 546 | 
            +
                    enc_out (Tensor): encoder output (batch x slen x model_dim)
         | 
| 547 | 
            +
                    """
         | 
| 548 | 
            +
                    if enc_out is None:
         | 
| 549 | 
            +
                        enc_out = self.embeddings(tgt)
         | 
| 550 | 
            +
                    if step == 0:
         | 
| 551 | 
            +
                        self._init_cache(enc_out)
         | 
| 552 | 
            +
                    elif step is None:
         | 
| 553 | 
            +
                        for layer in self.transformer_layers:
         | 
| 554 | 
            +
                            if isinstance(layer.self_attn, AverageAttention):
         | 
| 555 | 
            +
                                layer.self_attn.layer_cache = False, {"prev_g": torch.tensor([])}
         | 
| 556 | 
            +
                            else:
         | 
| 557 | 
            +
                                layer.self_attn.layer_cache = (
         | 
| 558 | 
            +
                                    False,
         | 
| 559 | 
            +
                                    {"keys": torch.tensor([]), "values": torch.tensor([])},
         | 
| 560 | 
            +
                                )
         | 
| 561 | 
            +
                            layer.context_attn.layer_cache = (
         | 
| 562 | 
            +
                                False,
         | 
| 563 | 
            +
                                {"keys": torch.tensor([]), "values": torch.tensor([])},
         | 
| 564 | 
            +
                            )
         | 
| 565 | 
            +
             | 
| 566 | 
            +
                    emb = self.embeddings(tgt, step=step)
         | 
| 567 | 
            +
                    dec_out = emb
         | 
| 568 | 
            +
                    assert emb.dim() == 3  # len x batch x embedding_dim
         | 
| 569 | 
            +
             | 
| 570 | 
            +
                    pad_idx = self.embeddings.word_padding_idx
         | 
| 571 | 
            +
                    src_lens = kwargs["src_len"]
         | 
| 572 | 
            +
                    src_max_len = self.state["src"].shape[1]
         | 
| 573 | 
            +
                    src_pad_mask = ~sequence_mask(src_lens, src_max_len)  # [B x slen]
         | 
| 574 | 
            +
                    src_pad_mask = src_pad_mask.unsqueeze(1)  # [B x 1 x slen]
         | 
| 575 | 
            +
                    tgt_pad_mask = tgt[:, :, 0].eq(pad_idx).unsqueeze(1)  # [B, 1, T_tgt]
         | 
| 576 | 
            +
             | 
| 577 | 
            +
                    with_align = kwargs.pop("with_align", False)
         | 
| 578 | 
            +
                    return_attn = with_align or self._copy
         | 
| 579 | 
            +
                    attn_aligns = []
         | 
| 580 | 
            +
             | 
| 581 | 
            +
                    for layer in self.transformer_layers:
         | 
| 582 | 
            +
                        dec_out, attn, attn_align = layer(
         | 
| 583 | 
            +
                            dec_out,
         | 
| 584 | 
            +
                            enc_out,
         | 
| 585 | 
            +
                            src_pad_mask,
         | 
| 586 | 
            +
                            tgt_pad_mask,
         | 
| 587 | 
            +
                            step=step,
         | 
| 588 | 
            +
                            with_align=with_align,
         | 
| 589 | 
            +
                            return_attn=return_attn,
         | 
| 590 | 
            +
                        )
         | 
| 591 | 
            +
                        if attn_align is not None:
         | 
| 592 | 
            +
                            attn_aligns.append(attn_align)
         | 
| 593 | 
            +
             | 
| 594 | 
            +
                    dec_out = self.layer_norm(dec_out)
         | 
| 595 | 
            +
             | 
| 596 | 
            +
                    attns = {"std": attn}
         | 
| 597 | 
            +
                    if self._copy:
         | 
| 598 | 
            +
                        attns["copy"] = attn
         | 
| 599 | 
            +
                    if with_align:
         | 
| 600 | 
            +
                        attns["align"] = attn_aligns[self.alignment_layer]  # `(B, Q, K)`
         | 
| 601 | 
            +
                        # attns["align"] = torch.stack(attn_aligns, 0).mean(0)  # All avg
         | 
| 602 | 
            +
             | 
| 603 | 
            +
                    # TODO change the way attns is returned dict => list or tuple (onnx)
         | 
| 604 | 
            +
                    return dec_out, attns
         | 
| 605 | 
            +
             | 
| 606 | 
            +
                def _init_cache(self, enc_out):
         | 
| 607 | 
            +
                    batch_size = enc_out.size(0)
         | 
| 608 | 
            +
                    depth = enc_out.size(-1)
         | 
| 609 | 
            +
             | 
| 610 | 
            +
                    for layer in self.transformer_layers:
         | 
| 611 | 
            +
                        # first value set to True triggered by the beginning of decoding
         | 
| 612 | 
            +
                        # layer_cache becomes active in the MultiHeadedAttention fwd
         | 
| 613 | 
            +
                        layer.context_attn.layer_cache = (
         | 
| 614 | 
            +
                            True,
         | 
| 615 | 
            +
                            {
         | 
| 616 | 
            +
                                "keys": torch.tensor([], device=enc_out.device),
         | 
| 617 | 
            +
                                "values": torch.tensor([], device=enc_out.device),
         | 
| 618 | 
            +
                            },
         | 
| 619 | 
            +
                        )
         | 
| 620 | 
            +
                        if isinstance(layer.self_attn, AverageAttention):
         | 
| 621 | 
            +
                            layer.self_attn.layer_cache = True, {
         | 
| 622 | 
            +
                                "prev_g": torch.zeros(
         | 
| 623 | 
            +
                                    (batch_size, 1, depth), device=enc_out.device
         | 
| 624 | 
            +
                                ).to(enc_out.dtype)
         | 
| 625 | 
            +
                            }
         | 
| 626 | 
            +
                        else:
         | 
| 627 | 
            +
                            layer.self_attn.layer_cache = (
         | 
| 628 | 
            +
                                True,
         | 
| 629 | 
            +
                                {
         | 
| 630 | 
            +
                                    "keys": torch.tensor([], device=enc_out.device),
         | 
| 631 | 
            +
                                    "values": torch.tensor([], device=enc_out.device),
         | 
| 632 | 
            +
                                },
         | 
| 633 | 
            +
                            )
         | 
| 634 | 
            +
             | 
| 635 | 
            +
             | 
| 636 | 
            +
            class TransformerLMDecoderLayer(TransformerDecoderLayerBase):
         | 
| 637 | 
            +
                """Transformer Decoder only layer block in GPT style.
         | 
| 638 | 
            +
                Args:
         | 
| 639 | 
            +
                     See TransformerDecoderLayerBase
         | 
| 640 | 
            +
                """
         | 
| 641 | 
            +
             | 
| 642 | 
            +
                def _forward(
         | 
| 643 | 
            +
                    self, layer_in, tgt_pad_mask, step=None, future=False, return_attn=False
         | 
| 644 | 
            +
                ):
         | 
| 645 | 
            +
                    """A naive forward pass for transformer decoder.
         | 
| 646 | 
            +
             | 
| 647 | 
            +
                    # T: could be 1 in the case of stepwise decoding or tgt_len
         | 
| 648 | 
            +
             | 
| 649 | 
            +
                    Args:
         | 
| 650 | 
            +
                        layer_in (FloatTensor): ``(batch_size, T, model_dim)``
         | 
| 651 | 
            +
                        tgt_pad_mask (bool): ``(batch_size, 1, T)``
         | 
| 652 | 
            +
                        layer_cache (dict or None): cached layer info when stepwise decode
         | 
| 653 | 
            +
                        step (int or None): stepwise decoding counter
         | 
| 654 | 
            +
                        future (bool): If set True, do not apply future_mask.
         | 
| 655 | 
            +
                        return_attn (bool): If set True return attn
         | 
| 656 | 
            +
             | 
| 657 | 
            +
                    Returns:
         | 
| 658 | 
            +
                        (FloatTensor, FloatTensor):
         | 
| 659 | 
            +
             | 
| 660 | 
            +
                        * layer_out ``(batch_size, T, model_dim)``
         | 
| 661 | 
            +
                        * attns ``(batch_size, head, T, T)``
         | 
| 662 | 
            +
             | 
| 663 | 
            +
                    """
         | 
| 664 | 
            +
                    dec_mask = None
         | 
| 665 | 
            +
             | 
| 666 | 
            +
                    if layer_in.size(1) > 1:
         | 
| 667 | 
            +
                        # masking is necessary when sequence length is greater than one
         | 
| 668 | 
            +
                        dec_mask = self._compute_dec_mask(tgt_pad_mask, future)
         | 
| 669 | 
            +
                        dec_mask = dec_mask.unsqueeze(1)
         | 
| 670 | 
            +
                        dec_mask = dec_mask.expand(-1, -1, dec_mask.size(3), -1)
         | 
| 671 | 
            +
                        # mask now are (batch x 1 x tlen x tlen)
         | 
| 672 | 
            +
                        # 1 = heads to be expanded in MHA
         | 
| 673 | 
            +
             | 
| 674 | 
            +
                    norm_layer_in = self.layer_norm_1(layer_in)
         | 
| 675 | 
            +
             | 
| 676 | 
            +
                    attn_output, attns = self._forward_self_attn(
         | 
| 677 | 
            +
                        norm_layer_in, dec_mask, step, return_attn=return_attn
         | 
| 678 | 
            +
                    )
         | 
| 679 | 
            +
             | 
| 680 | 
            +
                    if self.parallel_residual:
         | 
| 681 | 
            +
                        # feed_forward applies residual, so we remove and apply residual with un-normed
         | 
| 682 | 
            +
                        if not self.shared_layer_norm:
         | 
| 683 | 
            +
                            norm_res_layer_in = self.layer_norm_res(layer_in)
         | 
| 684 | 
            +
                            ff_in = norm_res_layer_in
         | 
| 685 | 
            +
                        else:
         | 
| 686 | 
            +
                            ff_in = norm_layer_in
         | 
| 687 | 
            +
                        layer_out = (
         | 
| 688 | 
            +
                            self.feed_forward(ff_in) - ff_in + layer_in + self.dropout(attn_output)
         | 
| 689 | 
            +
                        )
         | 
| 690 | 
            +
                    else:
         | 
| 691 | 
            +
                        layer_out = self.dropout(attn_output) + layer_in
         | 
| 692 | 
            +
                        layer_out = self.feed_forward(layer_out)
         | 
| 693 | 
            +
             | 
| 694 | 
            +
                    return layer_out, attns
         | 
| 695 | 
            +
             | 
| 696 | 
            +
             | 
| 697 | 
            +
            class TransformerLMDecoder(TransformerDecoderBase):
         | 
| 698 | 
            +
                """The Transformer decoder from GPT-2
         | 
| 699 | 
            +
                Args:
         | 
| 700 | 
            +
                     num_layers (int): number of decoder layers.
         | 
| 701 | 
            +
                     d_model (int): size of the model
         | 
| 702 | 
            +
                     heads (int): number of heads
         | 
| 703 | 
            +
                     d_ff (int): size of the inner FF layer
         | 
| 704 | 
            +
                     copy_attn (bool): if using a separate copy attention
         | 
| 705 | 
            +
                     self_attn_type (str): type of self-attention scaled-dot, average
         | 
| 706 | 
            +
                     dropout (float): dropout in residual, self-attn(dot) and feed-forward
         | 
| 707 | 
            +
                     attention_dropout (float): dropout in context_attn (and self-attn(avg))
         | 
| 708 | 
            +
                     embeddings (onmt.modules.Embeddings):
         | 
| 709 | 
            +
                         embeddings to use, should have positional encodings
         | 
| 710 | 
            +
                     max_relative_positions (int):
         | 
| 711 | 
            +
                         Max distance between inputs in relative positions representations
         | 
| 712 | 
            +
                     relative_positions_buckets (int):
         | 
| 713 | 
            +
                         Number of buckets when using Relative positions bias
         | 
| 714 | 
            +
                     aan_useffn (bool): Turn on the FFN layer in the AAN decoder
         | 
| 715 | 
            +
                     add_qkvbias (bool): whether to add bias to the Key/Value nn.Linear
         | 
| 716 | 
            +
                """
         | 
| 717 | 
            +
             | 
| 718 | 
            +
                def __init__(
         | 
| 719 | 
            +
                    self,
         | 
| 720 | 
            +
                    num_layers,
         | 
| 721 | 
            +
                    d_model,
         | 
| 722 | 
            +
                    heads,
         | 
| 723 | 
            +
                    d_ff,
         | 
| 724 | 
            +
                    copy_attn,
         | 
| 725 | 
            +
                    self_attn_type,
         | 
| 726 | 
            +
                    dropout,
         | 
| 727 | 
            +
                    attention_dropout,
         | 
| 728 | 
            +
                    embeddings,
         | 
| 729 | 
            +
                    max_relative_positions,
         | 
| 730 | 
            +
                    relative_positions_buckets,
         | 
| 731 | 
            +
                    aan_useffn,
         | 
| 732 | 
            +
                    full_context_alignment=None,
         | 
| 733 | 
            +
                    alignment_layer=None,
         | 
| 734 | 
            +
                    alignment_heads=None,
         | 
| 735 | 
            +
                    pos_ffn_activation_fn=ActivationFunction.relu,
         | 
| 736 | 
            +
                    add_qkvbias=False,
         | 
| 737 | 
            +
                    num_kv=0,
         | 
| 738 | 
            +
                    add_ffnbias=True,
         | 
| 739 | 
            +
                    parallel_residual=False,
         | 
| 740 | 
            +
                    shared_layer_norm=False,
         | 
| 741 | 
            +
                    layer_norm="standard",
         | 
| 742 | 
            +
                    norm_eps=1e-6,
         | 
| 743 | 
            +
                    use_ckpting=[],
         | 
| 744 | 
            +
                    parallel_gpu=1,
         | 
| 745 | 
            +
                ):
         | 
| 746 | 
            +
                    super(TransformerLMDecoder, self).__init__(
         | 
| 747 | 
            +
                        d_model, copy_attn, embeddings, alignment_layer, layer_norm, norm_eps
         | 
| 748 | 
            +
                    )
         | 
| 749 | 
            +
                    self.transformer_layers = nn.ModuleList(
         | 
| 750 | 
            +
                        [
         | 
| 751 | 
            +
                            TransformerLMDecoderLayer(
         | 
| 752 | 
            +
                                d_model,
         | 
| 753 | 
            +
                                heads,
         | 
| 754 | 
            +
                                d_ff,
         | 
| 755 | 
            +
                                dropout,
         | 
| 756 | 
            +
                                attention_dropout,
         | 
| 757 | 
            +
                                self_attn_type=self_attn_type,
         | 
| 758 | 
            +
                                max_relative_positions=max_relative_positions,
         | 
| 759 | 
            +
                                relative_positions_buckets=relative_positions_buckets,
         | 
| 760 | 
            +
                                aan_useffn=aan_useffn,
         | 
| 761 | 
            +
                                full_context_alignment=None,
         | 
| 762 | 
            +
                                alignment_heads=None,
         | 
| 763 | 
            +
                                pos_ffn_activation_fn=pos_ffn_activation_fn,
         | 
| 764 | 
            +
                                add_qkvbias=add_qkvbias,
         | 
| 765 | 
            +
                                num_kv=num_kv,
         | 
| 766 | 
            +
                                add_ffnbias=add_ffnbias,
         | 
| 767 | 
            +
                                parallel_residual=parallel_residual,
         | 
| 768 | 
            +
                                shared_layer_norm=shared_layer_norm,
         | 
| 769 | 
            +
                                layer_norm=layer_norm,
         | 
| 770 | 
            +
                                norm_eps=norm_eps,
         | 
| 771 | 
            +
                                use_ckpting=use_ckpting,
         | 
| 772 | 
            +
                                parallel_gpu=parallel_gpu,
         | 
| 773 | 
            +
                            )
         | 
| 774 | 
            +
                            for i in range(num_layers)
         | 
| 775 | 
            +
                        ]
         | 
| 776 | 
            +
                    )
         | 
| 777 | 
            +
             | 
| 778 | 
            +
                def init_state(self, src=None, enc_out=None, enc_final_hs=None):
         | 
| 779 | 
            +
                    super(TransformerLMDecoder, self).init_state(None, None, None)
         | 
| 780 | 
            +
             | 
| 781 | 
            +
                def detach_state(self):
         | 
| 782 | 
            +
                    pass
         | 
| 783 | 
            +
             | 
| 784 | 
            +
                def forward(self, tgt, enc_out=None, step=None, **kwargs):
         | 
| 785 | 
            +
                    """Decode, possibly stepwise."""
         | 
| 786 | 
            +
                    if step == 0:
         | 
| 787 | 
            +
                        self._init_cache(tgt)
         | 
| 788 | 
            +
                    elif step is None:
         | 
| 789 | 
            +
                        for layer in self.transformer_layers:
         | 
| 790 | 
            +
                            layer.self_attn.layer_cache = (
         | 
| 791 | 
            +
                                False,
         | 
| 792 | 
            +
                                {"keys": torch.tensor([]), "values": torch.tensor([])},
         | 
| 793 | 
            +
                            )
         | 
| 794 | 
            +
             | 
| 795 | 
            +
                    dec_out = self.embeddings(tgt, step=step)
         | 
| 796 | 
            +
             | 
| 797 | 
            +
                    assert dec_out.dim() == 3  # batch x len x embedding_dim
         | 
| 798 | 
            +
             | 
| 799 | 
            +
                    pad_idx = self.embeddings.word_padding_idx
         | 
| 800 | 
            +
                    tgt_pad_mask = tgt[:, :, 0].eq(pad_idx).unsqueeze(1)  # [B, 1, T_tgt]
         | 
| 801 | 
            +
             | 
| 802 | 
            +
                    with_align = kwargs.pop("with_align", False)
         | 
| 803 | 
            +
                    return_attn = with_align or self._copy
         | 
| 804 | 
            +
                    assert not with_align, "TransformerLMDecoder does not support align"
         | 
| 805 | 
            +
             | 
| 806 | 
            +
                    for layer in self.transformer_layers:
         | 
| 807 | 
            +
                        dec_out, attn, _ = layer(
         | 
| 808 | 
            +
                            dec_out,
         | 
| 809 | 
            +
                            tgt_pad_mask,
         | 
| 810 | 
            +
                            step=step,
         | 
| 811 | 
            +
                            with_align=with_align,
         | 
| 812 | 
            +
                            return_attn=return_attn,
         | 
| 813 | 
            +
                        )
         | 
| 814 | 
            +
             | 
| 815 | 
            +
                    dec_out = self.layer_norm(dec_out)
         | 
| 816 | 
            +
             | 
| 817 | 
            +
                    attns = {"std": attn}
         | 
| 818 | 
            +
                    if self._copy:
         | 
| 819 | 
            +
                        attns["copy"] = attn
         | 
| 820 | 
            +
             | 
| 821 | 
            +
                    # TODO change the way attns is returned dict => list or tuple (onnx)
         | 
| 822 | 
            +
                    return dec_out, attns
         | 
| 823 | 
            +
             | 
| 824 | 
            +
                def _init_cache(self, tgt=None):
         | 
| 825 | 
            +
                    for layer in self.transformer_layers:
         | 
| 826 | 
            +
                        if isinstance(layer.self_attn, AverageAttention):
         | 
| 827 | 
            +
                            raise NotImplementedError
         | 
| 828 | 
            +
                        else:
         | 
| 829 | 
            +
                            layer.self_attn.layer_cache = (
         | 
| 830 | 
            +
                                True,
         | 
| 831 | 
            +
                                {
         | 
| 832 | 
            +
                                    "keys": torch.tensor([], device=tgt.device),
         | 
| 833 | 
            +
                                    "values": torch.tensor([], device=tgt.device),
         | 
| 834 | 
            +
                                },
         | 
| 835 | 
            +
                            )
         | 
    	
        onmt/encoders/__init__.py
    ADDED
    
    | @@ -0,0 +1,67 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            """Module defining encoders."""
         | 
| 2 | 
            +
            import os
         | 
| 3 | 
            +
            import importlib
         | 
| 4 | 
            +
            from onmt.encoders.encoder import EncoderBase
         | 
| 5 | 
            +
            from onmt.encoders.transformer import TransformerEncoder
         | 
| 6 | 
            +
            from onmt.encoders.ggnn_encoder import GGNNEncoder
         | 
| 7 | 
            +
            from onmt.encoders.rnn_encoder import RNNEncoder
         | 
| 8 | 
            +
            from onmt.encoders.cnn_encoder import CNNEncoder
         | 
| 9 | 
            +
            from onmt.encoders.mean_encoder import MeanEncoder
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            str2enc = {
         | 
| 13 | 
            +
                "ggnn": GGNNEncoder,
         | 
| 14 | 
            +
                "rnn": RNNEncoder,
         | 
| 15 | 
            +
                "brnn": RNNEncoder,
         | 
| 16 | 
            +
                "cnn": CNNEncoder,
         | 
| 17 | 
            +
                "transformer": TransformerEncoder,
         | 
| 18 | 
            +
                "mean": MeanEncoder,
         | 
| 19 | 
            +
            }
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            __all__ = [
         | 
| 22 | 
            +
                "EncoderBase",
         | 
| 23 | 
            +
                "TransformerEncoder",
         | 
| 24 | 
            +
                "GGNNEncoder",
         | 
| 25 | 
            +
                "RNNEncoder",
         | 
| 26 | 
            +
                "CNNEncoder",
         | 
| 27 | 
            +
                "MeanEncoder",
         | 
| 28 | 
            +
                "str2enc",
         | 
| 29 | 
            +
            ]
         | 
| 30 | 
            +
             | 
| 31 | 
            +
             | 
| 32 | 
            +
            def get_encoders_cls(encoder_names):
         | 
| 33 | 
            +
                """Return valid encoder class indicated in `encoder_names`."""
         | 
| 34 | 
            +
                encoders_cls = {}
         | 
| 35 | 
            +
                for name in encoder_names:
         | 
| 36 | 
            +
                    if name not in str2enc:
         | 
| 37 | 
            +
                        raise ValueError("%s encoder not supported!" % name)
         | 
| 38 | 
            +
                    encoders_cls[name] = str2enc[name]
         | 
| 39 | 
            +
                return encoders_cls
         | 
| 40 | 
            +
             | 
| 41 | 
            +
             | 
| 42 | 
            +
            def register_encoder(name):
         | 
| 43 | 
            +
                """Encoder register that can be used to add new encoder class."""
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                def register_encoder_cls(cls):
         | 
| 46 | 
            +
                    if name in str2enc:
         | 
| 47 | 
            +
                        raise ValueError("Cannot register duplicate encoder ({})".format(name))
         | 
| 48 | 
            +
                    if not issubclass(cls, EncoderBase):
         | 
| 49 | 
            +
                        raise ValueError(f"encoder ({name}: {cls.__name_}) must extend EncoderBase")
         | 
| 50 | 
            +
                    str2enc[name] = cls
         | 
| 51 | 
            +
                    __all__.append(cls.__name__)  # added to be complete
         | 
| 52 | 
            +
                    return cls
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                return register_encoder_cls
         | 
| 55 | 
            +
             | 
| 56 | 
            +
             | 
| 57 | 
            +
            # Auto import python files in this directory
         | 
| 58 | 
            +
            encoder_dir = os.path.dirname(__file__)
         | 
| 59 | 
            +
            for file in os.listdir(encoder_dir):
         | 
| 60 | 
            +
                path = os.path.join(encoder_dir, file)
         | 
| 61 | 
            +
                if (
         | 
| 62 | 
            +
                    not file.startswith("_")
         | 
| 63 | 
            +
                    and not file.startswith(".")
         | 
| 64 | 
            +
                    and (file.endswith(".py") or os.path.isdir(path))
         | 
| 65 | 
            +
                ):
         | 
| 66 | 
            +
                    file_name = file[: file.find(".py")] if file.endswith(".py") else file
         | 
| 67 | 
            +
                    module = importlib.import_module("onmt.encoders." + file_name)
         | 
    	
        onmt/encoders/__pycache__/__init__.cpython-311.pyc
    ADDED
    
    | Binary file (3.13 kB). View file | 
|  |