Spaces:
Paused
Paused
| # MIT License | |
| # Copyright (c) 2019 Yang Liu and the HuggingFace team | |
| # Permission is hereby granted, free of charge, to any person obtaining a copy | |
| # of this software and associated documentation files (the "Software"), to deal | |
| # in the Software without restriction, including without limitation the rights | |
| # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |
| # copies of the Software, and to permit persons to whom the Software is | |
| # furnished to do so, subject to the following conditions: | |
| # The above copyright notice and this permission notice shall be included in all | |
| # copies or substantial portions of the Software. | |
| # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
| # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
| # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |
| # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
| # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |
| # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |
| # SOFTWARE. | |
| import copy | |
| import math | |
| import numpy as np | |
| import torch | |
| from configuration_bertabs import BertAbsConfig | |
| from torch import nn | |
| from torch.nn.init import xavier_uniform_ | |
| from transformers import BertConfig, BertModel, PreTrainedModel | |
| MAX_SIZE = 5000 | |
| BERTABS_FINETUNED_MODEL_ARCHIVE_LIST = [ | |
| "remi/bertabs-finetuned-cnndm-extractive-abstractive-summarization", | |
| ] | |
| class BertAbsPreTrainedModel(PreTrainedModel): | |
| config_class = BertAbsConfig | |
| load_tf_weights = False | |
| base_model_prefix = "bert" | |
| class BertAbs(BertAbsPreTrainedModel): | |
| def __init__(self, args, checkpoint=None, bert_extractive_checkpoint=None): | |
| super().__init__(args) | |
| self.args = args | |
| self.bert = Bert() | |
| # If pre-trained weights are passed for Bert, load these. | |
| load_bert_pretrained_extractive = True if bert_extractive_checkpoint else False | |
| if load_bert_pretrained_extractive: | |
| self.bert.model.load_state_dict( | |
| {n[11:]: p for n, p in bert_extractive_checkpoint.items() if n.startswith("bert.model")}, | |
| strict=True, | |
| ) | |
| self.vocab_size = self.bert.model.config.vocab_size | |
| if args.max_pos > 512: | |
| my_pos_embeddings = nn.Embedding(args.max_pos, self.bert.model.config.hidden_size) | |
| my_pos_embeddings.weight.data[:512] = self.bert.model.embeddings.position_embeddings.weight.data | |
| my_pos_embeddings.weight.data[512:] = self.bert.model.embeddings.position_embeddings.weight.data[-1][ | |
| None, : | |
| ].repeat(args.max_pos - 512, 1) | |
| self.bert.model.embeddings.position_embeddings = my_pos_embeddings | |
| tgt_embeddings = nn.Embedding(self.vocab_size, self.bert.model.config.hidden_size, padding_idx=0) | |
| tgt_embeddings.weight = copy.deepcopy(self.bert.model.embeddings.word_embeddings.weight) | |
| self.decoder = TransformerDecoder( | |
| self.args.dec_layers, | |
| self.args.dec_hidden_size, | |
| heads=self.args.dec_heads, | |
| d_ff=self.args.dec_ff_size, | |
| dropout=self.args.dec_dropout, | |
| embeddings=tgt_embeddings, | |
| vocab_size=self.vocab_size, | |
| ) | |
| gen_func = nn.LogSoftmax(dim=-1) | |
| self.generator = nn.Sequential(nn.Linear(args.dec_hidden_size, args.vocab_size), gen_func) | |
| self.generator[0].weight = self.decoder.embeddings.weight | |
| load_from_checkpoints = False if checkpoint is None else True | |
| if load_from_checkpoints: | |
| self.load_state_dict(checkpoint) | |
| def init_weights(self): | |
| for module in self.decoder.modules(): | |
| if isinstance(module, (nn.Linear, nn.Embedding)): | |
| module.weight.data.normal_(mean=0.0, std=0.02) | |
| elif isinstance(module, nn.LayerNorm): | |
| module.bias.data.zero_() | |
| module.weight.data.fill_(1.0) | |
| if isinstance(module, nn.Linear) and module.bias is not None: | |
| module.bias.data.zero_() | |
| for p in self.generator.parameters(): | |
| if p.dim() > 1: | |
| xavier_uniform_(p) | |
| else: | |
| p.data.zero_() | |
| def forward( | |
| self, | |
| encoder_input_ids, | |
| decoder_input_ids, | |
| token_type_ids, | |
| encoder_attention_mask, | |
| decoder_attention_mask, | |
| ): | |
| encoder_output = self.bert( | |
| input_ids=encoder_input_ids, | |
| token_type_ids=token_type_ids, | |
| attention_mask=encoder_attention_mask, | |
| ) | |
| encoder_hidden_states = encoder_output[0] | |
| dec_state = self.decoder.init_decoder_state(encoder_input_ids, encoder_hidden_states) | |
| decoder_outputs, _ = self.decoder(decoder_input_ids[:, :-1], encoder_hidden_states, dec_state) | |
| return decoder_outputs | |
| class Bert(nn.Module): | |
| """This class is not really necessary and should probably disappear.""" | |
| def __init__(self): | |
| super().__init__() | |
| config = BertConfig.from_pretrained("bert-base-uncased") | |
| self.model = BertModel(config) | |
| def forward(self, input_ids, attention_mask=None, token_type_ids=None, **kwargs): | |
| self.eval() | |
| with torch.no_grad(): | |
| encoder_outputs, _ = self.model( | |
| input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, **kwargs | |
| ) | |
| return encoder_outputs | |
| class TransformerDecoder(nn.Module): | |
| """ | |
| The Transformer decoder from "Attention is All You Need". | |
| Args: | |
| num_layers (int): number of encoder layers. | |
| d_model (int): size of the model | |
| heads (int): number of heads | |
| d_ff (int): size of the inner FF layer | |
| dropout (float): dropout parameters | |
| embeddings (:obj:`onmt.modules.Embeddings`): | |
| embeddings to use, should have positional encodings | |
| attn_type (str): if using a separate copy attention | |
| """ | |
| def __init__(self, num_layers, d_model, heads, d_ff, dropout, embeddings, vocab_size): | |
| super().__init__() | |
| # Basic attributes. | |
| self.decoder_type = "transformer" | |
| self.num_layers = num_layers | |
| self.embeddings = embeddings | |
| self.pos_emb = PositionalEncoding(dropout, self.embeddings.embedding_dim) | |
| # Build TransformerDecoder. | |
| self.transformer_layers = nn.ModuleList( | |
| [TransformerDecoderLayer(d_model, heads, d_ff, dropout) for _ in range(num_layers)] | |
| ) | |
| self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) | |
| # forward(input_ids, attention_mask, encoder_hidden_states, encoder_attention_mask) | |
| # def forward(self, input_ids, state, attention_mask=None, memory_lengths=None, | |
| # step=None, cache=None, encoder_attention_mask=None, encoder_hidden_states=None, memory_masks=None): | |
| def forward( | |
| self, | |
| input_ids, | |
| encoder_hidden_states=None, | |
| state=None, | |
| attention_mask=None, | |
| memory_lengths=None, | |
| step=None, | |
| cache=None, | |
| encoder_attention_mask=None, | |
| ): | |
| """ | |
| See :obj:`onmt.modules.RNNDecoderBase.forward()` | |
| memory_bank = encoder_hidden_states | |
| """ | |
| # Name conversion | |
| tgt = input_ids | |
| memory_bank = encoder_hidden_states | |
| memory_mask = encoder_attention_mask | |
| # src_words = state.src | |
| src_words = state.src | |
| src_batch, src_len = src_words.size() | |
| padding_idx = self.embeddings.padding_idx | |
| # Decoder padding mask | |
| tgt_words = tgt | |
| tgt_batch, tgt_len = tgt_words.size() | |
| tgt_pad_mask = tgt_words.data.eq(padding_idx).unsqueeze(1).expand(tgt_batch, tgt_len, tgt_len) | |
| # Encoder padding mask | |
| if memory_mask is not None: | |
| src_len = memory_mask.size(-1) | |
| src_pad_mask = memory_mask.expand(src_batch, tgt_len, src_len) | |
| else: | |
| src_pad_mask = src_words.data.eq(padding_idx).unsqueeze(1).expand(src_batch, tgt_len, src_len) | |
| # Pass through the embeddings | |
| emb = self.embeddings(input_ids) | |
| output = self.pos_emb(emb, step) | |
| assert emb.dim() == 3 # len x batch x embedding_dim | |
| if state.cache is None: | |
| saved_inputs = [] | |
| for i in range(self.num_layers): | |
| prev_layer_input = None | |
| if state.cache is None: | |
| if state.previous_input is not None: | |
| prev_layer_input = state.previous_layer_inputs[i] | |
| output, all_input = self.transformer_layers[i]( | |
| output, | |
| memory_bank, | |
| src_pad_mask, | |
| tgt_pad_mask, | |
| previous_input=prev_layer_input, | |
| layer_cache=state.cache["layer_{}".format(i)] if state.cache is not None else None, | |
| step=step, | |
| ) | |
| if state.cache is None: | |
| saved_inputs.append(all_input) | |
| if state.cache is None: | |
| saved_inputs = torch.stack(saved_inputs) | |
| output = self.layer_norm(output) | |
| if state.cache is None: | |
| state = state.update_state(tgt, saved_inputs) | |
| # Decoders in transformers return a tuple. Beam search will fail | |
| # if we don't follow this convention. | |
| return output, state # , state | |
| def init_decoder_state(self, src, memory_bank, with_cache=False): | |
| """Init decoder state""" | |
| state = TransformerDecoderState(src) | |
| if with_cache: | |
| state._init_cache(memory_bank, self.num_layers) | |
| return state | |
| class PositionalEncoding(nn.Module): | |
| def __init__(self, dropout, dim, max_len=5000): | |
| pe = torch.zeros(max_len, dim) | |
| position = torch.arange(0, max_len).unsqueeze(1) | |
| div_term = torch.exp((torch.arange(0, dim, 2, dtype=torch.float) * -(math.log(10000.0) / dim))) | |
| pe[:, 0::2] = torch.sin(position.float() * div_term) | |
| pe[:, 1::2] = torch.cos(position.float() * div_term) | |
| pe = pe.unsqueeze(0) | |
| super().__init__() | |
| self.register_buffer("pe", pe) | |
| self.dropout = nn.Dropout(p=dropout) | |
| self.dim = dim | |
| def forward(self, emb, step=None): | |
| emb = emb * math.sqrt(self.dim) | |
| if step: | |
| emb = emb + self.pe[:, step][:, None, :] | |
| else: | |
| emb = emb + self.pe[:, : emb.size(1)] | |
| emb = self.dropout(emb) | |
| return emb | |
| def get_emb(self, emb): | |
| return self.pe[:, : emb.size(1)] | |
| class TransformerDecoderLayer(nn.Module): | |
| """ | |
| Args: | |
| d_model (int): the dimension of keys/values/queries in | |
| MultiHeadedAttention, also the input size of | |
| the first-layer of the PositionwiseFeedForward. | |
| heads (int): the number of heads for MultiHeadedAttention. | |
| d_ff (int): the second-layer of the PositionwiseFeedForward. | |
| dropout (float): dropout probability(0-1.0). | |
| self_attn_type (string): type of self-attention scaled-dot, average | |
| """ | |
| def __init__(self, d_model, heads, d_ff, dropout): | |
| super().__init__() | |
| self.self_attn = MultiHeadedAttention(heads, d_model, dropout=dropout) | |
| self.context_attn = MultiHeadedAttention(heads, d_model, dropout=dropout) | |
| self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout) | |
| self.layer_norm_1 = nn.LayerNorm(d_model, eps=1e-6) | |
| self.layer_norm_2 = nn.LayerNorm(d_model, eps=1e-6) | |
| self.drop = nn.Dropout(dropout) | |
| mask = self._get_attn_subsequent_mask(MAX_SIZE) | |
| # Register self.mask as a saved_state in TransformerDecoderLayer, so | |
| # it gets TransformerDecoderLayer's cuda behavior automatically. | |
| self.register_buffer("mask", mask) | |
| def forward( | |
| self, | |
| inputs, | |
| memory_bank, | |
| src_pad_mask, | |
| tgt_pad_mask, | |
| previous_input=None, | |
| layer_cache=None, | |
| step=None, | |
| ): | |
| """ | |
| Args: | |
| inputs (`FloatTensor`): `[batch_size x 1 x model_dim]` | |
| memory_bank (`FloatTensor`): `[batch_size x src_len x model_dim]` | |
| src_pad_mask (`LongTensor`): `[batch_size x 1 x src_len]` | |
| tgt_pad_mask (`LongTensor`): `[batch_size x 1 x 1]` | |
| Returns: | |
| (`FloatTensor`, `FloatTensor`, `FloatTensor`): | |
| * output `[batch_size x 1 x model_dim]` | |
| * attn `[batch_size x 1 x src_len]` | |
| * all_input `[batch_size x current_step x model_dim]` | |
| """ | |
| dec_mask = torch.gt(tgt_pad_mask + self.mask[:, : tgt_pad_mask.size(1), : tgt_pad_mask.size(1)], 0) | |
| input_norm = self.layer_norm_1(inputs) | |
| all_input = input_norm | |
| if previous_input is not None: | |
| all_input = torch.cat((previous_input, input_norm), dim=1) | |
| dec_mask = None | |
| query = self.self_attn( | |
| all_input, | |
| all_input, | |
| input_norm, | |
| mask=dec_mask, | |
| layer_cache=layer_cache, | |
| type="self", | |
| ) | |
| query = self.drop(query) + inputs | |
| query_norm = self.layer_norm_2(query) | |
| mid = self.context_attn( | |
| memory_bank, | |
| memory_bank, | |
| query_norm, | |
| mask=src_pad_mask, | |
| layer_cache=layer_cache, | |
| type="context", | |
| ) | |
| output = self.feed_forward(self.drop(mid) + query) | |
| return output, all_input | |
| # return output | |
| def _get_attn_subsequent_mask(self, size): | |
| """ | |
| Get an attention mask to avoid using the subsequent info. | |
| Args: | |
| size: int | |
| Returns: | |
| (`LongTensor`): | |
| * subsequent_mask `[1 x size x size]` | |
| """ | |
| attn_shape = (1, size, size) | |
| subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype("uint8") | |
| subsequent_mask = torch.from_numpy(subsequent_mask) | |
| return subsequent_mask | |
| class MultiHeadedAttention(nn.Module): | |
| """ | |
| Multi-Head Attention module from | |
| "Attention is All You Need" | |
| :cite:`DBLP:journals/corr/VaswaniSPUJGKP17`. | |
| Similar to standard `dot` attention but uses | |
| multiple attention distributions simulataneously | |
| to select relevant items. | |
| .. mermaid:: | |
| graph BT | |
| A[key] | |
| B[value] | |
| C[query] | |
| O[output] | |
| subgraph Attn | |
| D[Attn 1] | |
| E[Attn 2] | |
| F[Attn N] | |
| end | |
| A --> D | |
| C --> D | |
| A --> E | |
| C --> E | |
| A --> F | |
| C --> F | |
| D --> O | |
| E --> O | |
| F --> O | |
| B --> O | |
| Also includes several additional tricks. | |
| Args: | |
| head_count (int): number of parallel heads | |
| model_dim (int): the dimension of keys/values/queries, | |
| must be divisible by head_count | |
| dropout (float): dropout parameter | |
| """ | |
| def __init__(self, head_count, model_dim, dropout=0.1, use_final_linear=True): | |
| assert model_dim % head_count == 0 | |
| self.dim_per_head = model_dim // head_count | |
| self.model_dim = model_dim | |
| super().__init__() | |
| self.head_count = head_count | |
| self.linear_keys = nn.Linear(model_dim, head_count * self.dim_per_head) | |
| self.linear_values = nn.Linear(model_dim, head_count * self.dim_per_head) | |
| self.linear_query = nn.Linear(model_dim, head_count * self.dim_per_head) | |
| self.softmax = nn.Softmax(dim=-1) | |
| self.dropout = nn.Dropout(dropout) | |
| self.use_final_linear = use_final_linear | |
| if self.use_final_linear: | |
| self.final_linear = nn.Linear(model_dim, model_dim) | |
| def forward( | |
| self, | |
| key, | |
| value, | |
| query, | |
| mask=None, | |
| layer_cache=None, | |
| type=None, | |
| predefined_graph_1=None, | |
| ): | |
| """ | |
| Compute the context vector and the attention vectors. | |
| Args: | |
| key (`FloatTensor`): set of `key_len` | |
| key vectors `[batch, key_len, dim]` | |
| value (`FloatTensor`): set of `key_len` | |
| value vectors `[batch, key_len, dim]` | |
| query (`FloatTensor`): set of `query_len` | |
| query vectors `[batch, query_len, dim]` | |
| mask: binary mask indicating which keys have | |
| non-zero attention `[batch, query_len, key_len]` | |
| Returns: | |
| (`FloatTensor`, `FloatTensor`) : | |
| * output context vectors `[batch, query_len, dim]` | |
| * one of the attention vectors `[batch, query_len, key_len]` | |
| """ | |
| batch_size = key.size(0) | |
| dim_per_head = self.dim_per_head | |
| head_count = self.head_count | |
| def shape(x): | |
| """projection""" | |
| return x.view(batch_size, -1, head_count, dim_per_head).transpose(1, 2) | |
| def unshape(x): | |
| """compute context""" | |
| return x.transpose(1, 2).contiguous().view(batch_size, -1, head_count * dim_per_head) | |
| # 1) Project key, value, and query. | |
| if layer_cache is not None: | |
| if type == "self": | |
| query, key, value = ( | |
| self.linear_query(query), | |
| self.linear_keys(query), | |
| self.linear_values(query), | |
| ) | |
| key = shape(key) | |
| value = shape(value) | |
| if layer_cache is not None: | |
| device = key.device | |
| if layer_cache["self_keys"] is not None: | |
| key = torch.cat((layer_cache["self_keys"].to(device), key), dim=2) | |
| if layer_cache["self_values"] is not None: | |
| value = torch.cat((layer_cache["self_values"].to(device), value), dim=2) | |
| layer_cache["self_keys"] = key | |
| layer_cache["self_values"] = value | |
| elif type == "context": | |
| query = self.linear_query(query) | |
| if layer_cache is not None: | |
| if layer_cache["memory_keys"] is None: | |
| key, value = self.linear_keys(key), self.linear_values(value) | |
| key = shape(key) | |
| value = shape(value) | |
| else: | |
| key, value = ( | |
| layer_cache["memory_keys"], | |
| layer_cache["memory_values"], | |
| ) | |
| layer_cache["memory_keys"] = key | |
| layer_cache["memory_values"] = value | |
| else: | |
| key, value = self.linear_keys(key), self.linear_values(value) | |
| key = shape(key) | |
| value = shape(value) | |
| else: | |
| key = self.linear_keys(key) | |
| value = self.linear_values(value) | |
| query = self.linear_query(query) | |
| key = shape(key) | |
| value = shape(value) | |
| query = shape(query) | |
| # 2) Calculate and scale scores. | |
| query = query / math.sqrt(dim_per_head) | |
| scores = torch.matmul(query, key.transpose(2, 3)) | |
| if mask is not None: | |
| mask = mask.unsqueeze(1).expand_as(scores) | |
| scores = scores.masked_fill(mask, -1e18) | |
| # 3) Apply attention dropout and compute context vectors. | |
| attn = self.softmax(scores) | |
| if predefined_graph_1 is not None: | |
| attn_masked = attn[:, -1] * predefined_graph_1 | |
| attn_masked = attn_masked / (torch.sum(attn_masked, 2).unsqueeze(2) + 1e-9) | |
| attn = torch.cat([attn[:, :-1], attn_masked.unsqueeze(1)], 1) | |
| drop_attn = self.dropout(attn) | |
| if self.use_final_linear: | |
| context = unshape(torch.matmul(drop_attn, value)) | |
| output = self.final_linear(context) | |
| return output | |
| else: | |
| context = torch.matmul(drop_attn, value) | |
| return context | |
| class DecoderState(object): | |
| """Interface for grouping together the current state of a recurrent | |
| decoder. In the simplest case just represents the hidden state of | |
| the model. But can also be used for implementing various forms of | |
| input_feeding and non-recurrent models. | |
| Modules need to implement this to utilize beam search decoding. | |
| """ | |
| def detach(self): | |
| """Need to document this""" | |
| self.hidden = tuple([_.detach() for _ in self.hidden]) | |
| self.input_feed = self.input_feed.detach() | |
| def beam_update(self, idx, positions, beam_size): | |
| """Need to document this""" | |
| for e in self._all: | |
| sizes = e.size() | |
| br = sizes[1] | |
| if len(sizes) == 3: | |
| sent_states = e.view(sizes[0], beam_size, br // beam_size, sizes[2])[:, :, idx] | |
| else: | |
| sent_states = e.view(sizes[0], beam_size, br // beam_size, sizes[2], sizes[3])[:, :, idx] | |
| sent_states.data.copy_(sent_states.data.index_select(1, positions)) | |
| def map_batch_fn(self, fn): | |
| raise NotImplementedError() | |
| class TransformerDecoderState(DecoderState): | |
| """Transformer Decoder state base class""" | |
| def __init__(self, src): | |
| """ | |
| Args: | |
| src (FloatTensor): a sequence of source words tensors | |
| with optional feature tensors, of size (len x batch). | |
| """ | |
| self.src = src | |
| self.previous_input = None | |
| self.previous_layer_inputs = None | |
| self.cache = None | |
| def _all(self): | |
| """ | |
| Contains attributes that need to be updated in self.beam_update(). | |
| """ | |
| if self.previous_input is not None and self.previous_layer_inputs is not None: | |
| return (self.previous_input, self.previous_layer_inputs, self.src) | |
| else: | |
| return (self.src,) | |
| def detach(self): | |
| if self.previous_input is not None: | |
| self.previous_input = self.previous_input.detach() | |
| if self.previous_layer_inputs is not None: | |
| self.previous_layer_inputs = self.previous_layer_inputs.detach() | |
| self.src = self.src.detach() | |
| def update_state(self, new_input, previous_layer_inputs): | |
| state = TransformerDecoderState(self.src) | |
| state.previous_input = new_input | |
| state.previous_layer_inputs = previous_layer_inputs | |
| return state | |
| def _init_cache(self, memory_bank, num_layers): | |
| self.cache = {} | |
| for l in range(num_layers): | |
| layer_cache = {"memory_keys": None, "memory_values": None} | |
| layer_cache["self_keys"] = None | |
| layer_cache["self_values"] = None | |
| self.cache["layer_{}".format(l)] = layer_cache | |
| def repeat_beam_size_times(self, beam_size): | |
| """Repeat beam_size times along batch dimension.""" | |
| self.src = self.src.data.repeat(1, beam_size, 1) | |
| def map_batch_fn(self, fn): | |
| def _recursive_map(struct, batch_dim=0): | |
| for k, v in struct.items(): | |
| if v is not None: | |
| if isinstance(v, dict): | |
| _recursive_map(v) | |
| else: | |
| struct[k] = fn(v, batch_dim) | |
| self.src = fn(self.src, 0) | |
| if self.cache is not None: | |
| _recursive_map(self.cache) | |
| def gelu(x): | |
| return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) | |
| class PositionwiseFeedForward(nn.Module): | |
| """A two-layer Feed-Forward-Network with residual layer norm. | |
| Args: | |
| d_model (int): the size of input for the first-layer of the FFN. | |
| d_ff (int): the hidden layer size of the second-layer | |
| of the FNN. | |
| dropout (float): dropout probability in :math:`[0, 1)`. | |
| """ | |
| def __init__(self, d_model, d_ff, dropout=0.1): | |
| super().__init__() | |
| self.w_1 = nn.Linear(d_model, d_ff) | |
| self.w_2 = nn.Linear(d_ff, d_model) | |
| self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) | |
| self.actv = gelu | |
| self.dropout_1 = nn.Dropout(dropout) | |
| self.dropout_2 = nn.Dropout(dropout) | |
| def forward(self, x): | |
| inter = self.dropout_1(self.actv(self.w_1(self.layer_norm(x)))) | |
| output = self.dropout_2(self.w_2(inter)) | |
| return output + x | |
| # | |
| # TRANSLATOR | |
| # The following code is used to generate summaries using the | |
| # pre-trained weights and beam search. | |
| # | |
| def build_predictor(args, tokenizer, symbols, model, logger=None): | |
| # we should be able to refactor the global scorer a lot | |
| scorer = GNMTGlobalScorer(args.alpha, length_penalty="wu") | |
| translator = Translator(args, model, tokenizer, symbols, global_scorer=scorer, logger=logger) | |
| return translator | |
| class GNMTGlobalScorer(object): | |
| """ | |
| NMT re-ranking score from | |
| "Google's Neural Machine Translation System" :cite:`wu2016google` | |
| Args: | |
| alpha (float): length parameter | |
| beta (float): coverage parameter | |
| """ | |
| def __init__(self, alpha, length_penalty): | |
| self.alpha = alpha | |
| penalty_builder = PenaltyBuilder(length_penalty) | |
| self.length_penalty = penalty_builder.length_penalty() | |
| def score(self, beam, logprobs): | |
| """ | |
| Rescores a prediction based on penalty functions | |
| """ | |
| normalized_probs = self.length_penalty(beam, logprobs, self.alpha) | |
| return normalized_probs | |
| class PenaltyBuilder(object): | |
| """ | |
| Returns the Length and Coverage Penalty function for Beam Search. | |
| Args: | |
| length_pen (str): option name of length pen | |
| cov_pen (str): option name of cov pen | |
| """ | |
| def __init__(self, length_pen): | |
| self.length_pen = length_pen | |
| def length_penalty(self): | |
| if self.length_pen == "wu": | |
| return self.length_wu | |
| elif self.length_pen == "avg": | |
| return self.length_average | |
| else: | |
| return self.length_none | |
| """ | |
| Below are all the different penalty terms implemented so far | |
| """ | |
| def length_wu(self, beam, logprobs, alpha=0.0): | |
| """ | |
| NMT length re-ranking score from | |
| "Google's Neural Machine Translation System" :cite:`wu2016google`. | |
| """ | |
| modifier = ((5 + len(beam.next_ys)) ** alpha) / ((5 + 1) ** alpha) | |
| return logprobs / modifier | |
| def length_average(self, beam, logprobs, alpha=0.0): | |
| """ | |
| Returns the average probability of tokens in a sequence. | |
| """ | |
| return logprobs / len(beam.next_ys) | |
| def length_none(self, beam, logprobs, alpha=0.0, beta=0.0): | |
| """ | |
| Returns unmodified scores. | |
| """ | |
| return logprobs | |
| class Translator(object): | |
| """ | |
| Uses a model to translate a batch of sentences. | |
| Args: | |
| model (:obj:`onmt.modules.NMTModel`): | |
| NMT model to use for translation | |
| fields (dict of Fields): data fields | |
| beam_size (int): size of beam to use | |
| n_best (int): number of translations produced | |
| max_length (int): maximum length output to produce | |
| global_scores (:obj:`GlobalScorer`): | |
| object to rescore final translations | |
| copy_attn (bool): use copy attention during translation | |
| beam_trace (bool): trace beam search for debugging | |
| logger(logging.Logger): logger. | |
| """ | |
| def __init__(self, args, model, vocab, symbols, global_scorer=None, logger=None): | |
| self.logger = logger | |
| self.args = args | |
| self.model = model | |
| self.generator = self.model.generator | |
| self.vocab = vocab | |
| self.symbols = symbols | |
| self.start_token = symbols["BOS"] | |
| self.end_token = symbols["EOS"] | |
| self.global_scorer = global_scorer | |
| self.beam_size = args.beam_size | |
| self.min_length = args.min_length | |
| self.max_length = args.max_length | |
| def translate(self, batch, step, attn_debug=False): | |
| """Generates summaries from one batch of data.""" | |
| self.model.eval() | |
| with torch.no_grad(): | |
| batch_data = self.translate_batch(batch) | |
| translations = self.from_batch(batch_data) | |
| return translations | |
| def translate_batch(self, batch, fast=False): | |
| """ | |
| Translate a batch of sentences. | |
| Mostly a wrapper around :obj:`Beam`. | |
| Args: | |
| batch (:obj:`Batch`): a batch from a dataset object | |
| fast (bool): enables fast beam search (may not support all features) | |
| """ | |
| with torch.no_grad(): | |
| return self._fast_translate_batch(batch, self.max_length, min_length=self.min_length) | |
| # Where the beam search lives | |
| # I have no idea why it is being called from the method above | |
| def _fast_translate_batch(self, batch, max_length, min_length=0): | |
| """Beam Search using the encoder inputs contained in `batch`.""" | |
| # The batch object is funny | |
| # Instead of just looking at the size of the arguments we encapsulate | |
| # a size argument. | |
| # Where is it defined? | |
| beam_size = self.beam_size | |
| batch_size = batch.batch_size | |
| src = batch.src | |
| segs = batch.segs | |
| mask_src = batch.mask_src | |
| src_features = self.model.bert(src, segs, mask_src) | |
| dec_states = self.model.decoder.init_decoder_state(src, src_features, with_cache=True) | |
| device = src_features.device | |
| # Tile states and memory beam_size times. | |
| dec_states.map_batch_fn(lambda state, dim: tile(state, beam_size, dim=dim)) | |
| src_features = tile(src_features, beam_size, dim=0) | |
| batch_offset = torch.arange(batch_size, dtype=torch.long, device=device) | |
| beam_offset = torch.arange(0, batch_size * beam_size, step=beam_size, dtype=torch.long, device=device) | |
| alive_seq = torch.full([batch_size * beam_size, 1], self.start_token, dtype=torch.long, device=device) | |
| # Give full probability to the first beam on the first step. | |
| topk_log_probs = torch.tensor([0.0] + [float("-inf")] * (beam_size - 1), device=device).repeat(batch_size) | |
| # Structure that holds finished hypotheses. | |
| hypotheses = [[] for _ in range(batch_size)] # noqa: F812 | |
| results = {} | |
| results["predictions"] = [[] for _ in range(batch_size)] # noqa: F812 | |
| results["scores"] = [[] for _ in range(batch_size)] # noqa: F812 | |
| results["gold_score"] = [0] * batch_size | |
| results["batch"] = batch | |
| for step in range(max_length): | |
| decoder_input = alive_seq[:, -1].view(1, -1) | |
| # Decoder forward. | |
| decoder_input = decoder_input.transpose(0, 1) | |
| dec_out, dec_states = self.model.decoder(decoder_input, src_features, dec_states, step=step) | |
| # Generator forward. | |
| log_probs = self.generator(dec_out.transpose(0, 1).squeeze(0)) | |
| vocab_size = log_probs.size(-1) | |
| if step < min_length: | |
| log_probs[:, self.end_token] = -1e20 | |
| # Multiply probs by the beam probability. | |
| log_probs += topk_log_probs.view(-1).unsqueeze(1) | |
| alpha = self.global_scorer.alpha | |
| length_penalty = ((5.0 + (step + 1)) / 6.0) ** alpha | |
| # Flatten probs into a list of possibilities. | |
| curr_scores = log_probs / length_penalty | |
| if self.args.block_trigram: | |
| cur_len = alive_seq.size(1) | |
| if cur_len > 3: | |
| for i in range(alive_seq.size(0)): | |
| fail = False | |
| words = [int(w) for w in alive_seq[i]] | |
| words = [self.vocab.ids_to_tokens[w] for w in words] | |
| words = " ".join(words).replace(" ##", "").split() | |
| if len(words) <= 3: | |
| continue | |
| trigrams = [(words[i - 1], words[i], words[i + 1]) for i in range(1, len(words) - 1)] | |
| trigram = tuple(trigrams[-1]) | |
| if trigram in trigrams[:-1]: | |
| fail = True | |
| if fail: | |
| curr_scores[i] = -10e20 | |
| curr_scores = curr_scores.reshape(-1, beam_size * vocab_size) | |
| topk_scores, topk_ids = curr_scores.topk(beam_size, dim=-1) | |
| # Recover log probs. | |
| topk_log_probs = topk_scores * length_penalty | |
| # Resolve beam origin and true word ids. | |
| topk_beam_index = topk_ids.div(vocab_size) | |
| topk_ids = topk_ids.fmod(vocab_size) | |
| # Map beam_index to batch_index in the flat representation. | |
| batch_index = topk_beam_index + beam_offset[: topk_beam_index.size(0)].unsqueeze(1) | |
| select_indices = batch_index.view(-1) | |
| # Append last prediction. | |
| alive_seq = torch.cat([alive_seq.index_select(0, select_indices), topk_ids.view(-1, 1)], -1) | |
| is_finished = topk_ids.eq(self.end_token) | |
| if step + 1 == max_length: | |
| is_finished.fill_(1) | |
| # End condition is top beam is finished. | |
| end_condition = is_finished[:, 0].eq(1) | |
| # Save finished hypotheses. | |
| if is_finished.any(): | |
| predictions = alive_seq.view(-1, beam_size, alive_seq.size(-1)) | |
| for i in range(is_finished.size(0)): | |
| b = batch_offset[i] | |
| if end_condition[i]: | |
| is_finished[i].fill_(1) | |
| finished_hyp = is_finished[i].nonzero().view(-1) | |
| # Store finished hypotheses for this batch. | |
| for j in finished_hyp: | |
| hypotheses[b].append((topk_scores[i, j], predictions[i, j, 1:])) | |
| # If the batch reached the end, save the n_best hypotheses. | |
| if end_condition[i]: | |
| best_hyp = sorted(hypotheses[b], key=lambda x: x[0], reverse=True) | |
| score, pred = best_hyp[0] | |
| results["scores"][b].append(score) | |
| results["predictions"][b].append(pred) | |
| non_finished = end_condition.eq(0).nonzero().view(-1) | |
| # If all sentences are translated, no need to go further. | |
| if len(non_finished) == 0: | |
| break | |
| # Remove finished batches for the next step. | |
| topk_log_probs = topk_log_probs.index_select(0, non_finished) | |
| batch_index = batch_index.index_select(0, non_finished) | |
| batch_offset = batch_offset.index_select(0, non_finished) | |
| alive_seq = predictions.index_select(0, non_finished).view(-1, alive_seq.size(-1)) | |
| # Reorder states. | |
| select_indices = batch_index.view(-1) | |
| src_features = src_features.index_select(0, select_indices) | |
| dec_states.map_batch_fn(lambda state, dim: state.index_select(dim, select_indices)) | |
| return results | |
| def from_batch(self, translation_batch): | |
| batch = translation_batch["batch"] | |
| assert len(translation_batch["gold_score"]) == len(translation_batch["predictions"]) | |
| batch_size = batch.batch_size | |
| preds, _, _, tgt_str, src = ( | |
| translation_batch["predictions"], | |
| translation_batch["scores"], | |
| translation_batch["gold_score"], | |
| batch.tgt_str, | |
| batch.src, | |
| ) | |
| translations = [] | |
| for b in range(batch_size): | |
| pred_sents = self.vocab.convert_ids_to_tokens([int(n) for n in preds[b][0]]) | |
| pred_sents = " ".join(pred_sents).replace(" ##", "") | |
| gold_sent = " ".join(tgt_str[b].split()) | |
| raw_src = [self.vocab.ids_to_tokens[int(t)] for t in src[b]][:500] | |
| raw_src = " ".join(raw_src) | |
| translation = (pred_sents, gold_sent, raw_src) | |
| translations.append(translation) | |
| return translations | |
| def tile(x, count, dim=0): | |
| """ | |
| Tiles x on dimension dim count times. | |
| """ | |
| perm = list(range(len(x.size()))) | |
| if dim != 0: | |
| perm[0], perm[dim] = perm[dim], perm[0] | |
| x = x.permute(perm).contiguous() | |
| out_size = list(x.size()) | |
| out_size[0] *= count | |
| batch = x.size(0) | |
| x = x.view(batch, -1).transpose(0, 1).repeat(count, 1).transpose(0, 1).contiguous().view(*out_size) | |
| if dim != 0: | |
| x = x.permute(perm).contiguous() | |
| return x | |
| # | |
| # Optimizer for training. We keep this here in case we want to add | |
| # a finetuning script. | |
| # | |
| class BertSumOptimizer(object): | |
| """Specific optimizer for BertSum. | |
| As described in [1], the authors fine-tune BertSum for abstractive | |
| summarization using two Adam Optimizers with different warm-up steps and | |
| learning rate. They also use a custom learning rate scheduler. | |
| [1] Liu, Yang, and Mirella Lapata. "Text summarization with pretrained encoders." | |
| arXiv preprint arXiv:1908.08345 (2019). | |
| """ | |
| def __init__(self, model, lr, warmup_steps, beta_1=0.99, beta_2=0.999, eps=1e-8): | |
| self.encoder = model.encoder | |
| self.decoder = model.decoder | |
| self.lr = lr | |
| self.warmup_steps = warmup_steps | |
| self.optimizers = { | |
| "encoder": torch.optim.Adam( | |
| model.encoder.parameters(), | |
| lr=lr["encoder"], | |
| betas=(beta_1, beta_2), | |
| eps=eps, | |
| ), | |
| "decoder": torch.optim.Adam( | |
| model.decoder.parameters(), | |
| lr=lr["decoder"], | |
| betas=(beta_1, beta_2), | |
| eps=eps, | |
| ), | |
| } | |
| self._step = 0 | |
| self.current_learning_rates = {} | |
| def _update_rate(self, stack): | |
| return self.lr[stack] * min(self._step ** (-0.5), self._step * self.warmup_steps[stack] ** (-1.5)) | |
| def zero_grad(self): | |
| self.optimizer_decoder.zero_grad() | |
| self.optimizer_encoder.zero_grad() | |
| def step(self): | |
| self._step += 1 | |
| for stack, optimizer in self.optimizers.items(): | |
| new_rate = self._update_rate(stack) | |
| for param_group in optimizer.param_groups: | |
| param_group["lr"] = new_rate | |
| optimizer.step() | |
| self.current_learning_rates[stack] = new_rate | |