Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import math | |
from multiprocessing import Pool | |
import numpy as np | |
from fairseq import options | |
from fairseq.data import dictionary | |
from fairseq.scoring import bleu | |
from examples.noisychannel import ( | |
rerank_generate, | |
rerank_options, | |
rerank_score_bw, | |
rerank_score_lm, | |
rerank_utils, | |
) | |
def score_target_hypo( | |
args, a, b, c, lenpen, target_outfile, hypo_outfile, write_hypos, normalize | |
): | |
print("lenpen", lenpen, "weight1", a, "weight2", b, "weight3", c) | |
gen_output_lst, bitext1_lst, bitext2_lst, lm_res_lst = load_score_files(args) | |
dict = dictionary.Dictionary() | |
scorer = scorer = bleu.Scorer( | |
bleu.BleuConfig( | |
pad=dict.pad(), | |
eos=dict.eos(), | |
unk=dict.unk(), | |
) | |
) | |
ordered_hypos = {} | |
ordered_targets = {} | |
for shard_id in range(len(bitext1_lst)): | |
bitext1 = bitext1_lst[shard_id] | |
bitext2 = bitext2_lst[shard_id] | |
gen_output = gen_output_lst[shard_id] | |
lm_res = lm_res_lst[shard_id] | |
total = len(bitext1.rescore_source.keys()) | |
source_lst = [] | |
hypo_lst = [] | |
score_lst = [] | |
reference_lst = [] | |
j = 1 | |
best_score = -math.inf | |
for i in range(total): | |
# length is measured in terms of words, not bpe tokens, since models may not share the same bpe | |
target_len = len(bitext1.rescore_hypo[i].split()) | |
if lm_res is not None: | |
lm_score = lm_res.score[i] | |
else: | |
lm_score = 0 | |
if bitext2 is not None: | |
bitext2_score = bitext2.rescore_score[i] | |
bitext2_backwards = bitext2.backwards | |
else: | |
bitext2_score = None | |
bitext2_backwards = None | |
score = rerank_utils.get_score( | |
a, | |
b, | |
c, | |
target_len, | |
bitext1.rescore_score[i], | |
bitext2_score, | |
lm_score=lm_score, | |
lenpen=lenpen, | |
src_len=bitext1.source_lengths[i], | |
tgt_len=bitext1.target_lengths[i], | |
bitext1_backwards=bitext1.backwards, | |
bitext2_backwards=bitext2_backwards, | |
normalize=normalize, | |
) | |
if score > best_score: | |
best_score = score | |
best_hypo = bitext1.rescore_hypo[i] | |
if j == gen_output.num_hypos[i] or j == args.num_rescore: | |
j = 1 | |
hypo_lst.append(best_hypo) | |
score_lst.append(best_score) | |
source_lst.append(bitext1.rescore_source[i]) | |
reference_lst.append(bitext1.rescore_target[i]) | |
best_score = -math.inf | |
best_hypo = "" | |
else: | |
j += 1 | |
gen_keys = list(sorted(gen_output.no_bpe_target.keys())) | |
for key in range(len(gen_keys)): | |
if args.prefix_len is None: | |
assert hypo_lst[key] in gen_output.no_bpe_hypo[gen_keys[key]], ( | |
"pred and rescore hypo mismatch: i: " | |
+ str(key) | |
+ ", " | |
+ str(hypo_lst[key]) | |
+ str(gen_keys[key]) | |
+ str(gen_output.no_bpe_hypo[key]) | |
) | |
sys_tok = dict.encode_line(hypo_lst[key]) | |
ref_tok = dict.encode_line(gen_output.no_bpe_target[gen_keys[key]]) | |
scorer.add(ref_tok, sys_tok) | |
else: | |
full_hypo = rerank_utils.get_full_from_prefix( | |
hypo_lst[key], gen_output.no_bpe_hypo[gen_keys[key]] | |
) | |
sys_tok = dict.encode_line(full_hypo) | |
ref_tok = dict.encode_line(gen_output.no_bpe_target[gen_keys[key]]) | |
scorer.add(ref_tok, sys_tok) | |
# if only one set of hyper parameters is provided, write the predictions to a file | |
if write_hypos: | |
# recover the orinal ids from n best list generation | |
for key in range(len(gen_output.no_bpe_target)): | |
if args.prefix_len is None: | |
assert hypo_lst[key] in gen_output.no_bpe_hypo[gen_keys[key]], ( | |
"pred and rescore hypo mismatch:" | |
+ "i:" | |
+ str(key) | |
+ str(hypo_lst[key]) | |
+ str(gen_output.no_bpe_hypo[key]) | |
) | |
ordered_hypos[gen_keys[key]] = hypo_lst[key] | |
ordered_targets[gen_keys[key]] = gen_output.no_bpe_target[ | |
gen_keys[key] | |
] | |
else: | |
full_hypo = rerank_utils.get_full_from_prefix( | |
hypo_lst[key], gen_output.no_bpe_hypo[gen_keys[key]] | |
) | |
ordered_hypos[gen_keys[key]] = full_hypo | |
ordered_targets[gen_keys[key]] = gen_output.no_bpe_target[ | |
gen_keys[key] | |
] | |
# write the hypos in the original order from nbest list generation | |
if args.num_shards == (len(bitext1_lst)): | |
with open(target_outfile, "w") as t: | |
with open(hypo_outfile, "w") as h: | |
for key in range(len(ordered_hypos)): | |
t.write(ordered_targets[key]) | |
h.write(ordered_hypos[key]) | |
res = scorer.result_string(4) | |
if write_hypos: | |
print(res) | |
score = rerank_utils.parse_bleu_scoring(res) | |
return score | |
def match_target_hypo(args, target_outfile, hypo_outfile): | |
"""combine scores from the LM and bitext models, and write the top scoring hypothesis to a file""" | |
if len(args.weight1) == 1: | |
res = score_target_hypo( | |
args, | |
args.weight1[0], | |
args.weight2[0], | |
args.weight3[0], | |
args.lenpen[0], | |
target_outfile, | |
hypo_outfile, | |
True, | |
args.normalize, | |
) | |
rerank_scores = [res] | |
else: | |
print("launching pool") | |
with Pool(32) as p: | |
rerank_scores = p.starmap( | |
score_target_hypo, | |
[ | |
( | |
args, | |
args.weight1[i], | |
args.weight2[i], | |
args.weight3[i], | |
args.lenpen[i], | |
target_outfile, | |
hypo_outfile, | |
False, | |
args.normalize, | |
) | |
for i in range(len(args.weight1)) | |
], | |
) | |
if len(rerank_scores) > 1: | |
best_index = np.argmax(rerank_scores) | |
best_score = rerank_scores[best_index] | |
print("best score", best_score) | |
print("best lenpen", args.lenpen[best_index]) | |
print("best weight1", args.weight1[best_index]) | |
print("best weight2", args.weight2[best_index]) | |
print("best weight3", args.weight3[best_index]) | |
return ( | |
args.lenpen[best_index], | |
args.weight1[best_index], | |
args.weight2[best_index], | |
args.weight3[best_index], | |
best_score, | |
) | |
else: | |
return ( | |
args.lenpen[0], | |
args.weight1[0], | |
args.weight2[0], | |
args.weight3[0], | |
rerank_scores[0], | |
) | |
def load_score_files(args): | |
if args.all_shards: | |
shard_ids = list(range(args.num_shards)) | |
else: | |
shard_ids = [args.shard_id] | |
gen_output_lst = [] | |
bitext1_lst = [] | |
bitext2_lst = [] | |
lm_res1_lst = [] | |
for shard_id in shard_ids: | |
using_nbest = args.nbest_list is not None | |
( | |
pre_gen, | |
left_to_right_preprocessed_dir, | |
right_to_left_preprocessed_dir, | |
backwards_preprocessed_dir, | |
lm_preprocessed_dir, | |
) = rerank_utils.get_directories( | |
args.data_dir_name, | |
args.num_rescore, | |
args.gen_subset, | |
args.gen_model_name, | |
shard_id, | |
args.num_shards, | |
args.sampling, | |
args.prefix_len, | |
args.target_prefix_frac, | |
args.source_prefix_frac, | |
) | |
rerank1_is_gen = ( | |
args.gen_model == args.score_model1 and args.source_prefix_frac is None | |
) | |
rerank2_is_gen = ( | |
args.gen_model == args.score_model2 and args.source_prefix_frac is None | |
) | |
score1_file = rerank_utils.rescore_file_name( | |
pre_gen, | |
args.prefix_len, | |
args.model1_name, | |
target_prefix_frac=args.target_prefix_frac, | |
source_prefix_frac=args.source_prefix_frac, | |
backwards=args.backwards1, | |
) | |
if args.score_model2 is not None: | |
score2_file = rerank_utils.rescore_file_name( | |
pre_gen, | |
args.prefix_len, | |
args.model2_name, | |
target_prefix_frac=args.target_prefix_frac, | |
source_prefix_frac=args.source_prefix_frac, | |
backwards=args.backwards2, | |
) | |
if args.language_model is not None: | |
lm_score_file = rerank_utils.rescore_file_name( | |
pre_gen, args.prefix_len, args.lm_name, lm_file=True | |
) | |
# get gen output | |
predictions_bpe_file = pre_gen + "/generate_output_bpe.txt" | |
if using_nbest: | |
print("Using predefined n-best list from interactive.py") | |
predictions_bpe_file = args.nbest_list | |
gen_output = rerank_utils.BitextOutputFromGen( | |
predictions_bpe_file, | |
bpe_symbol=args.post_process, | |
nbest=using_nbest, | |
prefix_len=args.prefix_len, | |
target_prefix_frac=args.target_prefix_frac, | |
) | |
if rerank1_is_gen: | |
bitext1 = gen_output | |
else: | |
bitext1 = rerank_utils.BitextOutput( | |
score1_file, | |
args.backwards1, | |
args.right_to_left1, | |
args.post_process, | |
args.prefix_len, | |
args.target_prefix_frac, | |
args.source_prefix_frac, | |
) | |
if args.score_model2 is not None or args.nbest_list is not None: | |
if rerank2_is_gen: | |
bitext2 = gen_output | |
else: | |
bitext2 = rerank_utils.BitextOutput( | |
score2_file, | |
args.backwards2, | |
args.right_to_left2, | |
args.post_process, | |
args.prefix_len, | |
args.target_prefix_frac, | |
args.source_prefix_frac, | |
) | |
assert ( | |
bitext2.source_lengths == bitext1.source_lengths | |
), "source lengths for rescoring models do not match" | |
assert ( | |
bitext2.target_lengths == bitext1.target_lengths | |
), "target lengths for rescoring models do not match" | |
else: | |
if args.diff_bpe: | |
assert args.score_model2 is None | |
bitext2 = gen_output | |
else: | |
bitext2 = None | |
if args.language_model is not None: | |
lm_res1 = rerank_utils.LMOutput( | |
lm_score_file, | |
args.lm_dict, | |
args.prefix_len, | |
args.post_process, | |
args.target_prefix_frac, | |
) | |
else: | |
lm_res1 = None | |
gen_output_lst.append(gen_output) | |
bitext1_lst.append(bitext1) | |
bitext2_lst.append(bitext2) | |
lm_res1_lst.append(lm_res1) | |
return gen_output_lst, bitext1_lst, bitext2_lst, lm_res1_lst | |
def rerank(args): | |
if type(args.lenpen) is not list: | |
args.lenpen = [args.lenpen] | |
if type(args.weight1) is not list: | |
args.weight1 = [args.weight1] | |
if type(args.weight2) is not list: | |
args.weight2 = [args.weight2] | |
if type(args.weight3) is not list: | |
args.weight3 = [args.weight3] | |
if args.all_shards: | |
shard_ids = list(range(args.num_shards)) | |
else: | |
shard_ids = [args.shard_id] | |
for shard_id in shard_ids: | |
( | |
pre_gen, | |
left_to_right_preprocessed_dir, | |
right_to_left_preprocessed_dir, | |
backwards_preprocessed_dir, | |
lm_preprocessed_dir, | |
) = rerank_utils.get_directories( | |
args.data_dir_name, | |
args.num_rescore, | |
args.gen_subset, | |
args.gen_model_name, | |
shard_id, | |
args.num_shards, | |
args.sampling, | |
args.prefix_len, | |
args.target_prefix_frac, | |
args.source_prefix_frac, | |
) | |
rerank_generate.gen_and_reprocess_nbest(args) | |
rerank_score_bw.score_bw(args) | |
rerank_score_lm.score_lm(args) | |
if args.write_hypos is None: | |
write_targets = pre_gen + "/matched_targets" | |
write_hypos = pre_gen + "/matched_hypos" | |
else: | |
write_targets = args.write_hypos + "_targets" + args.gen_subset | |
write_hypos = args.write_hypos + "_hypos" + args.gen_subset | |
if args.all_shards: | |
write_targets += "_all_shards" | |
write_hypos += "_all_shards" | |
( | |
best_lenpen, | |
best_weight1, | |
best_weight2, | |
best_weight3, | |
best_score, | |
) = match_target_hypo(args, write_targets, write_hypos) | |
return best_lenpen, best_weight1, best_weight2, best_weight3, best_score | |
def cli_main(): | |
parser = rerank_options.get_reranking_parser() | |
args = options.parse_args_and_arch(parser) | |
rerank(args) | |
if __name__ == "__main__": | |
cli_main() | |