#!/usr/bin/env python3 -u # Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. """ Sample from a trained LM; hacked fairseq-interactive """ from collections import namedtuple import os import ast import numpy as np from fairseq import checkpoint_utils, options, tasks, utils import tqdm Batch = namedtuple('Batch', 'ids src_tokens src_lengths') Translation = namedtuple('Translation', 'src_str hypos pos_scores alignments') def make_batches(lines, args, task, max_positions): tokens = [ task.source_dictionary.encode_line( src_str, add_if_not_exist=False ).long() for src_str in lines ] lengths = [t.numel() for t in tokens] itr = task.get_batch_iterator( dataset=task.build_dataset_for_inference(tokens, lengths), max_tokens=args.dataset.max_tokens, max_sentences=args.dataset.batch_size, max_positions=max_positions, ignore_invalid_inputs=args.dataset.skip_invalid_size_inputs_valid_test ).next_epoch_itr(shuffle=False) for batch in itr: yield Batch( ids=batch['id'], src_tokens=batch['net_input']['src_tokens'], src_lengths=batch['net_input']['src_lengths'], ) def main(args): arg_prompts = args.prompts arg_output = args.output arg_debug = args.debug arg_sample_size = args.samples_per_prompt try: from fairseq.dataclass.utils import convert_namespace_to_omegaconf args = convert_namespace_to_omegaconf(args) except: pass # if args.max_tokens is None and args.max_sentences is None: if args.common.seed is not None: np.random.seed(args.common.seed) utils.set_torch_seed(args.common.seed) if args.generation.sampling: args.generation.nbest = args.generation.beam = arg_sample_size task = tasks.setup_task(args.task) overrides = ast.literal_eval(args.common_eval.model_overrides) models, _model_args = checkpoint_utils.load_model_ensemble( args.common_eval.path.split(os.pathsep), arg_overrides=overrides, task=task, suffix=getattr(args, "checkpoint_suffix", ""), ) # Set dictionaries src_dict = task.source_dictionary tgt_dict = task.target_dictionary # Optimize ensemble for generation for model in models: model.prepare_for_inference_(args) model.cuda() # Load alignment dictionary for unknown word replacement # (None if no unknown word replacement, empty if no path to align dictionary) align_dict = utils.load_align_dict(args.generation.replace_unk) max_positions = utils.resolve_max_positions( task.max_positions(), *[model.max_positions() for model in models] ) output_file = open(arg_output, 'w') with open(arg_prompts, 'r') as fin: lines = fin.readlines() split = [x.split('|', 1) for x in lines] seq_id = [x[0] for x in split] prompts = [x[1] for x in split] if args.generation.prefix_size >= 0: prompts = [' '.join(l.split()[:args.generation.prefix_size]) for l in prompts] if arg_debug: prompts = prompts[:10] generator = task.build_generator(models, args.generation) start_id = 0 pbar = tqdm.tqdm(total=len(prompts)) for batch in make_batches(prompts, args, task, max_positions): src_tokens = batch.src_tokens src_lengths = batch.src_lengths src_tokens = src_tokens.cuda() src_lengths = src_lengths.cuda() sample = { 'net_input': { 'src_tokens': src_tokens, 'src_lengths': src_lengths, }, } results = [] translations = task.inference_step(generator, models, sample) for i, (id, hypos) in enumerate(zip(batch.ids.tolist(), translations)): src_tokens_i = utils.strip_pad(src_tokens[i], tgt_dict.pad()) results.append((i + start_id, src_tokens_i, hypos)) # sort output to match input order for id, src_tokens, hypos in sorted(results, key=lambda x: x[0]): if src_dict is not None: src_str = src_dict.string( src_tokens, args.common_eval.post_process) # Process top predictions for hypo_id, hypo in enumerate(hypos): _hypo_tokens, hypo_str, _alignment = utils.post_process_prediction( hypo_tokens=hypo['tokens'].int().cpu(), src_str=src_str, alignment=hypo['alignment'], align_dict=align_dict, tgt_dict=tgt_dict, remove_bpe=args.common_eval.post_process, ) detok_hypo_str = hypo_str utterance = detok_hypo_str print(f'{seq_id[id]}__{hypo_id}|{utterance}', file=output_file) pbar.update(1) start_id += len(results) # output_file.close() def cli_main(): parser = options.get_interactive_generation_parser() parser.add_argument('--prompts', type=str, default=None, required=True) parser.add_argument('--output', type=str, default=None, required=True) parser.add_argument('--debug', action='store_true') parser.add_argument('--samples-per-prompt', type=int, default=1) args = options.parse_args_and_arch(parser) np.random.seed(args.seed) utils.set_torch_seed(args.seed) main(args) if __name__ == '__main__': cli_main()