# MIT License # Copyright (c) 2019 Yang Liu and the HuggingFace team # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. import copy import math import numpy as np import torch from configuration_bertabs import BertAbsConfig from torch import nn from torch.nn.init import xavier_uniform_ from transformers import BertConfig, BertModel, PreTrainedModel MAX_SIZE = 5000 BERTABS_FINETUNED_MODEL_ARCHIVE_LIST = [ "remi/bertabs-finetuned-cnndm-extractive-abstractive-summarization", ] class BertAbsPreTrainedModel(PreTrainedModel): config_class = BertAbsConfig load_tf_weights = False base_model_prefix = "bert" class BertAbs(BertAbsPreTrainedModel): def __init__(self, args, checkpoint=None, bert_extractive_checkpoint=None): super().__init__(args) self.args = args self.bert = Bert() # If pre-trained weights are passed for Bert, load these. load_bert_pretrained_extractive = True if bert_extractive_checkpoint else False if load_bert_pretrained_extractive: self.bert.model.load_state_dict( {n[11:]: p for n, p in bert_extractive_checkpoint.items() if n.startswith("bert.model")}, strict=True, ) self.vocab_size = self.bert.model.config.vocab_size if args.max_pos > 512: my_pos_embeddings = nn.Embedding(args.max_pos, self.bert.model.config.hidden_size) my_pos_embeddings.weight.data[:512] = self.bert.model.embeddings.position_embeddings.weight.data my_pos_embeddings.weight.data[512:] = self.bert.model.embeddings.position_embeddings.weight.data[-1][ None, : ].repeat(args.max_pos - 512, 1) self.bert.model.embeddings.position_embeddings = my_pos_embeddings tgt_embeddings = nn.Embedding(self.vocab_size, self.bert.model.config.hidden_size, padding_idx=0) tgt_embeddings.weight = copy.deepcopy(self.bert.model.embeddings.word_embeddings.weight) self.decoder = TransformerDecoder( self.args.dec_layers, self.args.dec_hidden_size, heads=self.args.dec_heads, d_ff=self.args.dec_ff_size, dropout=self.args.dec_dropout, embeddings=tgt_embeddings, vocab_size=self.vocab_size, ) gen_func = nn.LogSoftmax(dim=-1) self.generator = nn.Sequential(nn.Linear(args.dec_hidden_size, args.vocab_size), gen_func) self.generator[0].weight = self.decoder.embeddings.weight load_from_checkpoints = False if checkpoint is None else True if load_from_checkpoints: self.load_state_dict(checkpoint) def init_weights(self): for module in self.decoder.modules(): if isinstance(module, (nn.Linear, nn.Embedding)): module.weight.data.normal_(mean=0.0, std=0.02) elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() for p in self.generator.parameters(): if p.dim() > 1: xavier_uniform_(p) else: p.data.zero_() def forward( self, encoder_input_ids, decoder_input_ids, token_type_ids, encoder_attention_mask, decoder_attention_mask, ): encoder_output = self.bert( input_ids=encoder_input_ids, token_type_ids=token_type_ids, attention_mask=encoder_attention_mask, ) encoder_hidden_states = encoder_output[0] dec_state = self.decoder.init_decoder_state(encoder_input_ids, encoder_hidden_states) decoder_outputs, _ = self.decoder(decoder_input_ids[:, :-1], encoder_hidden_states, dec_state) return decoder_outputs class Bert(nn.Module): """This class is not really necessary and should probably disappear.""" def __init__(self): super().__init__() config = BertConfig.from_pretrained("bert-base-uncased") self.model = BertModel(config) def forward(self, input_ids, attention_mask=None, token_type_ids=None, **kwargs): self.eval() with torch.no_grad(): encoder_outputs, _ = self.model( input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, **kwargs ) return encoder_outputs class TransformerDecoder(nn.Module): """ The Transformer decoder from "Attention is All You Need". Args: num_layers (int): number of encoder layers. d_model (int): size of the model heads (int): number of heads d_ff (int): size of the inner FF layer dropout (float): dropout parameters embeddings (:obj:`onmt.modules.Embeddings`): embeddings to use, should have positional encodings attn_type (str): if using a separate copy attention """ def __init__(self, num_layers, d_model, heads, d_ff, dropout, embeddings, vocab_size): super().__init__() # Basic attributes. self.decoder_type = "transformer" self.num_layers = num_layers self.embeddings = embeddings self.pos_emb = PositionalEncoding(dropout, self.embeddings.embedding_dim) # Build TransformerDecoder. self.transformer_layers = nn.ModuleList( [TransformerDecoderLayer(d_model, heads, d_ff, dropout) for _ in range(num_layers)] ) self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) # forward(input_ids, attention_mask, encoder_hidden_states, encoder_attention_mask) # def forward(self, input_ids, state, attention_mask=None, memory_lengths=None, # step=None, cache=None, encoder_attention_mask=None, encoder_hidden_states=None, memory_masks=None): def forward( self, input_ids, encoder_hidden_states=None, state=None, attention_mask=None, memory_lengths=None, step=None, cache=None, encoder_attention_mask=None, ): """ See :obj:`onmt.modules.RNNDecoderBase.forward()` memory_bank = encoder_hidden_states """ # Name conversion tgt = input_ids memory_bank = encoder_hidden_states memory_mask = encoder_attention_mask # src_words = state.src src_words = state.src src_batch, src_len = src_words.size() padding_idx = self.embeddings.padding_idx # Decoder padding mask tgt_words = tgt tgt_batch, tgt_len = tgt_words.size() tgt_pad_mask = tgt_words.data.eq(padding_idx).unsqueeze(1).expand(tgt_batch, tgt_len, tgt_len) # Encoder padding mask if memory_mask is not None: src_len = memory_mask.size(-1) src_pad_mask = memory_mask.expand(src_batch, tgt_len, src_len) else: src_pad_mask = src_words.data.eq(padding_idx).unsqueeze(1).expand(src_batch, tgt_len, src_len) # Pass through the embeddings emb = self.embeddings(input_ids) output = self.pos_emb(emb, step) assert emb.dim() == 3 # len x batch x embedding_dim if state.cache is None: saved_inputs = [] for i in range(self.num_layers): prev_layer_input = None if state.cache is None: if state.previous_input is not None: prev_layer_input = state.previous_layer_inputs[i] output, all_input = self.transformer_layers[i]( output, memory_bank, src_pad_mask, tgt_pad_mask, previous_input=prev_layer_input, layer_cache=state.cache["layer_{}".format(i)] if state.cache is not None else None, step=step, ) if state.cache is None: saved_inputs.append(all_input) if state.cache is None: saved_inputs = torch.stack(saved_inputs) output = self.layer_norm(output) if state.cache is None: state = state.update_state(tgt, saved_inputs) # Decoders in transformers return a tuple. Beam search will fail # if we don't follow this convention. return output, state # , state def init_decoder_state(self, src, memory_bank, with_cache=False): """Init decoder state""" state = TransformerDecoderState(src) if with_cache: state._init_cache(memory_bank, self.num_layers) return state class PositionalEncoding(nn.Module): def __init__(self, dropout, dim, max_len=5000): pe = torch.zeros(max_len, dim) position = torch.arange(0, max_len).unsqueeze(1) div_term = torch.exp((torch.arange(0, dim, 2, dtype=torch.float) * -(math.log(10000.0) / dim))) pe[:, 0::2] = torch.sin(position.float() * div_term) pe[:, 1::2] = torch.cos(position.float() * div_term) pe = pe.unsqueeze(0) super().__init__() self.register_buffer("pe", pe) self.dropout = nn.Dropout(p=dropout) self.dim = dim def forward(self, emb, step=None): emb = emb * math.sqrt(self.dim) if step: emb = emb + self.pe[:, step][:, None, :] else: emb = emb + self.pe[:, : emb.size(1)] emb = self.dropout(emb) return emb def get_emb(self, emb): return self.pe[:, : emb.size(1)] class TransformerDecoderLayer(nn.Module): """ Args: d_model (int): the dimension of keys/values/queries in MultiHeadedAttention, also the input size of the first-layer of the PositionwiseFeedForward. heads (int): the number of heads for MultiHeadedAttention. d_ff (int): the second-layer of the PositionwiseFeedForward. dropout (float): dropout probability(0-1.0). self_attn_type (string): type of self-attention scaled-dot, average """ def __init__(self, d_model, heads, d_ff, dropout): super().__init__() self.self_attn = MultiHeadedAttention(heads, d_model, dropout=dropout) self.context_attn = MultiHeadedAttention(heads, d_model, dropout=dropout) self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout) self.layer_norm_1 = nn.LayerNorm(d_model, eps=1e-6) self.layer_norm_2 = nn.LayerNorm(d_model, eps=1e-6) self.drop = nn.Dropout(dropout) mask = self._get_attn_subsequent_mask(MAX_SIZE) # Register self.mask as a saved_state in TransformerDecoderLayer, so # it gets TransformerDecoderLayer's cuda behavior automatically. self.register_buffer("mask", mask) def forward( self, inputs, memory_bank, src_pad_mask, tgt_pad_mask, previous_input=None, layer_cache=None, step=None, ): """ Args: inputs (`FloatTensor`): `[batch_size x 1 x model_dim]` memory_bank (`FloatTensor`): `[batch_size x src_len x model_dim]` src_pad_mask (`LongTensor`): `[batch_size x 1 x src_len]` tgt_pad_mask (`LongTensor`): `[batch_size x 1 x 1]` Returns: (`FloatTensor`, `FloatTensor`, `FloatTensor`): * output `[batch_size x 1 x model_dim]` * attn `[batch_size x 1 x src_len]` * all_input `[batch_size x current_step x model_dim]` """ dec_mask = torch.gt(tgt_pad_mask + self.mask[:, : tgt_pad_mask.size(1), : tgt_pad_mask.size(1)], 0) input_norm = self.layer_norm_1(inputs) all_input = input_norm if previous_input is not None: all_input = torch.cat((previous_input, input_norm), dim=1) dec_mask = None query = self.self_attn( all_input, all_input, input_norm, mask=dec_mask, layer_cache=layer_cache, type="self", ) query = self.drop(query) + inputs query_norm = self.layer_norm_2(query) mid = self.context_attn( memory_bank, memory_bank, query_norm, mask=src_pad_mask, layer_cache=layer_cache, type="context", ) output = self.feed_forward(self.drop(mid) + query) return output, all_input # return output def _get_attn_subsequent_mask(self, size): """ Get an attention mask to avoid using the subsequent info. Args: size: int Returns: (`LongTensor`): * subsequent_mask `[1 x size x size]` """ attn_shape = (1, size, size) subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype("uint8") subsequent_mask = torch.from_numpy(subsequent_mask) return subsequent_mask class MultiHeadedAttention(nn.Module): """ Multi-Head Attention module from "Attention is All You Need" :cite:`DBLP:journals/corr/VaswaniSPUJGKP17`. Similar to standard `dot` attention but uses multiple attention distributions simulataneously to select relevant items. .. mermaid:: graph BT A[key] B[value] C[query] O[output] subgraph Attn D[Attn 1] E[Attn 2] F[Attn N] end A --> D C --> D A --> E C --> E A --> F C --> F D --> O E --> O F --> O B --> O Also includes several additional tricks. Args: head_count (int): number of parallel heads model_dim (int): the dimension of keys/values/queries, must be divisible by head_count dropout (float): dropout parameter """ def __init__(self, head_count, model_dim, dropout=0.1, use_final_linear=True): assert model_dim % head_count == 0 self.dim_per_head = model_dim // head_count self.model_dim = model_dim super().__init__() self.head_count = head_count self.linear_keys = nn.Linear(model_dim, head_count * self.dim_per_head) self.linear_values = nn.Linear(model_dim, head_count * self.dim_per_head) self.linear_query = nn.Linear(model_dim, head_count * self.dim_per_head) self.softmax = nn.Softmax(dim=-1) self.dropout = nn.Dropout(dropout) self.use_final_linear = use_final_linear if self.use_final_linear: self.final_linear = nn.Linear(model_dim, model_dim) def forward( self, key, value, query, mask=None, layer_cache=None, type=None, predefined_graph_1=None, ): """ Compute the context vector and the attention vectors. Args: key (`FloatTensor`): set of `key_len` key vectors `[batch, key_len, dim]` value (`FloatTensor`): set of `key_len` value vectors `[batch, key_len, dim]` query (`FloatTensor`): set of `query_len` query vectors `[batch, query_len, dim]` mask: binary mask indicating which keys have non-zero attention `[batch, query_len, key_len]` Returns: (`FloatTensor`, `FloatTensor`) : * output context vectors `[batch, query_len, dim]` * one of the attention vectors `[batch, query_len, key_len]` """ batch_size = key.size(0) dim_per_head = self.dim_per_head head_count = self.head_count def shape(x): """projection""" return x.view(batch_size, -1, head_count, dim_per_head).transpose(1, 2) def unshape(x): """compute context""" return x.transpose(1, 2).contiguous().view(batch_size, -1, head_count * dim_per_head) # 1) Project key, value, and query. if layer_cache is not None: if type == "self": query, key, value = ( self.linear_query(query), self.linear_keys(query), self.linear_values(query), ) key = shape(key) value = shape(value) if layer_cache is not None: device = key.device if layer_cache["self_keys"] is not None: key = torch.cat((layer_cache["self_keys"].to(device), key), dim=2) if layer_cache["self_values"] is not None: value = torch.cat((layer_cache["self_values"].to(device), value), dim=2) layer_cache["self_keys"] = key layer_cache["self_values"] = value elif type == "context": query = self.linear_query(query) if layer_cache is not None: if layer_cache["memory_keys"] is None: key, value = self.linear_keys(key), self.linear_values(value) key = shape(key) value = shape(value) else: key, value = ( layer_cache["memory_keys"], layer_cache["memory_values"], ) layer_cache["memory_keys"] = key layer_cache["memory_values"] = value else: key, value = self.linear_keys(key), self.linear_values(value) key = shape(key) value = shape(value) else: key = self.linear_keys(key) value = self.linear_values(value) query = self.linear_query(query) key = shape(key) value = shape(value) query = shape(query) # 2) Calculate and scale scores. query = query / math.sqrt(dim_per_head) scores = torch.matmul(query, key.transpose(2, 3)) if mask is not None: mask = mask.unsqueeze(1).expand_as(scores) scores = scores.masked_fill(mask, -1e18) # 3) Apply attention dropout and compute context vectors. attn = self.softmax(scores) if predefined_graph_1 is not None: attn_masked = attn[:, -1] * predefined_graph_1 attn_masked = attn_masked / (torch.sum(attn_masked, 2).unsqueeze(2) + 1e-9) attn = torch.cat([attn[:, :-1], attn_masked.unsqueeze(1)], 1) drop_attn = self.dropout(attn) if self.use_final_linear: context = unshape(torch.matmul(drop_attn, value)) output = self.final_linear(context) return output else: context = torch.matmul(drop_attn, value) return context class DecoderState(object): """Interface for grouping together the current state of a recurrent decoder. In the simplest case just represents the hidden state of the model. But can also be used for implementing various forms of input_feeding and non-recurrent models. Modules need to implement this to utilize beam search decoding. """ def detach(self): """Need to document this""" self.hidden = tuple([_.detach() for _ in self.hidden]) self.input_feed = self.input_feed.detach() def beam_update(self, idx, positions, beam_size): """Need to document this""" for e in self._all: sizes = e.size() br = sizes[1] if len(sizes) == 3: sent_states = e.view(sizes[0], beam_size, br // beam_size, sizes[2])[:, :, idx] else: sent_states = e.view(sizes[0], beam_size, br // beam_size, sizes[2], sizes[3])[:, :, idx] sent_states.data.copy_(sent_states.data.index_select(1, positions)) def map_batch_fn(self, fn): raise NotImplementedError() class TransformerDecoderState(DecoderState): """Transformer Decoder state base class""" def __init__(self, src): """ Args: src (FloatTensor): a sequence of source words tensors with optional feature tensors, of size (len x batch). """ self.src = src self.previous_input = None self.previous_layer_inputs = None self.cache = None @property def _all(self): """ Contains attributes that need to be updated in self.beam_update(). """ if self.previous_input is not None and self.previous_layer_inputs is not None: return (self.previous_input, self.previous_layer_inputs, self.src) else: return (self.src,) def detach(self): if self.previous_input is not None: self.previous_input = self.previous_input.detach() if self.previous_layer_inputs is not None: self.previous_layer_inputs = self.previous_layer_inputs.detach() self.src = self.src.detach() def update_state(self, new_input, previous_layer_inputs): state = TransformerDecoderState(self.src) state.previous_input = new_input state.previous_layer_inputs = previous_layer_inputs return state def _init_cache(self, memory_bank, num_layers): self.cache = {} for l in range(num_layers): layer_cache = {"memory_keys": None, "memory_values": None} layer_cache["self_keys"] = None layer_cache["self_values"] = None self.cache["layer_{}".format(l)] = layer_cache def repeat_beam_size_times(self, beam_size): """Repeat beam_size times along batch dimension.""" self.src = self.src.data.repeat(1, beam_size, 1) def map_batch_fn(self, fn): def _recursive_map(struct, batch_dim=0): for k, v in struct.items(): if v is not None: if isinstance(v, dict): _recursive_map(v) else: struct[k] = fn(v, batch_dim) self.src = fn(self.src, 0) if self.cache is not None: _recursive_map(self.cache) def gelu(x): return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) class PositionwiseFeedForward(nn.Module): """A two-layer Feed-Forward-Network with residual layer norm. Args: d_model (int): the size of input for the first-layer of the FFN. d_ff (int): the hidden layer size of the second-layer of the FNN. dropout (float): dropout probability in :math:`[0, 1)`. """ def __init__(self, d_model, d_ff, dropout=0.1): super().__init__() self.w_1 = nn.Linear(d_model, d_ff) self.w_2 = nn.Linear(d_ff, d_model) self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) self.actv = gelu self.dropout_1 = nn.Dropout(dropout) self.dropout_2 = nn.Dropout(dropout) def forward(self, x): inter = self.dropout_1(self.actv(self.w_1(self.layer_norm(x)))) output = self.dropout_2(self.w_2(inter)) return output + x # # TRANSLATOR # The following code is used to generate summaries using the # pre-trained weights and beam search. # def build_predictor(args, tokenizer, symbols, model, logger=None): # we should be able to refactor the global scorer a lot scorer = GNMTGlobalScorer(args.alpha, length_penalty="wu") translator = Translator(args, model, tokenizer, symbols, global_scorer=scorer, logger=logger) return translator class GNMTGlobalScorer(object): """ NMT re-ranking score from "Google's Neural Machine Translation System" :cite:`wu2016google` Args: alpha (float): length parameter beta (float): coverage parameter """ def __init__(self, alpha, length_penalty): self.alpha = alpha penalty_builder = PenaltyBuilder(length_penalty) self.length_penalty = penalty_builder.length_penalty() def score(self, beam, logprobs): """ Rescores a prediction based on penalty functions """ normalized_probs = self.length_penalty(beam, logprobs, self.alpha) return normalized_probs class PenaltyBuilder(object): """ Returns the Length and Coverage Penalty function for Beam Search. Args: length_pen (str): option name of length pen cov_pen (str): option name of cov pen """ def __init__(self, length_pen): self.length_pen = length_pen def length_penalty(self): if self.length_pen == "wu": return self.length_wu elif self.length_pen == "avg": return self.length_average else: return self.length_none """ Below are all the different penalty terms implemented so far """ def length_wu(self, beam, logprobs, alpha=0.0): """ NMT length re-ranking score from "Google's Neural Machine Translation System" :cite:`wu2016google`. """ modifier = ((5 + len(beam.next_ys)) ** alpha) / ((5 + 1) ** alpha) return logprobs / modifier def length_average(self, beam, logprobs, alpha=0.0): """ Returns the average probability of tokens in a sequence. """ return logprobs / len(beam.next_ys) def length_none(self, beam, logprobs, alpha=0.0, beta=0.0): """ Returns unmodified scores. """ return logprobs class Translator(object): """ Uses a model to translate a batch of sentences. Args: model (:obj:`onmt.modules.NMTModel`): NMT model to use for translation fields (dict of Fields): data fields beam_size (int): size of beam to use n_best (int): number of translations produced max_length (int): maximum length output to produce global_scores (:obj:`GlobalScorer`): object to rescore final translations copy_attn (bool): use copy attention during translation beam_trace (bool): trace beam search for debugging logger(logging.Logger): logger. """ def __init__(self, args, model, vocab, symbols, global_scorer=None, logger=None): self.logger = logger self.args = args self.model = model self.generator = self.model.generator self.vocab = vocab self.symbols = symbols self.start_token = symbols["BOS"] self.end_token = symbols["EOS"] self.global_scorer = global_scorer self.beam_size = args.beam_size self.min_length = args.min_length self.max_length = args.max_length def translate(self, batch, step, attn_debug=False): """Generates summaries from one batch of data.""" self.model.eval() with torch.no_grad(): batch_data = self.translate_batch(batch) translations = self.from_batch(batch_data) return translations def translate_batch(self, batch, fast=False): """ Translate a batch of sentences. Mostly a wrapper around :obj:`Beam`. Args: batch (:obj:`Batch`): a batch from a dataset object fast (bool): enables fast beam search (may not support all features) """ with torch.no_grad(): return self._fast_translate_batch(batch, self.max_length, min_length=self.min_length) # Where the beam search lives # I have no idea why it is being called from the method above def _fast_translate_batch(self, batch, max_length, min_length=0): """Beam Search using the encoder inputs contained in `batch`.""" # The batch object is funny # Instead of just looking at the size of the arguments we encapsulate # a size argument. # Where is it defined? beam_size = self.beam_size batch_size = batch.batch_size src = batch.src segs = batch.segs mask_src = batch.mask_src src_features = self.model.bert(src, segs, mask_src) dec_states = self.model.decoder.init_decoder_state(src, src_features, with_cache=True) device = src_features.device # Tile states and memory beam_size times. dec_states.map_batch_fn(lambda state, dim: tile(state, beam_size, dim=dim)) src_features = tile(src_features, beam_size, dim=0) batch_offset = torch.arange(batch_size, dtype=torch.long, device=device) beam_offset = torch.arange(0, batch_size * beam_size, step=beam_size, dtype=torch.long, device=device) alive_seq = torch.full([batch_size * beam_size, 1], self.start_token, dtype=torch.long, device=device) # Give full probability to the first beam on the first step. topk_log_probs = torch.tensor([0.0] + [float("-inf")] * (beam_size - 1), device=device).repeat(batch_size) # Structure that holds finished hypotheses. hypotheses = [[] for _ in range(batch_size)] # noqa: F812 results = {} results["predictions"] = [[] for _ in range(batch_size)] # noqa: F812 results["scores"] = [[] for _ in range(batch_size)] # noqa: F812 results["gold_score"] = [0] * batch_size results["batch"] = batch for step in range(max_length): decoder_input = alive_seq[:, -1].view(1, -1) # Decoder forward. decoder_input = decoder_input.transpose(0, 1) dec_out, dec_states = self.model.decoder(decoder_input, src_features, dec_states, step=step) # Generator forward. log_probs = self.generator(dec_out.transpose(0, 1).squeeze(0)) vocab_size = log_probs.size(-1) if step < min_length: log_probs[:, self.end_token] = -1e20 # Multiply probs by the beam probability. log_probs += topk_log_probs.view(-1).unsqueeze(1) alpha = self.global_scorer.alpha length_penalty = ((5.0 + (step + 1)) / 6.0) ** alpha # Flatten probs into a list of possibilities. curr_scores = log_probs / length_penalty if self.args.block_trigram: cur_len = alive_seq.size(1) if cur_len > 3: for i in range(alive_seq.size(0)): fail = False words = [int(w) for w in alive_seq[i]] words = [self.vocab.ids_to_tokens[w] for w in words] words = " ".join(words).replace(" ##", "").split() if len(words) <= 3: continue trigrams = [(words[i - 1], words[i], words[i + 1]) for i in range(1, len(words) - 1)] trigram = tuple(trigrams[-1]) if trigram in trigrams[:-1]: fail = True if fail: curr_scores[i] = -10e20 curr_scores = curr_scores.reshape(-1, beam_size * vocab_size) topk_scores, topk_ids = curr_scores.topk(beam_size, dim=-1) # Recover log probs. topk_log_probs = topk_scores * length_penalty # Resolve beam origin and true word ids. topk_beam_index = topk_ids.div(vocab_size) topk_ids = topk_ids.fmod(vocab_size) # Map beam_index to batch_index in the flat representation. batch_index = topk_beam_index + beam_offset[: topk_beam_index.size(0)].unsqueeze(1) select_indices = batch_index.view(-1) # Append last prediction. alive_seq = torch.cat([alive_seq.index_select(0, select_indices), topk_ids.view(-1, 1)], -1) is_finished = topk_ids.eq(self.end_token) if step + 1 == max_length: is_finished.fill_(1) # End condition is top beam is finished. end_condition = is_finished[:, 0].eq(1) # Save finished hypotheses. if is_finished.any(): predictions = alive_seq.view(-1, beam_size, alive_seq.size(-1)) for i in range(is_finished.size(0)): b = batch_offset[i] if end_condition[i]: is_finished[i].fill_(1) finished_hyp = is_finished[i].nonzero().view(-1) # Store finished hypotheses for this batch. for j in finished_hyp: hypotheses[b].append((topk_scores[i, j], predictions[i, j, 1:])) # If the batch reached the end, save the n_best hypotheses. if end_condition[i]: best_hyp = sorted(hypotheses[b], key=lambda x: x[0], reverse=True) score, pred = best_hyp[0] results["scores"][b].append(score) results["predictions"][b].append(pred) non_finished = end_condition.eq(0).nonzero().view(-1) # If all sentences are translated, no need to go further. if len(non_finished) == 0: break # Remove finished batches for the next step. topk_log_probs = topk_log_probs.index_select(0, non_finished) batch_index = batch_index.index_select(0, non_finished) batch_offset = batch_offset.index_select(0, non_finished) alive_seq = predictions.index_select(0, non_finished).view(-1, alive_seq.size(-1)) # Reorder states. select_indices = batch_index.view(-1) src_features = src_features.index_select(0, select_indices) dec_states.map_batch_fn(lambda state, dim: state.index_select(dim, select_indices)) return results def from_batch(self, translation_batch): batch = translation_batch["batch"] assert len(translation_batch["gold_score"]) == len(translation_batch["predictions"]) batch_size = batch.batch_size preds, _, _, tgt_str, src = ( translation_batch["predictions"], translation_batch["scores"], translation_batch["gold_score"], batch.tgt_str, batch.src, ) translations = [] for b in range(batch_size): pred_sents = self.vocab.convert_ids_to_tokens([int(n) for n in preds[b][0]]) pred_sents = " ".join(pred_sents).replace(" ##", "") gold_sent = " ".join(tgt_str[b].split()) raw_src = [self.vocab.ids_to_tokens[int(t)] for t in src[b]][:500] raw_src = " ".join(raw_src) translation = (pred_sents, gold_sent, raw_src) translations.append(translation) return translations def tile(x, count, dim=0): """ Tiles x on dimension dim count times. """ perm = list(range(len(x.size()))) if dim != 0: perm[0], perm[dim] = perm[dim], perm[0] x = x.permute(perm).contiguous() out_size = list(x.size()) out_size[0] *= count batch = x.size(0) x = x.view(batch, -1).transpose(0, 1).repeat(count, 1).transpose(0, 1).contiguous().view(*out_size) if dim != 0: x = x.permute(perm).contiguous() return x # # Optimizer for training. We keep this here in case we want to add # a finetuning script. # class BertSumOptimizer(object): """Specific optimizer for BertSum. As described in [1], the authors fine-tune BertSum for abstractive summarization using two Adam Optimizers with different warm-up steps and learning rate. They also use a custom learning rate scheduler. [1] Liu, Yang, and Mirella Lapata. "Text summarization with pretrained encoders." arXiv preprint arXiv:1908.08345 (2019). """ def __init__(self, model, lr, warmup_steps, beta_1=0.99, beta_2=0.999, eps=1e-8): self.encoder = model.encoder self.decoder = model.decoder self.lr = lr self.warmup_steps = warmup_steps self.optimizers = { "encoder": torch.optim.Adam( model.encoder.parameters(), lr=lr["encoder"], betas=(beta_1, beta_2), eps=eps, ), "decoder": torch.optim.Adam( model.decoder.parameters(), lr=lr["decoder"], betas=(beta_1, beta_2), eps=eps, ), } self._step = 0 self.current_learning_rates = {} def _update_rate(self, stack): return self.lr[stack] * min(self._step ** (-0.5), self._step * self.warmup_steps[stack] ** (-1.5)) def zero_grad(self): self.optimizer_decoder.zero_grad() self.optimizer_encoder.zero_grad() def step(self): self._step += 1 for stack, optimizer in self.optimizers.items(): new_rate = self._update_rate(stack) for param_group in optimizer.param_groups: param_group["lr"] = new_rate optimizer.step() self.current_learning_rates[stack] = new_rate