Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import os | |
from fairseq import options | |
from examples.noisychannel import rerank_options, rerank_utils | |
def score_lm(args): | |
using_nbest = args.nbest_list is not None | |
( | |
pre_gen, | |
left_to_right_preprocessed_dir, | |
right_to_left_preprocessed_dir, | |
backwards_preprocessed_dir, | |
lm_preprocessed_dir, | |
) = rerank_utils.get_directories( | |
args.data_dir_name, | |
args.num_rescore, | |
args.gen_subset, | |
args.gen_model_name, | |
args.shard_id, | |
args.num_shards, | |
args.sampling, | |
args.prefix_len, | |
args.target_prefix_frac, | |
args.source_prefix_frac, | |
) | |
predictions_bpe_file = pre_gen + "/generate_output_bpe.txt" | |
if using_nbest: | |
print("Using predefined n-best list from interactive.py") | |
predictions_bpe_file = args.nbest_list | |
gen_output = rerank_utils.BitextOutputFromGen( | |
predictions_bpe_file, bpe_symbol=args.post_process, nbest=using_nbest | |
) | |
if args.language_model is not None: | |
lm_score_file = rerank_utils.rescore_file_name( | |
pre_gen, args.prefix_len, args.lm_name, lm_file=True | |
) | |
if args.language_model is not None and not os.path.isfile(lm_score_file): | |
print("STEP 4.5: language modeling for P(T)") | |
if args.lm_bpe_code is None: | |
bpe_status = "no bpe" | |
elif args.lm_bpe_code == "shared": | |
bpe_status = "shared" | |
else: | |
bpe_status = "different" | |
rerank_utils.lm_scoring( | |
lm_preprocessed_dir, | |
bpe_status, | |
gen_output, | |
pre_gen, | |
args.lm_dict, | |
args.lm_name, | |
args.language_model, | |
args.lm_bpe_code, | |
128, | |
lm_score_file, | |
args.target_lang, | |
args.source_lang, | |
prefix_len=args.prefix_len, | |
) | |
def cli_main(): | |
parser = rerank_options.get_reranking_parser() | |
args = options.parse_args_and_arch(parser) | |
score_lm(args) | |
if __name__ == "__main__": | |
cli_main() | |