#!/usr/bin/env python3 -u # Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. """ Translate pre-processed data with a trained model. """ import numpy as np import torch from fairseq import checkpoint_utils, options, progress_bar, tasks, utils from fairseq.sequence_generator import EnsembleModel from fairseq.utils import safe_hasattr def get_avg_pool( models, sample, prefix_tokens, src_dict, remove_bpe, has_langtok=False ): model = EnsembleModel(models) # model.forward normally channels prev_output_tokens into the decoder # separately, but SequenceGenerator directly calls model.encoder encoder_input = { k: v for k, v in sample["net_input"].items() if k != "prev_output_tokens" } # compute the encoder output for each beam encoder_outs = model.forward_encoder(encoder_input) np_encoder_outs = encoder_outs[0].encoder_out.cpu().numpy().astype(np.float32) encoder_mask = 1 - encoder_outs[0].encoder_padding_mask.cpu().numpy().astype( np.float32 ) encoder_mask = np.expand_dims(encoder_mask.T, axis=2) if has_langtok: encoder_mask = encoder_mask[1:, :, :] np_encoder_outs = np_encoder_outs[1, :, :] masked_encoder_outs = encoder_mask * np_encoder_outs avg_pool = (masked_encoder_outs / encoder_mask.sum(axis=0)).sum(axis=0) return avg_pool def main(args): assert args.path is not None, "--path required for generation!" assert ( not args.sampling or args.nbest == args.beam ), "--sampling requires --nbest to be equal to --beam" assert ( args.replace_unk is None or args.raw_text ), "--replace-unk requires a raw text dataset (--raw-text)" args.beam = 1 utils.import_user_module(args) if args.max_tokens is None: args.max_tokens = 12000 print(args) use_cuda = torch.cuda.is_available() and not args.cpu # Load dataset splits task = tasks.setup_task(args) task.load_dataset(args.gen_subset) # Set dictionaries try: src_dict = getattr(task, "source_dictionary", None) except NotImplementedError: src_dict = None tgt_dict = task.target_dictionary # Load ensemble print("| loading model(s) from {}".format(args.path)) models, _model_args = checkpoint_utils.load_model_ensemble( args.path.split(":"), arg_overrides=eval(args.model_overrides), task=task, ) # Optimize ensemble for generation for model in models: model.make_generation_fast_( beamable_mm_beam_size=None if args.no_beamable_mm else args.beam, need_attn=args.print_alignment, ) if args.fp16: model.half() if use_cuda: model.cuda() # Load alignment dictionary for unknown word replacement # (None if no unknown word replacement, empty if no path to align dictionary) align_dict = utils.load_align_dict(args.replace_unk) # Load dataset (possibly sharded) itr = task.get_batch_iterator( dataset=task.dataset(args.gen_subset), max_tokens=args.max_tokens, max_positions=utils.resolve_max_positions( task.max_positions(), ), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=args.required_batch_size_multiple, num_shards=args.num_shards, shard_id=args.shard_id, num_workers=args.num_workers, ).next_epoch_itr(shuffle=False) num_sentences = 0 source_sentences = [] shard_id = 0 all_avg_pool = None encoder_has_langtok = ( safe_hasattr(task.args, "encoder_langtok") and task.args.encoder_langtok is not None and safe_hasattr(task.args, "lang_tok_replacing_bos_eos") and not task.args.lang_tok_replacing_bos_eos ) with progress_bar.build_progress_bar(args, itr) as t: for sample in t: if sample is None: print("Skipping None") continue sample = utils.move_to_cuda(sample) if use_cuda else sample if "net_input" not in sample: continue prefix_tokens = None if args.prefix_size > 0: prefix_tokens = sample["target"][:, : args.prefix_size] with torch.no_grad(): avg_pool = get_avg_pool( models, sample, prefix_tokens, src_dict, args.post_process, has_langtok=encoder_has_langtok, ) if all_avg_pool is not None: all_avg_pool = np.concatenate((all_avg_pool, avg_pool)) else: all_avg_pool = avg_pool if not isinstance(sample["id"], list): sample_ids = sample["id"].tolist() else: sample_ids = sample["id"] for i, sample_id in enumerate(sample_ids): # Remove padding src_tokens = utils.strip_pad( sample["net_input"]["src_tokens"][i, :], tgt_dict.pad() ) # Either retrieve the original sentences or regenerate them from tokens. if align_dict is not None: src_str = task.dataset(args.gen_subset).src.get_original_text( sample_id ) else: if src_dict is not None: src_str = src_dict.string(src_tokens, args.post_process) else: src_str = "" if not args.quiet: if src_dict is not None: print("S-{}\t{}".format(sample_id, src_str)) source_sentences.append(f"{sample_id}\t{src_str}") num_sentences += sample["nsentences"] if all_avg_pool.shape[0] >= 1000000: with open( f"{args.encoder_save_dir}/all_avg_pool.{args.source_lang}.{shard_id}", "w", ) as avg_pool_file: all_avg_pool.tofile(avg_pool_file) with open( f"{args.encoder_save_dir}/sentences.{args.source_lang}.{shard_id}", "w", ) as sentence_file: sentence_file.writelines(f"{line}\n" for line in source_sentences) all_avg_pool = None source_sentences = [] shard_id += 1 if all_avg_pool is not None: with open( f"{args.encoder_save_dir}/all_avg_pool.{args.source_lang}.{shard_id}", "w" ) as avg_pool_file: all_avg_pool.tofile(avg_pool_file) with open( f"{args.encoder_save_dir}/sentences.{args.source_lang}.{shard_id}", "w" ) as sentence_file: sentence_file.writelines(f"{line}\n" for line in source_sentences) return None def cli_main(): parser = options.get_generation_parser() parser.add_argument( "--encoder-save-dir", default="", type=str, metavar="N", help="directory to save encoder outputs", ) args = options.parse_args_and_arch(parser) main(args) if __name__ == "__main__": cli_main()