#!/usr/bin/env python3 -u # Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. """ Generate n-best translations using a trained model. """ import os import subprocess from contextlib import redirect_stdout from fairseq import options from fairseq_cli import generate, preprocess from examples.noisychannel import rerank_options, rerank_utils def gen_and_reprocess_nbest(args): if args.score_dict_dir is None: args.score_dict_dir = args.data if args.prefix_len is not None: assert ( args.right_to_left1 is False ), "prefix length not compatible with right to left models" assert ( args.right_to_left2 is False ), "prefix length not compatible with right to left models" if args.nbest_list is not None: assert args.score_model2 is None if args.backwards1: scorer1_src = args.target_lang scorer1_tgt = args.source_lang else: scorer1_src = args.source_lang scorer1_tgt = args.target_lang store_data = ( os.path.join(os.path.dirname(__file__)) + "/rerank_data/" + args.data_dir_name ) if not os.path.exists(store_data): os.makedirs(store_data) ( pre_gen, left_to_right_preprocessed_dir, right_to_left_preprocessed_dir, backwards_preprocessed_dir, lm_preprocessed_dir, ) = rerank_utils.get_directories( args.data_dir_name, args.num_rescore, args.gen_subset, args.gen_model_name, args.shard_id, args.num_shards, args.sampling, args.prefix_len, args.target_prefix_frac, args.source_prefix_frac, ) assert not ( args.right_to_left1 and args.backwards1 ), "backwards right to left not supported" assert not ( args.right_to_left2 and args.backwards2 ), "backwards right to left not supported" assert not ( args.prefix_len is not None and args.target_prefix_frac is not None ), "target prefix frac and target prefix len incompatible" # make directory to store generation results if not os.path.exists(pre_gen): os.makedirs(pre_gen) rerank1_is_gen = ( args.gen_model == args.score_model1 and args.source_prefix_frac is None ) rerank2_is_gen = ( args.gen_model == args.score_model2 and args.source_prefix_frac is None ) if args.nbest_list is not None: rerank2_is_gen = True # make directories to store preprossed nbest list for reranking if not os.path.exists(left_to_right_preprocessed_dir): os.makedirs(left_to_right_preprocessed_dir) if not os.path.exists(right_to_left_preprocessed_dir): os.makedirs(right_to_left_preprocessed_dir) if not os.path.exists(lm_preprocessed_dir): os.makedirs(lm_preprocessed_dir) if not os.path.exists(backwards_preprocessed_dir): os.makedirs(backwards_preprocessed_dir) score1_file = rerank_utils.rescore_file_name( pre_gen, args.prefix_len, args.model1_name, target_prefix_frac=args.target_prefix_frac, source_prefix_frac=args.source_prefix_frac, backwards=args.backwards1, ) if args.score_model2 is not None: score2_file = rerank_utils.rescore_file_name( pre_gen, args.prefix_len, args.model2_name, target_prefix_frac=args.target_prefix_frac, source_prefix_frac=args.source_prefix_frac, backwards=args.backwards2, ) predictions_bpe_file = pre_gen + "/generate_output_bpe.txt" using_nbest = args.nbest_list is not None if using_nbest: print("Using predefined n-best list from interactive.py") predictions_bpe_file = args.nbest_list else: if not os.path.isfile(predictions_bpe_file): print("STEP 1: generate predictions using the p(T|S) model with bpe") print(args.data) param1 = [ args.data, "--path", args.gen_model, "--shard-id", str(args.shard_id), "--num-shards", str(args.num_shards), "--nbest", str(args.num_rescore), "--batch-size", str(args.batch_size), "--beam", str(args.num_rescore), "--batch-size", str(args.num_rescore), "--gen-subset", args.gen_subset, "--source-lang", args.source_lang, "--target-lang", args.target_lang, ] if args.sampling: param1 += ["--sampling"] gen_parser = options.get_generation_parser() input_args = options.parse_args_and_arch(gen_parser, param1) print(input_args) with open(predictions_bpe_file, "w") as f: with redirect_stdout(f): generate.main(input_args) gen_output = rerank_utils.BitextOutputFromGen( predictions_bpe_file, bpe_symbol=args.post_process, nbest=using_nbest, prefix_len=args.prefix_len, target_prefix_frac=args.target_prefix_frac, ) if args.diff_bpe: rerank_utils.write_reprocessed( gen_output.no_bpe_source, gen_output.no_bpe_hypo, gen_output.no_bpe_target, pre_gen + "/source_gen_bpe." + args.source_lang, pre_gen + "/target_gen_bpe." + args.target_lang, pre_gen + "/reference_gen_bpe." + args.target_lang, ) bitext_bpe = args.rescore_bpe_code bpe_src_param = [ "-c", bitext_bpe, "--input", pre_gen + "/source_gen_bpe." + args.source_lang, "--output", pre_gen + "/rescore_data." + args.source_lang, ] bpe_tgt_param = [ "-c", bitext_bpe, "--input", pre_gen + "/target_gen_bpe." + args.target_lang, "--output", pre_gen + "/rescore_data." + args.target_lang, ] subprocess.call( [ "python", os.path.join( os.path.dirname(__file__), "subword-nmt/subword_nmt/apply_bpe.py" ), ] + bpe_src_param, shell=False, ) subprocess.call( [ "python", os.path.join( os.path.dirname(__file__), "subword-nmt/subword_nmt/apply_bpe.py" ), ] + bpe_tgt_param, shell=False, ) if (not os.path.isfile(score1_file) and not rerank1_is_gen) or ( args.score_model2 is not None and not os.path.isfile(score2_file) and not rerank2_is_gen ): print( "STEP 2: process the output of generate.py so we have clean text files with the translations" ) rescore_file = "/rescore_data" if args.prefix_len is not None: prefix_len_rescore_file = rescore_file + "prefix" + str(args.prefix_len) if args.target_prefix_frac is not None: target_prefix_frac_rescore_file = ( rescore_file + "target_prefix_frac" + str(args.target_prefix_frac) ) if args.source_prefix_frac is not None: source_prefix_frac_rescore_file = ( rescore_file + "source_prefix_frac" + str(args.source_prefix_frac) ) if not args.right_to_left1 or not args.right_to_left2: if not args.diff_bpe: rerank_utils.write_reprocessed( gen_output.source, gen_output.hypo, gen_output.target, pre_gen + rescore_file + "." + args.source_lang, pre_gen + rescore_file + "." + args.target_lang, pre_gen + "/reference_file", bpe_symbol=args.post_process, ) if args.prefix_len is not None: bw_rescore_file = prefix_len_rescore_file rerank_utils.write_reprocessed( gen_output.source, gen_output.hypo, gen_output.target, pre_gen + prefix_len_rescore_file + "." + args.source_lang, pre_gen + prefix_len_rescore_file + "." + args.target_lang, pre_gen + "/reference_file", prefix_len=args.prefix_len, bpe_symbol=args.post_process, ) elif args.target_prefix_frac is not None: bw_rescore_file = target_prefix_frac_rescore_file rerank_utils.write_reprocessed( gen_output.source, gen_output.hypo, gen_output.target, pre_gen + target_prefix_frac_rescore_file + "." + args.source_lang, pre_gen + target_prefix_frac_rescore_file + "." + args.target_lang, pre_gen + "/reference_file", bpe_symbol=args.post_process, target_prefix_frac=args.target_prefix_frac, ) else: bw_rescore_file = rescore_file if args.source_prefix_frac is not None: fw_rescore_file = source_prefix_frac_rescore_file rerank_utils.write_reprocessed( gen_output.source, gen_output.hypo, gen_output.target, pre_gen + source_prefix_frac_rescore_file + "." + args.source_lang, pre_gen + source_prefix_frac_rescore_file + "." + args.target_lang, pre_gen + "/reference_file", bpe_symbol=args.post_process, source_prefix_frac=args.source_prefix_frac, ) else: fw_rescore_file = rescore_file if args.right_to_left1 or args.right_to_left2: rerank_utils.write_reprocessed( gen_output.source, gen_output.hypo, gen_output.target, pre_gen + "/right_to_left_rescore_data." + args.source_lang, pre_gen + "/right_to_left_rescore_data." + args.target_lang, pre_gen + "/right_to_left_reference_file", right_to_left=True, bpe_symbol=args.post_process, ) print("STEP 3: binarize the translations") if ( not args.right_to_left1 or args.score_model2 is not None and not args.right_to_left2 or not rerank1_is_gen ): if args.backwards1 or args.backwards2: if args.backwards_score_dict_dir is not None: bw_dict = args.backwards_score_dict_dir else: bw_dict = args.score_dict_dir bw_preprocess_param = [ "--source-lang", scorer1_src, "--target-lang", scorer1_tgt, "--trainpref", pre_gen + bw_rescore_file, "--srcdict", bw_dict + "/dict." + scorer1_src + ".txt", "--tgtdict", bw_dict + "/dict." + scorer1_tgt + ".txt", "--destdir", backwards_preprocessed_dir, ] preprocess_parser = options.get_preprocessing_parser() input_args = preprocess_parser.parse_args(bw_preprocess_param) preprocess.main(input_args) preprocess_param = [ "--source-lang", scorer1_src, "--target-lang", scorer1_tgt, "--trainpref", pre_gen + fw_rescore_file, "--srcdict", args.score_dict_dir + "/dict." + scorer1_src + ".txt", "--tgtdict", args.score_dict_dir + "/dict." + scorer1_tgt + ".txt", "--destdir", left_to_right_preprocessed_dir, ] preprocess_parser = options.get_preprocessing_parser() input_args = preprocess_parser.parse_args(preprocess_param) preprocess.main(input_args) if args.right_to_left1 or args.right_to_left2: preprocess_param = [ "--source-lang", scorer1_src, "--target-lang", scorer1_tgt, "--trainpref", pre_gen + "/right_to_left_rescore_data", "--srcdict", args.score_dict_dir + "/dict." + scorer1_src + ".txt", "--tgtdict", args.score_dict_dir + "/dict." + scorer1_tgt + ".txt", "--destdir", right_to_left_preprocessed_dir, ] preprocess_parser = options.get_preprocessing_parser() input_args = preprocess_parser.parse_args(preprocess_param) preprocess.main(input_args) return gen_output def cli_main(): parser = rerank_options.get_reranking_parser() args = options.parse_args_and_arch(parser) gen_and_reprocess_nbest(args) if __name__ == "__main__": cli_main()