#!/usr/bin/env python """ Translator Class and builder """ import codecs import os import time import numpy as np from itertools import count, zip_longest import torch from onmt.constants import DefaultTokens import onmt.model_builder import onmt.inputters as inputters import onmt.decoders.ensemble from onmt.inputters.text_dataset import InferenceDataIterator from onmt.translate.beam_search import BeamSearch, BeamSearchLM from onmt.translate.greedy_search import GreedySearch, GreedySearchLM from onmt.utils.misc import tile, set_random_seed, report_matrix from onmt.utils.alignment import extract_alignment, build_align_pharaoh from onmt.modules.copy_generator import collapse_copy_scores from onmt.constants import ModelTask def build_translator(opt, report_score=True, logger=None, out_file=None): if out_file is None: out_file = codecs.open(opt.output, "w+", "utf-8") load_test_model = ( onmt.decoders.ensemble.load_test_model if len(opt.models) > 1 else onmt.model_builder.load_test_model ) fields, model, model_opt = load_test_model(opt) scorer = onmt.translate.GNMTGlobalScorer.from_opt(opt) if model_opt.model_task == ModelTask.LANGUAGE_MODEL: translator = GeneratorLM.from_opt( model, fields, opt, model_opt, global_scorer=scorer, out_file=out_file, report_align=opt.report_align, report_score=report_score, logger=logger, ) else: translator = Translator.from_opt( model, fields, opt, model_opt, global_scorer=scorer, out_file=out_file, report_align=opt.report_align, report_score=report_score, logger=logger, ) return translator def max_tok_len(new, count, sofar): """ In token batching scheme, the number of sequences is limited such that the total number of src/tgt tokens (including padding) in a batch <= batch_size """ # Maintains the longest src and tgt length in the current batch global max_src_in_batch # this is a hack # Reset current longest length at a new batch (count=1) if count == 1: max_src_in_batch = 0 # max_tgt_in_batch = 0 # Src: [ w1 ... wN ] max_src_in_batch = max(max_src_in_batch, len(new.src[0]) + 2) # Tgt: [w1 ... wM ] src_elements = count * max_src_in_batch return src_elements class Inference(object): """Translate a batch of sentences with a saved model. Args: model (onmt.modules.NMTModel): NMT model to use for translation fields (dict[str, torchtext.data.Field]): A dict mapping each side to its list of name-Field pairs. src_reader (onmt.inputters.DataReaderBase): Source reader. tgt_reader (onmt.inputters.TextDataReader): Target reader. gpu (int): GPU device. Set to negative for no GPU. n_best (int): How many beams to wait for. min_length (int): See :class:`onmt.translate.decode_strategy.DecodeStrategy`. max_length (int): See :class:`onmt.translate.decode_strategy.DecodeStrategy`. beam_size (int): Number of beams. random_sampling_topk (int): See :class:`onmt.translate.greedy_search.GreedySearch`. random_sampling_temp (float): See :class:`onmt.translate.greedy_search.GreedySearch`. stepwise_penalty (bool): Whether coverage penalty is applied every step or not. dump_beam (bool): Debugging option. block_ngram_repeat (int): See :class:`onmt.translate.decode_strategy.DecodeStrategy`. ignore_when_blocking (set or frozenset): See :class:`onmt.translate.decode_strategy.DecodeStrategy`. replace_unk (bool): Replace unknown token. tgt_prefix (bool): Force the predictions begin with provided -tgt. data_type (str): Source data type. verbose (bool): Print/log every translation. report_time (bool): Print/log total time/frequency. copy_attn (bool): Use copy attention. global_scorer (onmt.translate.GNMTGlobalScorer): Translation scoring/reranking object. out_file (TextIO or codecs.StreamReaderWriter): Output file. report_score (bool) : Whether to report scores logger (logging.Logger or NoneType): Logger. """ def __init__( self, model, fields, src_reader, tgt_reader, gpu=-1, n_best=1, min_length=0, max_length=100, ratio=0.0, beam_size=30, random_sampling_topk=0, random_sampling_topp=0.0, random_sampling_temp=1.0, stepwise_penalty=None, dump_beam=False, block_ngram_repeat=0, ignore_when_blocking=frozenset(), replace_unk=False, ban_unk_token=False, tgt_prefix=False, phrase_table="", data_type="text", verbose=False, report_time=False, copy_attn=False, global_scorer=None, out_file=None, report_align=False, report_score=True, logger=None, seed=-1, ): self.model = model self.fields = fields tgt_field = dict(self.fields)["tgt"].base_field self._tgt_vocab = tgt_field.vocab self._tgt_eos_idx = self._tgt_vocab.stoi[tgt_field.eos_token] self._tgt_pad_idx = self._tgt_vocab.stoi[tgt_field.pad_token] self._tgt_bos_idx = self._tgt_vocab.stoi[tgt_field.init_token] self._tgt_unk_idx = self._tgt_vocab.stoi[tgt_field.unk_token] self._tgt_vocab_len = len(self._tgt_vocab) self._gpu = gpu self._use_cuda = gpu > -1 self._dev = ( torch.device("cuda", self._gpu) if self._use_cuda else torch.device("cpu") ) self.n_best = n_best self.max_length = max_length self.beam_size = beam_size self.random_sampling_temp = random_sampling_temp self.sample_from_topk = random_sampling_topk self.sample_from_topp = random_sampling_topp self.min_length = min_length self.ban_unk_token = ban_unk_token self.ratio = ratio self.stepwise_penalty = stepwise_penalty self.dump_beam = dump_beam self.block_ngram_repeat = block_ngram_repeat self.ignore_when_blocking = ignore_when_blocking self._exclusion_idxs = { self._tgt_vocab.stoi[t] for t in self.ignore_when_blocking } self.src_reader = src_reader self.tgt_reader = tgt_reader self.replace_unk = replace_unk if self.replace_unk and not self.model.decoder.attentional: raise ValueError("replace_unk requires an attentional decoder.") self.tgt_prefix = tgt_prefix self.phrase_table = phrase_table self.data_type = data_type self.verbose = verbose self.report_time = report_time self.copy_attn = copy_attn self.global_scorer = global_scorer if ( self.global_scorer.has_cov_pen and not self.model.decoder.attentional ): raise ValueError( "Coverage penalty requires an attentional decoder." ) self.out_file = out_file self.report_align = report_align self.report_score = report_score self.logger = logger self.use_filter_pred = False self._filter_pred = None # for debugging self.beam_trace = self.dump_beam != "" self.beam_accum = None if self.beam_trace: self.beam_accum = { "predicted_ids": [], "beam_parent_ids": [], "scores": [], "log_probs": [], } set_random_seed(seed, self._use_cuda) @classmethod def from_opt( cls, model, fields, opt, model_opt, global_scorer=None, out_file=None, report_align=False, report_score=True, logger=None, ): """Alternate constructor. Args: model (onmt.modules.NMTModel): See :func:`__init__()`. fields (dict[str, torchtext.data.Field]): See :func:`__init__()`. opt (argparse.Namespace): Command line options model_opt (argparse.Namespace): Command line options saved with the model checkpoint. global_scorer (onmt.translate.GNMTGlobalScorer): See :func:`__init__()`.. out_file (TextIO or codecs.StreamReaderWriter): See :func:`__init__()`. report_align (bool) : See :func:`__init__()`. report_score (bool) : See :func:`__init__()`. logger (logging.Logger or NoneType): See :func:`__init__()`. """ # TODO: maybe add dynamic part cls.validate_task(model_opt.model_task) src_reader = inputters.str2reader[opt.data_type].from_opt(opt) tgt_reader = inputters.str2reader["text"].from_opt(opt) return cls( model, fields, src_reader, tgt_reader, gpu=opt.gpu, n_best=opt.n_best, min_length=opt.min_length, max_length=opt.max_length, ratio=opt.ratio, beam_size=opt.beam_size, random_sampling_topk=opt.random_sampling_topk, random_sampling_topp=opt.random_sampling_topp, random_sampling_temp=opt.random_sampling_temp, stepwise_penalty=opt.stepwise_penalty, dump_beam=opt.dump_beam, block_ngram_repeat=opt.block_ngram_repeat, ignore_when_blocking=set(opt.ignore_when_blocking), replace_unk=opt.replace_unk, ban_unk_token=opt.ban_unk_token, tgt_prefix=opt.tgt_prefix, phrase_table=opt.phrase_table, data_type=opt.data_type, verbose=opt.verbose, report_time=opt.report_time, copy_attn=model_opt.copy_attn, global_scorer=global_scorer, out_file=out_file, report_align=report_align, report_score=report_score, logger=logger, seed=opt.seed, ) def _log(self, msg): if self.logger: self.logger.info(msg) else: print(msg) def _gold_score( self, batch, memory_bank, src_lengths, src_vocabs, use_src_map, enc_states, batch_size, src, ): if "tgt" in batch.__dict__: gs = self._score_target( batch, memory_bank, src_lengths, src_vocabs, batch.src_map if use_src_map else None, ) self.model.decoder.init_state(src, memory_bank, enc_states) else: gs = [0] * batch_size return gs def translate_dynamic( self, src, transform, src_feats={}, tgt=None, batch_size=None, batch_type="sents", attn_debug=False, align_debug=False, phrase_table="" ): if batch_size is None: raise ValueError("batch_size must be set") if self.tgt_prefix and tgt is None: raise ValueError("Prefix should be feed to tgt if -tgt_prefix.") data_iter = InferenceDataIterator(src, tgt, src_feats, transform) data = inputters.DynamicDataset( self.fields, data=data_iter, sort_key=inputters.str2sortkey[self.data_type], filter_pred=self._filter_pred, ) return self._translate( data, tgt=tgt, batch_size=batch_size, batch_type=batch_type, attn_debug=attn_debug, align_debug=align_debug, phrase_table=phrase_table, dynamic=True, transform=transform) def translate( self, src, src_feats={}, tgt=None, batch_size=None, batch_type="sents", attn_debug=False, align_debug=False, phrase_table="", ): """Translate content of ``src`` and get gold scores from ``tgt``. Args: src: See :func:`self.src_reader.read()`. tgt: See :func:`self.tgt_reader.read()`. src_feats: See :func`self.src_reader.read()`. batch_size (int): size of examples per mini-batch attn_debug (bool): enables the attention logging align_debug (bool): enables the word alignment logging Returns: (`list`, `list`) * all_scores is a list of `batch_size` lists of `n_best` scores * all_predictions is a list of `batch_size` lists of `n_best` predictions """ if batch_size is None: raise ValueError("batch_size must be set") if self.tgt_prefix and tgt is None: raise ValueError("Prefix should be feed to tgt if -tgt_prefix.") src_data = { "reader": self.src_reader, "data": src, "features": src_feats } tgt_data = { "reader": self.tgt_reader, "data": tgt, "features": {} } _readers, _data = inputters.Dataset.config( [("src", src_data), ("tgt", tgt_data)] ) data = inputters.Dataset( self.fields, readers=_readers, data=_data, sort_key=inputters.str2sortkey[self.data_type], filter_pred=self._filter_pred, ) return self._translate( data, tgt=tgt, batch_size=batch_size, batch_type=batch_type, attn_debug=attn_debug, align_debug=align_debug, phrase_table=phrase_table) def _translate( self, data, tgt=None, batch_size=None, batch_type="sents", attn_debug=False, align_debug=False, phrase_table="", dynamic=False, transform=None ): data_iter = inputters.OrderedIterator( dataset=data, device=self._dev, batch_size=batch_size, batch_size_fn=max_tok_len if batch_type == "tokens" else None, train=False, sort=False, sort_within_batch=True, shuffle=False, ) xlation_builder = onmt.translate.TranslationBuilder( data, self.fields, self.n_best, self.replace_unk, tgt, self.phrase_table, ) # Statistics counter = count(1) pred_score_total, pred_words_total = 0, 0 gold_score_total, gold_words_total = 0, 0 all_scores = [] all_predictions = [] start_time = time.time() for batch in data_iter: batch_data = self.translate_batch( batch, data.src_vocabs, attn_debug ) translations = xlation_builder.from_batch(batch_data) for trans in translations: all_scores += [trans.pred_scores[: self.n_best]] pred_score_total += trans.pred_scores[0] pred_words_total += len(trans.pred_sents[0]) if tgt is not None: gold_score_total += trans.gold_score gold_words_total += len(trans.gold_sent) + 1 n_best_preds = [ " ".join(pred) for pred in trans.pred_sents[: self.n_best] ] if self.report_align: align_pharaohs = [ build_align_pharaoh(align) for align in trans.word_aligns[: self.n_best] ] n_best_preds_align = [ " ".join(align) for align in align_pharaohs ] n_best_preds = [ pred + DefaultTokens.ALIGNMENT_SEPARATOR + align for pred, align in zip( n_best_preds, n_best_preds_align ) ] if dynamic: n_best_preds = [transform.apply_reverse(x) for x in n_best_preds] all_predictions += [n_best_preds] self.out_file.write("\n".join(n_best_preds) + "\n") self.out_file.flush() if self.verbose: sent_number = next(counter) output = trans.log(sent_number) if self.logger: self.logger.info(output) else: os.write(1, output.encode("utf-8")) if attn_debug: preds = trans.pred_sents[0] preds.append(DefaultTokens.EOS) attns = trans.attns[0].tolist() if self.data_type == "text": srcs = trans.src_raw else: srcs = [str(item) for item in range(len(attns[0]))] output = report_matrix(srcs, preds, attns) if self.logger: self.logger.info(output) else: os.write(1, output.encode("utf-8")) if align_debug: tgts = trans.pred_sents[0] align = trans.word_aligns[0].tolist() if self.data_type == "text": srcs = trans.src_raw else: srcs = [str(item) for item in range(len(align[0]))] output = report_matrix(srcs, tgts, align) if self.logger: self.logger.info(output) else: os.write(1, output.encode("utf-8")) end_time = time.time() if self.report_score: msg = self._report_score( "PRED", pred_score_total, pred_words_total ) self._log(msg) if tgt is not None: msg = self._report_score( "GOLD", gold_score_total, gold_words_total ) self._log(msg) if self.report_time: total_time = end_time - start_time self._log("Total translation time (s): %f" % total_time) self._log( "Average translation time (s): %f" % (total_time / len(all_predictions)) ) self._log( "Tokens per second: %f" % (pred_words_total / total_time) ) if self.dump_beam: import json json.dump( self.translator.beam_accum, codecs.open(self.dump_beam, "w", "utf-8"), ) return all_scores, all_predictions def _align_pad_prediction(self, predictions, bos, pad): """ Padding predictions in batch and add BOS. Args: predictions (List[List[Tensor]]): `(batch, n_best,)`, for each src sequence contain n_best tgt predictions all of which ended with eos id. bos (int): bos index to be used. pad (int): pad index to be used. Return: batched_nbest_predict (torch.LongTensor): `(batch, n_best, tgt_l)` """ dtype, device = predictions[0][0].dtype, predictions[0][0].device flatten_tgt = [ best.tolist() for bests in predictions for best in bests ] paded_tgt = torch.tensor( list(zip_longest(*flatten_tgt, fillvalue=pad)), dtype=dtype, device=device, ).T bos_tensor = torch.full( [paded_tgt.size(0), 1], bos, dtype=dtype, device=device ) full_tgt = torch.cat((bos_tensor, paded_tgt), dim=-1) batched_nbest_predict = full_tgt.view( len(predictions), -1, full_tgt.size(-1) ) # (batch, n_best, tgt_l) return batched_nbest_predict def _report_score(self, name, score_total, words_total): if words_total == 0: msg = "%s No words predicted" % (name,) else: avg_score = score_total / words_total ppl = np.exp(-score_total.item() / words_total) msg = "%s AVG SCORE: %.4f, %s PPL: %.4f" % ( name, avg_score, name, ppl, ) return msg def _decode_and_generate( self, decoder_in, memory_bank, batch, src_vocabs, memory_lengths, src_map=None, step=None, batch_offset=None, ): if self.copy_attn: # Turn any copied words into UNKs. decoder_in = decoder_in.masked_fill( decoder_in.gt(self._tgt_vocab_len - 1), self._tgt_unk_idx ) # Decoder forward, takes [tgt_len, batch, nfeats] as input # and [src_len, batch, hidden] as memory_bank # in case of inference tgt_len = 1, batch = beam times batch_size # in case of Gold Scoring tgt_len = actual length, batch = 1 batch dec_out, dec_attn = self.model.decoder( decoder_in, memory_bank, memory_lengths=memory_lengths, step=step ) # Generator forward. if not self.copy_attn: if "std" in dec_attn: attn = dec_attn["std"] else: attn = None log_probs = self.model.generator(dec_out.squeeze(0)) # returns [(batch_size x beam_size) , vocab ] when 1 step # or [ tgt_len, batch_size, vocab ] when full sentence else: attn = dec_attn["copy"] scores = self.model.generator( dec_out.view(-1, dec_out.size(2)), attn.view(-1, attn.size(2)), src_map, ) # here we have scores [tgt_lenxbatch, vocab] or [beamxbatch, vocab] if batch_offset is None: scores = scores.view(-1, batch.batch_size, scores.size(-1)) scores = scores.transpose(0, 1).contiguous() else: scores = scores.view(-1, self.beam_size, scores.size(-1)) scores = collapse_copy_scores( scores, batch, self._tgt_vocab, src_vocabs, batch_dim=0, batch_offset=batch_offset, ) scores = scores.view(decoder_in.size(0), -1, scores.size(-1)) log_probs = scores.squeeze(0).log() # returns [(batch_size x beam_size) , vocab ] when 1 step # or [ tgt_len, batch_size, vocab ] when full sentence return log_probs, attn def translate_batch(self, batch, src_vocabs, attn_debug): """Translate a batch of sentences.""" raise NotImplementedError def _score_target( self, batch, memory_bank, src_lengths, src_vocabs, src_map ): raise NotImplementedError def report_results( self, gold_score, batch, batch_size, src, src_lengths, src_vocabs, use_src_map, decode_strategy, ): results = { "predictions": None, "scores": None, "attention": None, "batch": batch, "gold_score": gold_score, } results["scores"] = decode_strategy.scores results["predictions"] = decode_strategy.predictions results["attention"] = decode_strategy.attention if self.report_align: results["alignment"] = self._align_forward( batch, decode_strategy.predictions ) else: results["alignment"] = [[] for _ in range(batch_size)] return results class Translator(Inference): @classmethod def validate_task(cls, task): if task != ModelTask.SEQ2SEQ: raise ValueError( f"Translator does not support task {task}." f" Tasks supported: {ModelTask.SEQ2SEQ}" ) def _align_forward(self, batch, predictions): """ For a batch of input and its prediction, return a list of batch predict alignment src indice Tensor in size ``(batch, n_best,)``. """ # (0) add BOS and padding to tgt prediction batch_tgt_idxs = self._align_pad_prediction( predictions, bos=self._tgt_bos_idx, pad=self._tgt_pad_idx ) tgt_mask = ( batch_tgt_idxs.eq(self._tgt_pad_idx) | batch_tgt_idxs.eq(self._tgt_eos_idx) | batch_tgt_idxs.eq(self._tgt_bos_idx) ) n_best = batch_tgt_idxs.size(1) # (1) Encoder forward. src, enc_states, memory_bank, src_lengths = self._run_encoder(batch) # (2) Repeat src objects `n_best` times. # We use batch_size x n_best, get ``(src_len, batch * n_best, nfeat)`` src = tile(src, n_best, dim=1) enc_states = tile(enc_states, n_best, dim=1) if isinstance(memory_bank, tuple): memory_bank = tuple(tile(x, n_best, dim=1) for x in memory_bank) else: memory_bank = tile(memory_bank, n_best, dim=1) src_lengths = tile(src_lengths, n_best) # ``(batch * n_best,)`` # (3) Init decoder with n_best src, self.model.decoder.init_state(src, memory_bank, enc_states) # reshape tgt to ``(len, batch * n_best, nfeat)`` tgt = batch_tgt_idxs.view(-1, batch_tgt_idxs.size(-1)).T.unsqueeze(-1) dec_in = tgt[:-1] # exclude last target from inputs _, attns = self.model.decoder( dec_in, memory_bank, memory_lengths=src_lengths, with_align=True ) alignment_attn = attns["align"] # ``(B, tgt_len-1, src_len)`` # masked_select align_tgt_mask = tgt_mask.view(-1, tgt_mask.size(-1)) prediction_mask = align_tgt_mask[:, 1:] # exclude bos to match pred # get aligned src id for each prediction's valid tgt tokens alignement = extract_alignment( alignment_attn, prediction_mask, src_lengths, n_best ) return alignement def translate_batch(self, batch, src_vocabs, attn_debug): """Translate a batch of sentences.""" with torch.no_grad(): if self.sample_from_topk != 0 or self.sample_from_topp != 0: decode_strategy = GreedySearch( pad=self._tgt_pad_idx, bos=self._tgt_bos_idx, eos=self._tgt_eos_idx, unk=self._tgt_unk_idx, batch_size=batch.batch_size, global_scorer=self.global_scorer, min_length=self.min_length, max_length=self.max_length, block_ngram_repeat=self.block_ngram_repeat, exclusion_tokens=self._exclusion_idxs, return_attention=attn_debug or self.replace_unk, sampling_temp=self.random_sampling_temp, keep_topk=self.sample_from_topk, keep_topp=self.sample_from_topp, beam_size=self.beam_size, ban_unk_token=self.ban_unk_token, ) else: # TODO: support these blacklisted features assert not self.dump_beam decode_strategy = BeamSearch( self.beam_size, batch_size=batch.batch_size, pad=self._tgt_pad_idx, bos=self._tgt_bos_idx, eos=self._tgt_eos_idx, unk=self._tgt_unk_idx, n_best=self.n_best, global_scorer=self.global_scorer, min_length=self.min_length, max_length=self.max_length, return_attention=attn_debug or self.replace_unk, block_ngram_repeat=self.block_ngram_repeat, exclusion_tokens=self._exclusion_idxs, stepwise_penalty=self.stepwise_penalty, ratio=self.ratio, ban_unk_token=self.ban_unk_token, ) return self._translate_batch_with_strategy( batch, src_vocabs, decode_strategy ) def _run_encoder(self, batch): src, src_lengths = ( batch.src if isinstance(batch.src, tuple) else (batch.src, None) ) enc_states, memory_bank, src_lengths = self.model.encoder( src, src_lengths ) if src_lengths is None: assert not isinstance( memory_bank, tuple ), "Ensemble decoding only supported for text data" src_lengths = ( torch.Tensor(batch.batch_size) .type_as(memory_bank) .long() .fill_(memory_bank.size(0)) ) return src, enc_states, memory_bank, src_lengths def _translate_batch_with_strategy( self, batch, src_vocabs, decode_strategy ): """Translate a batch of sentences step by step using cache. Args: batch: a batch of sentences, yield by data iterator. src_vocabs (list): list of torchtext.data.Vocab if can_copy. decode_strategy (DecodeStrategy): A decode strategy to use for generate translation step by step. Returns: results (dict): The translation results. """ # (0) Prep the components of the search. use_src_map = self.copy_attn parallel_paths = decode_strategy.parallel_paths # beam_size batch_size = batch.batch_size # (1) Run the encoder on the src. src, enc_states, memory_bank, src_lengths = self._run_encoder(batch) self.model.decoder.init_state(src, memory_bank, enc_states) gold_score = self._gold_score( batch, memory_bank, src_lengths, src_vocabs, use_src_map, enc_states, batch_size, src, ) # (2) prep decode_strategy. Possibly repeat src objects. src_map = batch.src_map if use_src_map else None target_prefix = batch.tgt if self.tgt_prefix else None ( fn_map_state, memory_bank, memory_lengths, src_map, ) = decode_strategy.initialize( memory_bank, src_lengths, src_map, target_prefix=target_prefix ) if fn_map_state is not None: self.model.decoder.map_state(fn_map_state) # (3) Begin decoding step by step: for step in range(decode_strategy.max_length): decoder_input = decode_strategy.current_predictions.view(1, -1, 1) log_probs, attn = self._decode_and_generate( decoder_input, memory_bank, batch, src_vocabs, memory_lengths=memory_lengths, src_map=src_map, step=step, batch_offset=decode_strategy.batch_offset, ) decode_strategy.advance(log_probs, attn) any_finished = decode_strategy.is_finished.any() if any_finished: decode_strategy.update_finished() if decode_strategy.done: break select_indices = decode_strategy.select_indices if any_finished: # Reorder states. if isinstance(memory_bank, tuple): memory_bank = tuple( x.index_select(1, select_indices) for x in memory_bank ) else: memory_bank = memory_bank.index_select(1, select_indices) memory_lengths = memory_lengths.index_select(0, select_indices) if src_map is not None: src_map = src_map.index_select(1, select_indices) if parallel_paths > 1 or any_finished: self.model.decoder.map_state( lambda state, dim: state.index_select(dim, select_indices) ) return self.report_results( gold_score, batch, batch_size, src, src_lengths, src_vocabs, use_src_map, decode_strategy, ) def _score_target( self, batch, memory_bank, src_lengths, src_vocabs, src_map ): tgt = batch.tgt tgt_in = tgt[:-1] log_probs, attn = self._decode_and_generate( tgt_in, memory_bank, batch, src_vocabs, memory_lengths=src_lengths, src_map=src_map, ) log_probs[:, :, self._tgt_pad_idx] = 0 gold = tgt[1:] gold_scores = log_probs.gather(2, gold) gold_scores = gold_scores.sum(dim=0).view(-1) return gold_scores class GeneratorLM(Inference): @classmethod def validate_task(cls, task): if task != ModelTask.LANGUAGE_MODEL: raise ValueError( f"GeneratorLM does not support task {task}." f" Tasks supported: {ModelTask.LANGUAGE_MODEL}" ) def _align_forward(self, batch, predictions): """ For a batch of input and its prediction, return a list of batch predict alignment src indice Tensor in size ``(batch, n_best,)``. """ raise NotImplementedError def translate( self, src, src_feats={}, tgt=None, batch_size=None, batch_type="sents", attn_debug=False, align_debug=False, phrase_table="", ): if batch_size != 1: warning_msg = ("GeneratorLM does not support batch_size != 1" " nicely. You can remove this limitation here." " With batch_size > 1 the end of each input is" " repeated until the input is finished. Then" " generation will start.") if self.logger: self.logger.info(warning_msg) else: os.write(1, warning_msg.encode("utf-8")) return super(GeneratorLM, self).translate( src, src_feats, tgt, batch_size=1, batch_type=batch_type, attn_debug=attn_debug, align_debug=align_debug, phrase_table=phrase_table, ) def translate_batch(self, batch, src_vocabs, attn_debug): """Translate a batch of sentences.""" with torch.no_grad(): if self.sample_from_topk != 0 or self.sample_from_topp != 0: decode_strategy = GreedySearchLM( pad=self._tgt_pad_idx, bos=self._tgt_bos_idx, eos=self._tgt_eos_idx, unk=self._tgt_unk_idx, batch_size=batch.batch_size, global_scorer=self.global_scorer, min_length=self.min_length, max_length=self.max_length, block_ngram_repeat=self.block_ngram_repeat, exclusion_tokens=self._exclusion_idxs, return_attention=attn_debug or self.replace_unk, sampling_temp=self.random_sampling_temp, keep_topk=self.sample_from_topk, keep_topp=self.sample_from_topp, beam_size=self.beam_size, ban_unk_token=self.ban_unk_token, ) else: # TODO: support these blacklisted features assert not self.dump_beam decode_strategy = BeamSearchLM( self.beam_size, batch_size=batch.batch_size, pad=self._tgt_pad_idx, bos=self._tgt_bos_idx, eos=self._tgt_eos_idx, unk=self._tgt_unk_idx, n_best=self.n_best, global_scorer=self.global_scorer, min_length=self.min_length, max_length=self.max_length, return_attention=attn_debug or self.replace_unk, block_ngram_repeat=self.block_ngram_repeat, exclusion_tokens=self._exclusion_idxs, stepwise_penalty=self.stepwise_penalty, ratio=self.ratio, ban_unk_token=self.ban_unk_token, ) return self._translate_batch_with_strategy( batch, src_vocabs, decode_strategy ) @classmethod def split_src_to_prevent_padding(cls, src, src_lengths): min_len_batch = torch.min(src_lengths).item() target_prefix = None if min_len_batch > 0 and min_len_batch < src.size(0): target_prefix = src[min_len_batch:] src = src[:min_len_batch] src_lengths[:] = min_len_batch return src, src_lengths, target_prefix def tile_to_beam_size_after_initial_step(self, fn_map_state, log_probs): if fn_map_state is not None: log_probs = fn_map_state(log_probs, dim=1) self.model.decoder.map_state(fn_map_state) log_probs = log_probs[-1] return log_probs def _translate_batch_with_strategy( self, batch, src_vocabs, decode_strategy ): """Translate a batch of sentences step by step using cache. Args: batch: a batch of sentences, yield by data iterator. src_vocabs (list): list of torchtext.data.Vocab if can_copy. decode_strategy (DecodeStrategy): A decode strategy to use for generate translation step by step. Returns: results (dict): The translation results. """ # (0) Prep the components of the search. use_src_map = self.copy_attn parallel_paths = decode_strategy.parallel_paths # beam_size batch_size = batch.batch_size # (1) split src into src and target_prefix to avoid padding. src, src_lengths = ( batch.src if isinstance(batch.src, tuple) else (batch.src, None) ) src, src_lengths, target_prefix = self.split_src_to_prevent_padding( src, src_lengths ) # (2) init decoder self.model.decoder.init_state(src, None, None) gold_score = self._gold_score( batch, None, src_lengths, src_vocabs, use_src_map, None, batch_size, src, ) # (3) prep decode_strategy. Possibly repeat src objects. src_map = batch.src_map if use_src_map else None ( fn_map_state, src, memory_lengths, src_map, ) = decode_strategy.initialize( src, src_lengths, src_map, target_prefix=target_prefix, ) # (4) Begin decoding step by step: for step in range(decode_strategy.max_length): decoder_input = ( src if step == 0 else decode_strategy.current_predictions.view(1, -1, 1) ) log_probs, attn = self._decode_and_generate( decoder_input, None, batch, src_vocabs, memory_lengths=memory_lengths.clone(), src_map=src_map, step=step if step == 0 else step + src_lengths[0].item(), batch_offset=decode_strategy.batch_offset, ) if step == 0: log_probs = self.tile_to_beam_size_after_initial_step( fn_map_state, log_probs) decode_strategy.advance(log_probs, attn) any_finished = decode_strategy.is_finished.any() if any_finished: decode_strategy.update_finished() if decode_strategy.done: break select_indices = decode_strategy.select_indices memory_lengths += 1 if any_finished: # Reorder states. memory_lengths = memory_lengths.index_select(0, select_indices) if src_map is not None: src_map = src_map.index_select(1, select_indices) if parallel_paths > 1 or any_finished: # select indexes in model state/cache self.model.decoder.map_state( lambda state, dim: state.index_select(dim, select_indices) ) return self.report_results( gold_score, batch, batch_size, src, src_lengths, src_vocabs, use_src_map, decode_strategy, ) def _score_target( self, batch, memory_bank, src_lengths, src_vocabs, src_map ): tgt = batch.tgt src, src_lengths = ( batch.src if isinstance(batch.src, tuple) else (batch.src, None) ) log_probs, attn = self._decode_and_generate( src, None, batch, src_vocabs, memory_lengths=src_lengths, src_map=src_map, ) log_probs[:, :, self._tgt_pad_idx] = 0 gold_scores = log_probs.gather(2, tgt) gold_scores = gold_scores.sum(dim=0).view(-1) return gold_scores