# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import os from contextlib import redirect_stdout from fairseq import options from fairseq_cli import generate from examples.noisychannel import rerank_options, rerank_utils def score_bw(args): if args.backwards1: scorer1_src = args.target_lang scorer1_tgt = args.source_lang else: scorer1_src = args.source_lang scorer1_tgt = args.target_lang if args.score_model2 is not None: if args.backwards2: scorer2_src = args.target_lang scorer2_tgt = args.source_lang else: scorer2_src = args.source_lang scorer2_tgt = args.target_lang rerank1_is_gen = ( args.gen_model == args.score_model1 and args.source_prefix_frac is None ) rerank2_is_gen = ( args.gen_model == args.score_model2 and args.source_prefix_frac is None ) ( pre_gen, left_to_right_preprocessed_dir, right_to_left_preprocessed_dir, backwards_preprocessed_dir, lm_preprocessed_dir, ) = rerank_utils.get_directories( args.data_dir_name, args.num_rescore, args.gen_subset, args.gen_model_name, args.shard_id, args.num_shards, args.sampling, args.prefix_len, args.target_prefix_frac, args.source_prefix_frac, ) score1_file = rerank_utils.rescore_file_name( pre_gen, args.prefix_len, args.model1_name, target_prefix_frac=args.target_prefix_frac, source_prefix_frac=args.source_prefix_frac, backwards=args.backwards1, ) if args.score_model2 is not None: score2_file = rerank_utils.rescore_file_name( pre_gen, args.prefix_len, args.model2_name, target_prefix_frac=args.target_prefix_frac, source_prefix_frac=args.source_prefix_frac, backwards=args.backwards2, ) if args.right_to_left1: rerank_data1 = right_to_left_preprocessed_dir elif args.backwards1: rerank_data1 = backwards_preprocessed_dir else: rerank_data1 = left_to_right_preprocessed_dir gen_param = ["--batch-size", str(128), "--score-reference", "--gen-subset", "train"] if not rerank1_is_gen and not os.path.isfile(score1_file): print("STEP 4: score the translations for model 1") model_param1 = [ "--path", args.score_model1, "--source-lang", scorer1_src, "--target-lang", scorer1_tgt, ] gen_model1_param = [rerank_data1] + gen_param + model_param1 gen_parser = options.get_generation_parser() input_args = options.parse_args_and_arch(gen_parser, gen_model1_param) with open(score1_file, "w") as f: with redirect_stdout(f): generate.main(input_args) if ( args.score_model2 is not None and not os.path.isfile(score2_file) and not rerank2_is_gen ): print("STEP 4: score the translations for model 2") if args.right_to_left2: rerank_data2 = right_to_left_preprocessed_dir elif args.backwards2: rerank_data2 = backwards_preprocessed_dir else: rerank_data2 = left_to_right_preprocessed_dir model_param2 = [ "--path", args.score_model2, "--source-lang", scorer2_src, "--target-lang", scorer2_tgt, ] gen_model2_param = [rerank_data2] + gen_param + model_param2 gen_parser = options.get_generation_parser() input_args = options.parse_args_and_arch(gen_parser, gen_model2_param) with open(score2_file, "w") as f: with redirect_stdout(f): generate.main(input_args) def cli_main(): parser = rerank_options.get_reranking_parser() args = options.parse_args_and_arch(parser) score_bw(args) if __name__ == "__main__": cli_main()