Bart-gen-arg / src /genie /network.py
adherent's picture
what is the <arg> in <trg>
6c25ddb
raw
history blame contribute delete
No virus
7.31 kB
import torch
from torch import nn
from torch.nn import functional as F
from transformers import (
AdamW,
get_linear_schedule_with_warmup,
BartModel,
)
from transformers.modeling_utils import PreTrainedModel
from transformers.modeling_outputs import Seq2SeqLMOutput
print("network.py")
class BartGen(PreTrainedModel):
def __init__(self, config, tokenizer):
super(BartGen, self).__init__(config)
self.config = config
self.tokenizer = tokenizer
self.transformer = BartModel.from_pretrained('facebook/bart-large')
self.register_buffer("final_logits_bias", torch.zeros((1, self.transformer.shared.num_embeddings)))
def resize_token_embeddings(self):
old_num_tokens = self.transformer.shared.num_embeddings
new_embeddings = self.transformer.resize_token_embeddings(len(self.tokenizer))
self.transformer.shared = new_embeddings
self._resize_final_logits_bias(len(self.tokenizer), old_num_tokens)
self.vocab_size = len(self.tokenizer)
return new_embeddings
def _resize_final_logits_bias(self, new_num_tokens: int, old_num_tokens: int) -> None:
if new_num_tokens <= old_num_tokens:
new_bias = self.final_logits_bias[:, :new_num_tokens]
else:
extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
self.register_buffer("final_logits_bias", new_bias)
def _init_weights(self, module):
""" Initialize the weights """
if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
elif isinstance(module, torch.nn.LayerNorm): # if use apex, this should be FusedLayerNorm
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
def get_encoder(self):
return self.transformer.encoder
def get_output_embeddings(self):
# this method is needed for generation
vocab_size, emb_size = self.transformer.shared.weight.shape
lin_layer = nn.Linear(vocab_size, emb_size, bias=False)
lin_layer.weight.data = self.transformer.shared.weight.data
return lin_layer
def prepare_inputs_for_generation(
self, decoder_input_ids, past, attention_mask, use_cache, encoder_outputs, **kwargs
):
return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs,
"past_key_values": past,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
}
def adjust_logits_during_generation(self, logits, cur_len, max_length):
if cur_len == 1 and self.config.force_bos_token_to_be_generated:
self._force_token_ids_generation(logits, self.config.bos_token_id)
elif cur_len == max_length - 1 and self.config.eos_token_id is not None:
self._force_token_ids_generation(logits, self.config.eos_token_id)
return logits
def _force_token_ids_generation(self, scores, token_id) -> None:
"""force one of token_ids to be generated by setting prob of all other tokens to 0 (logprob=-float("inf"))"""
scores[:, [x for x in range(self.config.vocab_size) if x != token_id]] = -float("inf")
@staticmethod
def _reorder_cache(past, beam_idx):
reordered_past = []
for layer_past in past:
# get the correct batch idx from decoder layer's batch dim for cross and self-attn
layer_past_new = {
attn_key: _reorder_buffer(attn_cache, beam_idx) for attn_key, attn_cache in layer_past.items()
}
reordered_past.append(layer_past_new)
return reordered_past
def forward(self, input_ids,
attention_mask=None,
encoder_outputs=None,
use_cache=False,
past_key_values=None,
decoder_input_ids=None,
decoder_attention_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
task=-1):
# generation
if task==-1:
outputs = self.transformer(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
use_cache=use_cache,
encoder_outputs=encoder_outputs,
past_key_values=past_key_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,)
lm_logits = F.linear(outputs[0], self.transformer.shared.weight, bias=self.final_logits_bias)
masked_lm_loss = None
if not return_dict:
output = (lm_logits,) + outputs[1:]
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
return Seq2SeqLMOutput(
loss=masked_lm_loss,
logits=lm_logits,
past_key_values=outputs.past_key_values,
decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions,
)
#training
elif task==0:
assert(decoder_input_ids!=None)
y_ids = decoder_input_ids[:, :-1]
labels = decoder_input_ids[:, 1:].clone()
labels[labels== self.tokenizer.pad_token_id] = -100
# labels are just decoder_input_ids shifted to the right by 1
outputs = self.transformer(
input_ids,
attention_mask=attention_mask,
decoder_input_ids=y_ids,
decoder_attention_mask=decoder_attention_mask[:, :-1],
use_cache=False,
past_key_values=past_key_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,)
sequence_output = outputs[0]
lm_logits = F.linear(sequence_output, self.transformer.shared.weight, bias=self.final_logits_bias)
outputs = (lm_logits,) + outputs[1:] # Add cache, hidden states and attention if they are here
loss_fct = nn.CrossEntropyLoss()
masked_lm_loss = loss_fct(lm_logits.view(-1, self.vocab_size), labels.view(-1))
outputs = (masked_lm_loss,) + outputs
return outputs