Spaces:
Sleeping
Sleeping
# -*- coding: utf-8 -*- | |
# @Time : 2023/4/05 18:02 下午 | |
# @Author : NuoChen | |
# @File : code_generation.py | |
from transformers import PLBartTokenizer, PLBartForSequenceClassification, PLBartConfig, PLBartForConditionalGeneration | |
from typing import Any, Dict, List, Optional, Tuple, Union | |
from transformers.modeling_outputs import ( | |
BaseModelOutput, | |
BaseModelOutputWithPastAndCrossAttentions, | |
CausalLMOutputWithCrossAttentions, | |
Seq2SeqLMOutput, | |
Seq2SeqModelOutput, | |
Seq2SeqSequenceClassifierOutput, | |
) | |
import torch | |
from torch import nn | |
from typing import Optional, List, Union, Tuple | |
from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss | |
from transformers import RobertaModel, RobertaPreTrainedModel | |
from transformers.models.plbart.modeling_plbart import PLBartPreTrainedModel, PLBartModel | |
from transformers.models.plbart.configuration_plbart import PLBartConfig | |
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int): | |
""" | |
Shift input ids one token to the right, and wrap the last non pad token (the <LID> token) Note that MBart does not | |
have a single `decoder_start_token_id` in contrast to other Bart-like models. | |
""" | |
prev_output_tokens = input_ids.clone() | |
if pad_token_id is None: | |
raise ValueError("self.model.config.pad_token_id has to be defined.") | |
# replace possible -100 values in labels by `pad_token_id` | |
prev_output_tokens.masked_fill_(prev_output_tokens == -100, pad_token_id) | |
index_of_eos = (prev_output_tokens.ne(pad_token_id).sum(dim=1) - 1).unsqueeze(-1) | |
decoder_start_tokens = prev_output_tokens.gather(1, index_of_eos).squeeze() | |
prev_output_tokens[:, 1:] = prev_output_tokens[:, :-1].clone() | |
prev_output_tokens[:, 0] = decoder_start_tokens | |
return prev_output_tokens | |
class PLBARTForCodeGeneration(PLBartPreTrainedModel): | |
base_model_prefix = "model" | |
_keys_to_ignore_on_load_missing = [ | |
r"final_logits_bias", | |
r"encoder.version", | |
r"decoder.version", | |
r"lm_head.weight", | |
] | |
def __init__(self, config: PLBartConfig): | |
super().__init__(config) | |
self.model = PLBartModel(config) | |
self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings))) | |
self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False) | |
self.init_weights() | |
def get_encoder(self): | |
return self.model.get_encoder() | |
def get_decoder(self): | |
return self.model.get_decoder() | |
def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding: | |
new_embeddings = super().resize_token_embeddings(new_num_tokens) | |
self._resize_final_logits_bias(new_num_tokens) | |
return new_embeddings | |
def _resize_final_logits_bias(self, new_num_tokens: int) -> None: | |
old_num_tokens = self.final_logits_bias.shape[-1] | |
if new_num_tokens <= old_num_tokens: | |
new_bias = self.final_logits_bias[:, :new_num_tokens] | |
else: | |
extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device) | |
new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1) | |
self.register_buffer("final_logits_bias", new_bias) | |
def get_output_embeddings(self): | |
return self.lm_head | |
def set_output_embeddings(self, new_embeddings): | |
self.lm_head = new_embeddings | |
def forward( | |
self, | |
input_ids: Optional[torch.LongTensor] = None, | |
attention_mask: Optional[torch.LongTensor] = None, | |
decoder_input_ids: Optional[torch.LongTensor] = None, | |
decoder_attention_mask: Optional[torch.Tensor] = None, | |
head_mask: Optional[torch.Tensor] = None, | |
decoder_head_mask: Optional[torch.LongTensor] = None, | |
cross_attn_head_mask: Optional[torch.Tensor] = None, | |
encoder_outputs: Optional[List[torch.FloatTensor]] = None, | |
past_key_values: Optional[List[torch.FloatTensor]] = None, | |
inputs_embeds: Optional[torch.FloatTensor] = None, | |
decoder_inputs_embeds=None, | |
labels: Optional[torch.Tensor] = None, | |
use_cache: Optional[bool] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]: | |
r""" | |
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): | |
Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., | |
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored | |
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. | |
Returns: | |
""" | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
if labels is not None: | |
if decoder_input_ids is None: | |
decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id) | |
outputs = self.model( | |
input_ids, | |
attention_mask=attention_mask, | |
decoder_input_ids=decoder_input_ids, | |
encoder_outputs=encoder_outputs, | |
decoder_attention_mask=decoder_attention_mask, | |
head_mask=head_mask, | |
decoder_head_mask=decoder_head_mask, | |
cross_attn_head_mask=cross_attn_head_mask, | |
past_key_values=past_key_values, | |
inputs_embeds=inputs_embeds, | |
decoder_inputs_embeds=decoder_inputs_embeds, | |
use_cache=use_cache, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias | |
masked_lm_loss = None | |
if labels is not None: | |
loss_fct = CrossEntropyLoss() | |
masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) | |
if not return_dict: | |
output = (lm_logits,) + outputs[1:] | |
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output | |
return Seq2SeqLMOutput( | |
loss=masked_lm_loss, | |
logits=lm_logits, | |
past_key_values=outputs.past_key_values, | |
decoder_hidden_states=outputs.decoder_hidden_states, | |
decoder_attentions=outputs.decoder_attentions, | |
cross_attentions=outputs.cross_attentions, | |
encoder_last_hidden_state=outputs.encoder_last_hidden_state, | |
encoder_hidden_states=outputs.encoder_hidden_states, | |
encoder_attentions=outputs.encoder_attentions, | |
) | |
def prepare_inputs_for_generation( | |
self, | |
decoder_input_ids: torch.LongTensor, | |
past: Optional[List[torch.FloatTensor]] = None, | |
attention_mask: Optional[torch.LongTensor] = None, | |
head_mask: Optional[torch.Tensor] = None, | |
decoder_head_mask: Optional[torch.Tensor] = None, | |
cross_attn_head_mask: Optional[torch.Tensor] = None, | |
use_cache: Optional[bool] = None, | |
encoder_outputs: Optional[List[torch.FloatTensor]] = None, | |
**kwargs # TODO: Check if this is needed. It is unused? | |
) -> Dict[str, Any]: | |
# cut decoder_input_ids if past is used | |
if past is not None: | |
decoder_input_ids = decoder_input_ids[:, -1:] | |
return { | |
"input_ids": None, # encoder_outputs is defined. input_ids not needed | |
"encoder_outputs": encoder_outputs, | |
"past_key_values": past, | |
"decoder_input_ids": decoder_input_ids, | |
"attention_mask": attention_mask, | |
"head_mask": head_mask, | |
"decoder_head_mask": decoder_head_mask, | |
"cross_attn_head_mask": cross_attn_head_mask, | |
"use_cache": use_cache, # change this to avoid caching (presumably for debugging) | |
} | |
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): | |
return shift_tokens_right(labels, self.config.pad_token_id) | |
def _reorder_cache(past, beam_idx): | |
reordered_past = () | |
for layer_past in past: | |
# cached cross_attention states don't have to be reordered -> they are always the same | |
reordered_past += ( | |
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], | |
) | |
return reordered_past | |