|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | """Neural machine translation model decoding script.""" | 
					
						
						|  |  | 
					
						
						|  | import configargparse | 
					
						
						|  | import logging | 
					
						
						|  | import os | 
					
						
						|  | import random | 
					
						
						|  | import sys | 
					
						
						|  |  | 
					
						
						|  | import numpy as np | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def get_parser(): | 
					
						
						|  | """Get default arguments.""" | 
					
						
						|  | parser = configargparse.ArgumentParser( | 
					
						
						|  | description="Translate text from speech " | 
					
						
						|  | "using a speech translation model on one CPU or GPU", | 
					
						
						|  | config_file_parser_class=configargparse.YAMLConfigFileParser, | 
					
						
						|  | formatter_class=configargparse.ArgumentDefaultsHelpFormatter, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | parser.add("--config", is_config_file=True, help="Config file path") | 
					
						
						|  | parser.add( | 
					
						
						|  | "--config2", | 
					
						
						|  | is_config_file=True, | 
					
						
						|  | help="Second config file path that overwrites the settings in `--config`", | 
					
						
						|  | ) | 
					
						
						|  | parser.add( | 
					
						
						|  | "--config3", | 
					
						
						|  | is_config_file=True, | 
					
						
						|  | help="Third config file path " | 
					
						
						|  | "that overwrites the settings in `--config` and `--config2`", | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | parser.add_argument("--ngpu", type=int, default=0, help="Number of GPUs") | 
					
						
						|  | parser.add_argument( | 
					
						
						|  | "--dtype", | 
					
						
						|  | choices=("float16", "float32", "float64"), | 
					
						
						|  | default="float32", | 
					
						
						|  | help="Float precision (only available in --api v2)", | 
					
						
						|  | ) | 
					
						
						|  | parser.add_argument( | 
					
						
						|  | "--backend", | 
					
						
						|  | type=str, | 
					
						
						|  | default="chainer", | 
					
						
						|  | choices=["chainer", "pytorch"], | 
					
						
						|  | help="Backend library", | 
					
						
						|  | ) | 
					
						
						|  | parser.add_argument("--debugmode", type=int, default=1, help="Debugmode") | 
					
						
						|  | parser.add_argument("--seed", type=int, default=1, help="Random seed") | 
					
						
						|  | parser.add_argument("--verbose", "-V", type=int, default=1, help="Verbose option") | 
					
						
						|  | parser.add_argument( | 
					
						
						|  | "--batchsize", | 
					
						
						|  | type=int, | 
					
						
						|  | default=1, | 
					
						
						|  | help="Batch size for beam search (0: means no batch processing)", | 
					
						
						|  | ) | 
					
						
						|  | parser.add_argument( | 
					
						
						|  | "--preprocess-conf", | 
					
						
						|  | type=str, | 
					
						
						|  | default=None, | 
					
						
						|  | help="The configuration file for the pre-processing", | 
					
						
						|  | ) | 
					
						
						|  | parser.add_argument( | 
					
						
						|  | "--api", | 
					
						
						|  | default="v1", | 
					
						
						|  | choices=["v1", "v2"], | 
					
						
						|  | help="Beam search APIs " | 
					
						
						|  | "v1: Default API. It only supports " | 
					
						
						|  | "the ASRInterface.recognize method and DefaultRNNLM. " | 
					
						
						|  | "v2: Experimental API. " | 
					
						
						|  | "It supports any models that implements ScorerInterface.", | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | parser.add_argument( | 
					
						
						|  | "--trans-json", type=str, help="Filename of translation data (json)" | 
					
						
						|  | ) | 
					
						
						|  | parser.add_argument( | 
					
						
						|  | "--result-label", | 
					
						
						|  | type=str, | 
					
						
						|  | required=True, | 
					
						
						|  | help="Filename of result label data (json)", | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | parser.add_argument( | 
					
						
						|  | "--model", type=str, required=True, help="Model file parameters to read" | 
					
						
						|  | ) | 
					
						
						|  | parser.add_argument( | 
					
						
						|  | "--model-conf", type=str, default=None, help="Model config file" | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | parser.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses") | 
					
						
						|  | parser.add_argument("--beam-size", type=int, default=1, help="Beam size") | 
					
						
						|  | parser.add_argument("--penalty", type=float, default=0.1, help="Incertion penalty") | 
					
						
						|  | parser.add_argument( | 
					
						
						|  | "--maxlenratio", | 
					
						
						|  | type=float, | 
					
						
						|  | default=3.0, | 
					
						
						|  | help="""Input length ratio to obtain max output length. | 
					
						
						|  | If maxlenratio=0.0 (default), it uses a end-detect function | 
					
						
						|  | to automatically find maximum hypothesis lengths""", | 
					
						
						|  | ) | 
					
						
						|  | parser.add_argument( | 
					
						
						|  | "--minlenratio", | 
					
						
						|  | type=float, | 
					
						
						|  | default=0.0, | 
					
						
						|  | help="Input length ratio to obtain min output length", | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | parser.add_argument( | 
					
						
						|  | "--tgt-lang", | 
					
						
						|  | default=False, | 
					
						
						|  | type=str, | 
					
						
						|  | help="target language ID (e.g., <en>, <de>, and <fr> etc.)", | 
					
						
						|  | ) | 
					
						
						|  | return parser | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def main(args): | 
					
						
						|  | """Run the main decoding function.""" | 
					
						
						|  | parser = get_parser() | 
					
						
						|  | args = parser.parse_args(args) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if args.verbose == 1: | 
					
						
						|  | logging.basicConfig( | 
					
						
						|  | level=logging.INFO, | 
					
						
						|  | format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", | 
					
						
						|  | ) | 
					
						
						|  | elif args.verbose == 2: | 
					
						
						|  | logging.basicConfig( | 
					
						
						|  | level=logging.DEBUG, | 
					
						
						|  | format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", | 
					
						
						|  | ) | 
					
						
						|  | else: | 
					
						
						|  | logging.basicConfig( | 
					
						
						|  | level=logging.WARN, | 
					
						
						|  | format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", | 
					
						
						|  | ) | 
					
						
						|  | logging.warning("Skip DEBUG/INFO messages") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if args.ngpu > 0: | 
					
						
						|  | cvd = os.environ.get("CUDA_VISIBLE_DEVICES") | 
					
						
						|  | if cvd is None: | 
					
						
						|  | logging.warning("CUDA_VISIBLE_DEVICES is not set.") | 
					
						
						|  | elif args.ngpu != len(cvd.split(",")): | 
					
						
						|  | logging.error("#gpus is not matched with CUDA_VISIBLE_DEVICES.") | 
					
						
						|  | sys.exit(1) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if args.ngpu > 1: | 
					
						
						|  | logging.error("The program only supports ngpu=1.") | 
					
						
						|  | sys.exit(1) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | logging.info("python path = " + os.environ.get("PYTHONPATH", "(None)")) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | random.seed(args.seed) | 
					
						
						|  | np.random.seed(args.seed) | 
					
						
						|  | logging.info("set random seed = %d" % args.seed) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | logging.info("backend = " + args.backend) | 
					
						
						|  | if args.backend == "pytorch": | 
					
						
						|  |  | 
					
						
						|  | from espnet.mt.pytorch_backend.mt import trans | 
					
						
						|  |  | 
					
						
						|  | if args.dtype != "float32": | 
					
						
						|  | raise NotImplementedError( | 
					
						
						|  | f"`--dtype {args.dtype}` is only available with `--api v2`" | 
					
						
						|  | ) | 
					
						
						|  | trans(args) | 
					
						
						|  | else: | 
					
						
						|  | raise ValueError("Only pytorch are supported.") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if __name__ == "__main__": | 
					
						
						|  | main(sys.argv[1:]) | 
					
						
						|  |  |