Spaces:
Running
Running
#!/usr/bin/env python3 -u | |
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
""" | |
Generate n-best translations using a trained model. | |
""" | |
import os | |
import subprocess | |
from contextlib import redirect_stdout | |
from fairseq import options | |
from fairseq_cli import generate, preprocess | |
from examples.noisychannel import rerank_options, rerank_utils | |
def gen_and_reprocess_nbest(args): | |
if args.score_dict_dir is None: | |
args.score_dict_dir = args.data | |
if args.prefix_len is not None: | |
assert ( | |
args.right_to_left1 is False | |
), "prefix length not compatible with right to left models" | |
assert ( | |
args.right_to_left2 is False | |
), "prefix length not compatible with right to left models" | |
if args.nbest_list is not None: | |
assert args.score_model2 is None | |
if args.backwards1: | |
scorer1_src = args.target_lang | |
scorer1_tgt = args.source_lang | |
else: | |
scorer1_src = args.source_lang | |
scorer1_tgt = args.target_lang | |
store_data = ( | |
os.path.join(os.path.dirname(__file__)) + "/rerank_data/" + args.data_dir_name | |
) | |
if not os.path.exists(store_data): | |
os.makedirs(store_data) | |
( | |
pre_gen, | |
left_to_right_preprocessed_dir, | |
right_to_left_preprocessed_dir, | |
backwards_preprocessed_dir, | |
lm_preprocessed_dir, | |
) = rerank_utils.get_directories( | |
args.data_dir_name, | |
args.num_rescore, | |
args.gen_subset, | |
args.gen_model_name, | |
args.shard_id, | |
args.num_shards, | |
args.sampling, | |
args.prefix_len, | |
args.target_prefix_frac, | |
args.source_prefix_frac, | |
) | |
assert not ( | |
args.right_to_left1 and args.backwards1 | |
), "backwards right to left not supported" | |
assert not ( | |
args.right_to_left2 and args.backwards2 | |
), "backwards right to left not supported" | |
assert not ( | |
args.prefix_len is not None and args.target_prefix_frac is not None | |
), "target prefix frac and target prefix len incompatible" | |
# make directory to store generation results | |
if not os.path.exists(pre_gen): | |
os.makedirs(pre_gen) | |
rerank1_is_gen = ( | |
args.gen_model == args.score_model1 and args.source_prefix_frac is None | |
) | |
rerank2_is_gen = ( | |
args.gen_model == args.score_model2 and args.source_prefix_frac is None | |
) | |
if args.nbest_list is not None: | |
rerank2_is_gen = True | |
# make directories to store preprossed nbest list for reranking | |
if not os.path.exists(left_to_right_preprocessed_dir): | |
os.makedirs(left_to_right_preprocessed_dir) | |
if not os.path.exists(right_to_left_preprocessed_dir): | |
os.makedirs(right_to_left_preprocessed_dir) | |
if not os.path.exists(lm_preprocessed_dir): | |
os.makedirs(lm_preprocessed_dir) | |
if not os.path.exists(backwards_preprocessed_dir): | |
os.makedirs(backwards_preprocessed_dir) | |
score1_file = rerank_utils.rescore_file_name( | |
pre_gen, | |
args.prefix_len, | |
args.model1_name, | |
target_prefix_frac=args.target_prefix_frac, | |
source_prefix_frac=args.source_prefix_frac, | |
backwards=args.backwards1, | |
) | |
if args.score_model2 is not None: | |
score2_file = rerank_utils.rescore_file_name( | |
pre_gen, | |
args.prefix_len, | |
args.model2_name, | |
target_prefix_frac=args.target_prefix_frac, | |
source_prefix_frac=args.source_prefix_frac, | |
backwards=args.backwards2, | |
) | |
predictions_bpe_file = pre_gen + "/generate_output_bpe.txt" | |
using_nbest = args.nbest_list is not None | |
if using_nbest: | |
print("Using predefined n-best list from interactive.py") | |
predictions_bpe_file = args.nbest_list | |
else: | |
if not os.path.isfile(predictions_bpe_file): | |
print("STEP 1: generate predictions using the p(T|S) model with bpe") | |
print(args.data) | |
param1 = [ | |
args.data, | |
"--path", | |
args.gen_model, | |
"--shard-id", | |
str(args.shard_id), | |
"--num-shards", | |
str(args.num_shards), | |
"--nbest", | |
str(args.num_rescore), | |
"--batch-size", | |
str(args.batch_size), | |
"--beam", | |
str(args.num_rescore), | |
"--batch-size", | |
str(args.num_rescore), | |
"--gen-subset", | |
args.gen_subset, | |
"--source-lang", | |
args.source_lang, | |
"--target-lang", | |
args.target_lang, | |
] | |
if args.sampling: | |
param1 += ["--sampling"] | |
gen_parser = options.get_generation_parser() | |
input_args = options.parse_args_and_arch(gen_parser, param1) | |
print(input_args) | |
with open(predictions_bpe_file, "w") as f: | |
with redirect_stdout(f): | |
generate.main(input_args) | |
gen_output = rerank_utils.BitextOutputFromGen( | |
predictions_bpe_file, | |
bpe_symbol=args.post_process, | |
nbest=using_nbest, | |
prefix_len=args.prefix_len, | |
target_prefix_frac=args.target_prefix_frac, | |
) | |
if args.diff_bpe: | |
rerank_utils.write_reprocessed( | |
gen_output.no_bpe_source, | |
gen_output.no_bpe_hypo, | |
gen_output.no_bpe_target, | |
pre_gen + "/source_gen_bpe." + args.source_lang, | |
pre_gen + "/target_gen_bpe." + args.target_lang, | |
pre_gen + "/reference_gen_bpe." + args.target_lang, | |
) | |
bitext_bpe = args.rescore_bpe_code | |
bpe_src_param = [ | |
"-c", | |
bitext_bpe, | |
"--input", | |
pre_gen + "/source_gen_bpe." + args.source_lang, | |
"--output", | |
pre_gen + "/rescore_data." + args.source_lang, | |
] | |
bpe_tgt_param = [ | |
"-c", | |
bitext_bpe, | |
"--input", | |
pre_gen + "/target_gen_bpe." + args.target_lang, | |
"--output", | |
pre_gen + "/rescore_data." + args.target_lang, | |
] | |
subprocess.call( | |
[ | |
"python", | |
os.path.join( | |
os.path.dirname(__file__), "subword-nmt/subword_nmt/apply_bpe.py" | |
), | |
] | |
+ bpe_src_param, | |
shell=False, | |
) | |
subprocess.call( | |
[ | |
"python", | |
os.path.join( | |
os.path.dirname(__file__), "subword-nmt/subword_nmt/apply_bpe.py" | |
), | |
] | |
+ bpe_tgt_param, | |
shell=False, | |
) | |
if (not os.path.isfile(score1_file) and not rerank1_is_gen) or ( | |
args.score_model2 is not None | |
and not os.path.isfile(score2_file) | |
and not rerank2_is_gen | |
): | |
print( | |
"STEP 2: process the output of generate.py so we have clean text files with the translations" | |
) | |
rescore_file = "/rescore_data" | |
if args.prefix_len is not None: | |
prefix_len_rescore_file = rescore_file + "prefix" + str(args.prefix_len) | |
if args.target_prefix_frac is not None: | |
target_prefix_frac_rescore_file = ( | |
rescore_file + "target_prefix_frac" + str(args.target_prefix_frac) | |
) | |
if args.source_prefix_frac is not None: | |
source_prefix_frac_rescore_file = ( | |
rescore_file + "source_prefix_frac" + str(args.source_prefix_frac) | |
) | |
if not args.right_to_left1 or not args.right_to_left2: | |
if not args.diff_bpe: | |
rerank_utils.write_reprocessed( | |
gen_output.source, | |
gen_output.hypo, | |
gen_output.target, | |
pre_gen + rescore_file + "." + args.source_lang, | |
pre_gen + rescore_file + "." + args.target_lang, | |
pre_gen + "/reference_file", | |
bpe_symbol=args.post_process, | |
) | |
if args.prefix_len is not None: | |
bw_rescore_file = prefix_len_rescore_file | |
rerank_utils.write_reprocessed( | |
gen_output.source, | |
gen_output.hypo, | |
gen_output.target, | |
pre_gen + prefix_len_rescore_file + "." + args.source_lang, | |
pre_gen + prefix_len_rescore_file + "." + args.target_lang, | |
pre_gen + "/reference_file", | |
prefix_len=args.prefix_len, | |
bpe_symbol=args.post_process, | |
) | |
elif args.target_prefix_frac is not None: | |
bw_rescore_file = target_prefix_frac_rescore_file | |
rerank_utils.write_reprocessed( | |
gen_output.source, | |
gen_output.hypo, | |
gen_output.target, | |
pre_gen | |
+ target_prefix_frac_rescore_file | |
+ "." | |
+ args.source_lang, | |
pre_gen | |
+ target_prefix_frac_rescore_file | |
+ "." | |
+ args.target_lang, | |
pre_gen + "/reference_file", | |
bpe_symbol=args.post_process, | |
target_prefix_frac=args.target_prefix_frac, | |
) | |
else: | |
bw_rescore_file = rescore_file | |
if args.source_prefix_frac is not None: | |
fw_rescore_file = source_prefix_frac_rescore_file | |
rerank_utils.write_reprocessed( | |
gen_output.source, | |
gen_output.hypo, | |
gen_output.target, | |
pre_gen | |
+ source_prefix_frac_rescore_file | |
+ "." | |
+ args.source_lang, | |
pre_gen | |
+ source_prefix_frac_rescore_file | |
+ "." | |
+ args.target_lang, | |
pre_gen + "/reference_file", | |
bpe_symbol=args.post_process, | |
source_prefix_frac=args.source_prefix_frac, | |
) | |
else: | |
fw_rescore_file = rescore_file | |
if args.right_to_left1 or args.right_to_left2: | |
rerank_utils.write_reprocessed( | |
gen_output.source, | |
gen_output.hypo, | |
gen_output.target, | |
pre_gen + "/right_to_left_rescore_data." + args.source_lang, | |
pre_gen + "/right_to_left_rescore_data." + args.target_lang, | |
pre_gen + "/right_to_left_reference_file", | |
right_to_left=True, | |
bpe_symbol=args.post_process, | |
) | |
print("STEP 3: binarize the translations") | |
if ( | |
not args.right_to_left1 | |
or args.score_model2 is not None | |
and not args.right_to_left2 | |
or not rerank1_is_gen | |
): | |
if args.backwards1 or args.backwards2: | |
if args.backwards_score_dict_dir is not None: | |
bw_dict = args.backwards_score_dict_dir | |
else: | |
bw_dict = args.score_dict_dir | |
bw_preprocess_param = [ | |
"--source-lang", | |
scorer1_src, | |
"--target-lang", | |
scorer1_tgt, | |
"--trainpref", | |
pre_gen + bw_rescore_file, | |
"--srcdict", | |
bw_dict + "/dict." + scorer1_src + ".txt", | |
"--tgtdict", | |
bw_dict + "/dict." + scorer1_tgt + ".txt", | |
"--destdir", | |
backwards_preprocessed_dir, | |
] | |
preprocess_parser = options.get_preprocessing_parser() | |
input_args = preprocess_parser.parse_args(bw_preprocess_param) | |
preprocess.main(input_args) | |
preprocess_param = [ | |
"--source-lang", | |
scorer1_src, | |
"--target-lang", | |
scorer1_tgt, | |
"--trainpref", | |
pre_gen + fw_rescore_file, | |
"--srcdict", | |
args.score_dict_dir + "/dict." + scorer1_src + ".txt", | |
"--tgtdict", | |
args.score_dict_dir + "/dict." + scorer1_tgt + ".txt", | |
"--destdir", | |
left_to_right_preprocessed_dir, | |
] | |
preprocess_parser = options.get_preprocessing_parser() | |
input_args = preprocess_parser.parse_args(preprocess_param) | |
preprocess.main(input_args) | |
if args.right_to_left1 or args.right_to_left2: | |
preprocess_param = [ | |
"--source-lang", | |
scorer1_src, | |
"--target-lang", | |
scorer1_tgt, | |
"--trainpref", | |
pre_gen + "/right_to_left_rescore_data", | |
"--srcdict", | |
args.score_dict_dir + "/dict." + scorer1_src + ".txt", | |
"--tgtdict", | |
args.score_dict_dir + "/dict." + scorer1_tgt + ".txt", | |
"--destdir", | |
right_to_left_preprocessed_dir, | |
] | |
preprocess_parser = options.get_preprocessing_parser() | |
input_args = preprocess_parser.parse_args(preprocess_param) | |
preprocess.main(input_args) | |
return gen_output | |
def cli_main(): | |
parser = rerank_options.get_reranking_parser() | |
args = options.parse_args_and_arch(parser) | |
gen_and_reprocess_nbest(args) | |
if __name__ == "__main__": | |
cli_main() | |