Spaces:
Build error
Build error
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
from transformers import ( | |
AdamW, | |
get_linear_schedule_with_warmup, | |
BartModel, | |
) | |
from transformers.modeling_utils import PreTrainedModel | |
from transformers.modeling_outputs import Seq2SeqLMOutput | |
print("network.py") | |
class BartGen(PreTrainedModel): | |
def __init__(self, config, tokenizer): | |
super(BartGen, self).__init__(config) | |
self.config = config | |
self.tokenizer = tokenizer | |
self.transformer = BartModel.from_pretrained('facebook/bart-large') | |
self.register_buffer("final_logits_bias", torch.zeros((1, self.transformer.shared.num_embeddings))) | |
def resize_token_embeddings(self): | |
old_num_tokens = self.transformer.shared.num_embeddings | |
new_embeddings = self.transformer.resize_token_embeddings(len(self.tokenizer)) | |
self.transformer.shared = new_embeddings | |
self._resize_final_logits_bias(len(self.tokenizer), old_num_tokens) | |
self.vocab_size = len(self.tokenizer) | |
return new_embeddings | |
def _resize_final_logits_bias(self, new_num_tokens: int, old_num_tokens: int) -> None: | |
if new_num_tokens <= old_num_tokens: | |
new_bias = self.final_logits_bias[:, :new_num_tokens] | |
else: | |
extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device) | |
new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1) | |
self.register_buffer("final_logits_bias", new_bias) | |
def _init_weights(self, module): | |
""" Initialize the weights """ | |
if isinstance(module, (nn.Linear, nn.Embedding)): | |
# Slightly different from the TF version which uses truncated_normal for initialization | |
# cf https://github.com/pytorch/pytorch/pull/5617 | |
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) | |
elif isinstance(module, torch.nn.LayerNorm): # if use apex, this should be FusedLayerNorm | |
module.bias.data.zero_() | |
module.weight.data.fill_(1.0) | |
if isinstance(module, nn.Linear) and module.bias is not None: | |
module.bias.data.zero_() | |
def get_encoder(self): | |
return self.transformer.encoder | |
def get_output_embeddings(self): | |
# this method is needed for generation | |
vocab_size, emb_size = self.transformer.shared.weight.shape | |
lin_layer = nn.Linear(vocab_size, emb_size, bias=False) | |
lin_layer.weight.data = self.transformer.shared.weight.data | |
return lin_layer | |
def prepare_inputs_for_generation( | |
self, decoder_input_ids, past, attention_mask, use_cache, encoder_outputs, **kwargs | |
): | |
return { | |
"input_ids": None, # encoder_outputs is defined. input_ids not needed | |
"encoder_outputs": encoder_outputs, | |
"past_key_values": past, | |
"decoder_input_ids": decoder_input_ids, | |
"attention_mask": attention_mask, | |
"use_cache": use_cache, # change this to avoid caching (presumably for debugging) | |
} | |
def adjust_logits_during_generation(self, logits, cur_len, max_length): | |
if cur_len == 1 and self.config.force_bos_token_to_be_generated: | |
self._force_token_ids_generation(logits, self.config.bos_token_id) | |
elif cur_len == max_length - 1 and self.config.eos_token_id is not None: | |
self._force_token_ids_generation(logits, self.config.eos_token_id) | |
return logits | |
def _force_token_ids_generation(self, scores, token_id) -> None: | |
"""force one of token_ids to be generated by setting prob of all other tokens to 0 (logprob=-float("inf"))""" | |
scores[:, [x for x in range(self.config.vocab_size) if x != token_id]] = -float("inf") | |
def _reorder_cache(past, beam_idx): | |
reordered_past = [] | |
for layer_past in past: | |
# get the correct batch idx from decoder layer's batch dim for cross and self-attn | |
layer_past_new = { | |
attn_key: _reorder_buffer(attn_cache, beam_idx) for attn_key, attn_cache in layer_past.items() | |
} | |
reordered_past.append(layer_past_new) | |
return reordered_past | |
def forward(self, input_ids, | |
attention_mask=None, | |
encoder_outputs=None, | |
use_cache=False, | |
past_key_values=None, | |
decoder_input_ids=None, | |
decoder_attention_mask=None, | |
output_attentions=None, | |
output_hidden_states=None, | |
return_dict=None, | |
task=-1): | |
# generation | |
if task==-1: | |
outputs = self.transformer( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
decoder_input_ids=decoder_input_ids, | |
decoder_attention_mask=decoder_attention_mask, | |
use_cache=use_cache, | |
encoder_outputs=encoder_outputs, | |
past_key_values=past_key_values, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict,) | |
lm_logits = F.linear(outputs[0], self.transformer.shared.weight, bias=self.final_logits_bias) | |
masked_lm_loss = None | |
if not return_dict: | |
output = (lm_logits,) + outputs[1:] | |
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output | |
return Seq2SeqLMOutput( | |
loss=masked_lm_loss, | |
logits=lm_logits, | |
past_key_values=outputs.past_key_values, | |
decoder_hidden_states=outputs.decoder_hidden_states, | |
decoder_attentions=outputs.decoder_attentions, | |
encoder_last_hidden_state=outputs.encoder_last_hidden_state, | |
encoder_hidden_states=outputs.encoder_hidden_states, | |
encoder_attentions=outputs.encoder_attentions, | |
) | |
#training | |
elif task==0: | |
assert(decoder_input_ids!=None) | |
y_ids = decoder_input_ids[:, :-1] | |
labels = decoder_input_ids[:, 1:].clone() | |
labels[labels== self.tokenizer.pad_token_id] = -100 | |
# labels are just decoder_input_ids shifted to the right by 1 | |
outputs = self.transformer( | |
input_ids, | |
attention_mask=attention_mask, | |
decoder_input_ids=y_ids, | |
decoder_attention_mask=decoder_attention_mask[:, :-1], | |
use_cache=False, | |
past_key_values=past_key_values, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict,) | |
sequence_output = outputs[0] | |
lm_logits = F.linear(sequence_output, self.transformer.shared.weight, bias=self.final_logits_bias) | |
outputs = (lm_logits,) + outputs[1:] # Add cache, hidden states and attention if they are here | |
loss_fct = nn.CrossEntropyLoss() | |
masked_lm_loss = loss_fct(lm_logits.view(-1, self.vocab_size), labels.view(-1)) | |
outputs = (masked_lm_loss,) + outputs | |
return outputs | |