# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import math from multiprocessing import Pool import numpy as np from fairseq import options from fairseq.data import dictionary from fairseq.scoring import bleu from examples.noisychannel import ( rerank_generate, rerank_options, rerank_score_bw, rerank_score_lm, rerank_utils, ) def score_target_hypo( args, a, b, c, lenpen, target_outfile, hypo_outfile, write_hypos, normalize ): print("lenpen", lenpen, "weight1", a, "weight2", b, "weight3", c) gen_output_lst, bitext1_lst, bitext2_lst, lm_res_lst = load_score_files(args) dict = dictionary.Dictionary() scorer = scorer = bleu.Scorer( bleu.BleuConfig( pad=dict.pad(), eos=dict.eos(), unk=dict.unk(), ) ) ordered_hypos = {} ordered_targets = {} for shard_id in range(len(bitext1_lst)): bitext1 = bitext1_lst[shard_id] bitext2 = bitext2_lst[shard_id] gen_output = gen_output_lst[shard_id] lm_res = lm_res_lst[shard_id] total = len(bitext1.rescore_source.keys()) source_lst = [] hypo_lst = [] score_lst = [] reference_lst = [] j = 1 best_score = -math.inf for i in range(total): # length is measured in terms of words, not bpe tokens, since models may not share the same bpe target_len = len(bitext1.rescore_hypo[i].split()) if lm_res is not None: lm_score = lm_res.score[i] else: lm_score = 0 if bitext2 is not None: bitext2_score = bitext2.rescore_score[i] bitext2_backwards = bitext2.backwards else: bitext2_score = None bitext2_backwards = None score = rerank_utils.get_score( a, b, c, target_len, bitext1.rescore_score[i], bitext2_score, lm_score=lm_score, lenpen=lenpen, src_len=bitext1.source_lengths[i], tgt_len=bitext1.target_lengths[i], bitext1_backwards=bitext1.backwards, bitext2_backwards=bitext2_backwards, normalize=normalize, ) if score > best_score: best_score = score best_hypo = bitext1.rescore_hypo[i] if j == gen_output.num_hypos[i] or j == args.num_rescore: j = 1 hypo_lst.append(best_hypo) score_lst.append(best_score) source_lst.append(bitext1.rescore_source[i]) reference_lst.append(bitext1.rescore_target[i]) best_score = -math.inf best_hypo = "" else: j += 1 gen_keys = list(sorted(gen_output.no_bpe_target.keys())) for key in range(len(gen_keys)): if args.prefix_len is None: assert hypo_lst[key] in gen_output.no_bpe_hypo[gen_keys[key]], ( "pred and rescore hypo mismatch: i: " + str(key) + ", " + str(hypo_lst[key]) + str(gen_keys[key]) + str(gen_output.no_bpe_hypo[key]) ) sys_tok = dict.encode_line(hypo_lst[key]) ref_tok = dict.encode_line(gen_output.no_bpe_target[gen_keys[key]]) scorer.add(ref_tok, sys_tok) else: full_hypo = rerank_utils.get_full_from_prefix( hypo_lst[key], gen_output.no_bpe_hypo[gen_keys[key]] ) sys_tok = dict.encode_line(full_hypo) ref_tok = dict.encode_line(gen_output.no_bpe_target[gen_keys[key]]) scorer.add(ref_tok, sys_tok) # if only one set of hyper parameters is provided, write the predictions to a file if write_hypos: # recover the orinal ids from n best list generation for key in range(len(gen_output.no_bpe_target)): if args.prefix_len is None: assert hypo_lst[key] in gen_output.no_bpe_hypo[gen_keys[key]], ( "pred and rescore hypo mismatch:" + "i:" + str(key) + str(hypo_lst[key]) + str(gen_output.no_bpe_hypo[key]) ) ordered_hypos[gen_keys[key]] = hypo_lst[key] ordered_targets[gen_keys[key]] = gen_output.no_bpe_target[ gen_keys[key] ] else: full_hypo = rerank_utils.get_full_from_prefix( hypo_lst[key], gen_output.no_bpe_hypo[gen_keys[key]] ) ordered_hypos[gen_keys[key]] = full_hypo ordered_targets[gen_keys[key]] = gen_output.no_bpe_target[ gen_keys[key] ] # write the hypos in the original order from nbest list generation if args.num_shards == (len(bitext1_lst)): with open(target_outfile, "w") as t: with open(hypo_outfile, "w") as h: for key in range(len(ordered_hypos)): t.write(ordered_targets[key]) h.write(ordered_hypos[key]) res = scorer.result_string(4) if write_hypos: print(res) score = rerank_utils.parse_bleu_scoring(res) return score def match_target_hypo(args, target_outfile, hypo_outfile): """combine scores from the LM and bitext models, and write the top scoring hypothesis to a file""" if len(args.weight1) == 1: res = score_target_hypo( args, args.weight1[0], args.weight2[0], args.weight3[0], args.lenpen[0], target_outfile, hypo_outfile, True, args.normalize, ) rerank_scores = [res] else: print("launching pool") with Pool(32) as p: rerank_scores = p.starmap( score_target_hypo, [ ( args, args.weight1[i], args.weight2[i], args.weight3[i], args.lenpen[i], target_outfile, hypo_outfile, False, args.normalize, ) for i in range(len(args.weight1)) ], ) if len(rerank_scores) > 1: best_index = np.argmax(rerank_scores) best_score = rerank_scores[best_index] print("best score", best_score) print("best lenpen", args.lenpen[best_index]) print("best weight1", args.weight1[best_index]) print("best weight2", args.weight2[best_index]) print("best weight3", args.weight3[best_index]) return ( args.lenpen[best_index], args.weight1[best_index], args.weight2[best_index], args.weight3[best_index], best_score, ) else: return ( args.lenpen[0], args.weight1[0], args.weight2[0], args.weight3[0], rerank_scores[0], ) def load_score_files(args): if args.all_shards: shard_ids = list(range(args.num_shards)) else: shard_ids = [args.shard_id] gen_output_lst = [] bitext1_lst = [] bitext2_lst = [] lm_res1_lst = [] for shard_id in shard_ids: using_nbest = args.nbest_list is not None ( pre_gen, left_to_right_preprocessed_dir, right_to_left_preprocessed_dir, backwards_preprocessed_dir, lm_preprocessed_dir, ) = rerank_utils.get_directories( args.data_dir_name, args.num_rescore, args.gen_subset, args.gen_model_name, shard_id, args.num_shards, args.sampling, args.prefix_len, args.target_prefix_frac, args.source_prefix_frac, ) rerank1_is_gen = ( args.gen_model == args.score_model1 and args.source_prefix_frac is None ) rerank2_is_gen = ( args.gen_model == args.score_model2 and args.source_prefix_frac is None ) score1_file = rerank_utils.rescore_file_name( pre_gen, args.prefix_len, args.model1_name, target_prefix_frac=args.target_prefix_frac, source_prefix_frac=args.source_prefix_frac, backwards=args.backwards1, ) if args.score_model2 is not None: score2_file = rerank_utils.rescore_file_name( pre_gen, args.prefix_len, args.model2_name, target_prefix_frac=args.target_prefix_frac, source_prefix_frac=args.source_prefix_frac, backwards=args.backwards2, ) if args.language_model is not None: lm_score_file = rerank_utils.rescore_file_name( pre_gen, args.prefix_len, args.lm_name, lm_file=True ) # get gen output predictions_bpe_file = pre_gen + "/generate_output_bpe.txt" if using_nbest: print("Using predefined n-best list from interactive.py") predictions_bpe_file = args.nbest_list gen_output = rerank_utils.BitextOutputFromGen( predictions_bpe_file, bpe_symbol=args.post_process, nbest=using_nbest, prefix_len=args.prefix_len, target_prefix_frac=args.target_prefix_frac, ) if rerank1_is_gen: bitext1 = gen_output else: bitext1 = rerank_utils.BitextOutput( score1_file, args.backwards1, args.right_to_left1, args.post_process, args.prefix_len, args.target_prefix_frac, args.source_prefix_frac, ) if args.score_model2 is not None or args.nbest_list is not None: if rerank2_is_gen: bitext2 = gen_output else: bitext2 = rerank_utils.BitextOutput( score2_file, args.backwards2, args.right_to_left2, args.post_process, args.prefix_len, args.target_prefix_frac, args.source_prefix_frac, ) assert ( bitext2.source_lengths == bitext1.source_lengths ), "source lengths for rescoring models do not match" assert ( bitext2.target_lengths == bitext1.target_lengths ), "target lengths for rescoring models do not match" else: if args.diff_bpe: assert args.score_model2 is None bitext2 = gen_output else: bitext2 = None if args.language_model is not None: lm_res1 = rerank_utils.LMOutput( lm_score_file, args.lm_dict, args.prefix_len, args.post_process, args.target_prefix_frac, ) else: lm_res1 = None gen_output_lst.append(gen_output) bitext1_lst.append(bitext1) bitext2_lst.append(bitext2) lm_res1_lst.append(lm_res1) return gen_output_lst, bitext1_lst, bitext2_lst, lm_res1_lst def rerank(args): if type(args.lenpen) is not list: args.lenpen = [args.lenpen] if type(args.weight1) is not list: args.weight1 = [args.weight1] if type(args.weight2) is not list: args.weight2 = [args.weight2] if type(args.weight3) is not list: args.weight3 = [args.weight3] if args.all_shards: shard_ids = list(range(args.num_shards)) else: shard_ids = [args.shard_id] for shard_id in shard_ids: ( pre_gen, left_to_right_preprocessed_dir, right_to_left_preprocessed_dir, backwards_preprocessed_dir, lm_preprocessed_dir, ) = rerank_utils.get_directories( args.data_dir_name, args.num_rescore, args.gen_subset, args.gen_model_name, shard_id, args.num_shards, args.sampling, args.prefix_len, args.target_prefix_frac, args.source_prefix_frac, ) rerank_generate.gen_and_reprocess_nbest(args) rerank_score_bw.score_bw(args) rerank_score_lm.score_lm(args) if args.write_hypos is None: write_targets = pre_gen + "/matched_targets" write_hypos = pre_gen + "/matched_hypos" else: write_targets = args.write_hypos + "_targets" + args.gen_subset write_hypos = args.write_hypos + "_hypos" + args.gen_subset if args.all_shards: write_targets += "_all_shards" write_hypos += "_all_shards" ( best_lenpen, best_weight1, best_weight2, best_weight3, best_score, ) = match_target_hypo(args, write_targets, write_hypos) return best_lenpen, best_weight1, best_weight2, best_weight3, best_score def cli_main(): parser = rerank_options.get_reranking_parser() args = options.parse_args_and_arch(parser) rerank(args) if __name__ == "__main__": cli_main()