Spaces:
Build error
Build error
File size: 7,312 Bytes
6c25ddb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 |
import torch
from torch import nn
from torch.nn import functional as F
from transformers import (
AdamW,
get_linear_schedule_with_warmup,
BartModel,
)
from transformers.modeling_utils import PreTrainedModel
from transformers.modeling_outputs import Seq2SeqLMOutput
print("network.py")
class BartGen(PreTrainedModel):
def __init__(self, config, tokenizer):
super(BartGen, self).__init__(config)
self.config = config
self.tokenizer = tokenizer
self.transformer = BartModel.from_pretrained('facebook/bart-large')
self.register_buffer("final_logits_bias", torch.zeros((1, self.transformer.shared.num_embeddings)))
def resize_token_embeddings(self):
old_num_tokens = self.transformer.shared.num_embeddings
new_embeddings = self.transformer.resize_token_embeddings(len(self.tokenizer))
self.transformer.shared = new_embeddings
self._resize_final_logits_bias(len(self.tokenizer), old_num_tokens)
self.vocab_size = len(self.tokenizer)
return new_embeddings
def _resize_final_logits_bias(self, new_num_tokens: int, old_num_tokens: int) -> None:
if new_num_tokens <= old_num_tokens:
new_bias = self.final_logits_bias[:, :new_num_tokens]
else:
extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
self.register_buffer("final_logits_bias", new_bias)
def _init_weights(self, module):
""" Initialize the weights """
if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
elif isinstance(module, torch.nn.LayerNorm): # if use apex, this should be FusedLayerNorm
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
def get_encoder(self):
return self.transformer.encoder
def get_output_embeddings(self):
# this method is needed for generation
vocab_size, emb_size = self.transformer.shared.weight.shape
lin_layer = nn.Linear(vocab_size, emb_size, bias=False)
lin_layer.weight.data = self.transformer.shared.weight.data
return lin_layer
def prepare_inputs_for_generation(
self, decoder_input_ids, past, attention_mask, use_cache, encoder_outputs, **kwargs
):
return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs,
"past_key_values": past,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
}
def adjust_logits_during_generation(self, logits, cur_len, max_length):
if cur_len == 1 and self.config.force_bos_token_to_be_generated:
self._force_token_ids_generation(logits, self.config.bos_token_id)
elif cur_len == max_length - 1 and self.config.eos_token_id is not None:
self._force_token_ids_generation(logits, self.config.eos_token_id)
return logits
def _force_token_ids_generation(self, scores, token_id) -> None:
"""force one of token_ids to be generated by setting prob of all other tokens to 0 (logprob=-float("inf"))"""
scores[:, [x for x in range(self.config.vocab_size) if x != token_id]] = -float("inf")
@staticmethod
def _reorder_cache(past, beam_idx):
reordered_past = []
for layer_past in past:
# get the correct batch idx from decoder layer's batch dim for cross and self-attn
layer_past_new = {
attn_key: _reorder_buffer(attn_cache, beam_idx) for attn_key, attn_cache in layer_past.items()
}
reordered_past.append(layer_past_new)
return reordered_past
def forward(self, input_ids,
attention_mask=None,
encoder_outputs=None,
use_cache=False,
past_key_values=None,
decoder_input_ids=None,
decoder_attention_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
task=-1):
# generation
if task==-1:
outputs = self.transformer(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
use_cache=use_cache,
encoder_outputs=encoder_outputs,
past_key_values=past_key_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,)
lm_logits = F.linear(outputs[0], self.transformer.shared.weight, bias=self.final_logits_bias)
masked_lm_loss = None
if not return_dict:
output = (lm_logits,) + outputs[1:]
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
return Seq2SeqLMOutput(
loss=masked_lm_loss,
logits=lm_logits,
past_key_values=outputs.past_key_values,
decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions,
)
#training
elif task==0:
assert(decoder_input_ids!=None)
y_ids = decoder_input_ids[:, :-1]
labels = decoder_input_ids[:, 1:].clone()
labels[labels== self.tokenizer.pad_token_id] = -100
# labels are just decoder_input_ids shifted to the right by 1
outputs = self.transformer(
input_ids,
attention_mask=attention_mask,
decoder_input_ids=y_ids,
decoder_attention_mask=decoder_attention_mask[:, :-1],
use_cache=False,
past_key_values=past_key_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,)
sequence_output = outputs[0]
lm_logits = F.linear(sequence_output, self.transformer.shared.weight, bias=self.final_logits_bias)
outputs = (lm_logits,) + outputs[1:] # Add cache, hidden states and attention if they are here
loss_fct = nn.CrossEntropyLoss()
masked_lm_loss = loss_fct(lm_logits.view(-1, self.vocab_size), labels.view(-1))
outputs = (masked_lm_loss,) + outputs
return outputs
|