# -*- coding: utf-8 -*- # @Time    : 2023/4/05 18:02 下午 # @Author  : NuoChen # @File    : code_generation.py from transformers import PLBartTokenizer, PLBartForSequenceClassification, PLBartConfig, PLBartForConditionalGeneration from typing import Any, Dict, List, Optional, Tuple, Union from transformers.modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, Seq2SeqLMOutput, Seq2SeqModelOutput, Seq2SeqSequenceClassifierOutput, ) import torch from torch import nn from typing import Optional, List, Union, Tuple from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss from transformers import RobertaModel, RobertaPreTrainedModel from transformers.models.plbart.modeling_plbart import PLBartPreTrainedModel, PLBartModel from transformers.models.plbart.configuration_plbart import PLBartConfig def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int): """ Shift input ids one token to the right, and wrap the last non pad token (the token) Note that MBart does not have a single `decoder_start_token_id` in contrast to other Bart-like models. """ prev_output_tokens = input_ids.clone() if pad_token_id is None: raise ValueError("self.model.config.pad_token_id has to be defined.") # replace possible -100 values in labels by `pad_token_id` prev_output_tokens.masked_fill_(prev_output_tokens == -100, pad_token_id) index_of_eos = (prev_output_tokens.ne(pad_token_id).sum(dim=1) - 1).unsqueeze(-1) decoder_start_tokens = prev_output_tokens.gather(1, index_of_eos).squeeze() prev_output_tokens[:, 1:] = prev_output_tokens[:, :-1].clone() prev_output_tokens[:, 0] = decoder_start_tokens return prev_output_tokens class PLBARTForCodeGeneration(PLBartPreTrainedModel): base_model_prefix = "model" _keys_to_ignore_on_load_missing = [ r"final_logits_bias", r"encoder.version", r"decoder.version", r"lm_head.weight", ] def __init__(self, config: PLBartConfig): super().__init__(config) self.model = PLBartModel(config) self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings))) self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False) self.init_weights() def get_encoder(self): return self.model.get_encoder() def get_decoder(self): return self.model.get_decoder() def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding: new_embeddings = super().resize_token_embeddings(new_num_tokens) self._resize_final_logits_bias(new_num_tokens) return new_embeddings def _resize_final_logits_bias(self, new_num_tokens: int) -> None: old_num_tokens = self.final_logits_bias.shape[-1] if new_num_tokens <= old_num_tokens: new_bias = self.final_logits_bias[:, :new_num_tokens] else: extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device) new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1) self.register_buffer("final_logits_bias", new_bias) def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None, decoder_input_ids: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, decoder_head_mask: Optional[torch.LongTensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, encoder_outputs: Optional[List[torch.FloatTensor]] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, decoder_inputs_embeds=None, labels: Optional[torch.Tensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. Returns: """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict if labels is not None: if decoder_input_ids is None: decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id) outputs = self.model( input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, encoder_outputs=encoder_outputs, decoder_attention_mask=decoder_attention_mask, head_mask=head_mask, decoder_head_mask=decoder_head_mask, cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias masked_lm_loss = None if labels is not None: loss_fct = CrossEntropyLoss() masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) if not return_dict: output = (lm_logits,) + outputs[1:] return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output return Seq2SeqLMOutput( loss=masked_lm_loss, logits=lm_logits, past_key_values=outputs.past_key_values, decoder_hidden_states=outputs.decoder_hidden_states, decoder_attentions=outputs.decoder_attentions, cross_attentions=outputs.cross_attentions, encoder_last_hidden_state=outputs.encoder_last_hidden_state, encoder_hidden_states=outputs.encoder_hidden_states, encoder_attentions=outputs.encoder_attentions, ) def prepare_inputs_for_generation( self, decoder_input_ids: torch.LongTensor, past: Optional[List[torch.FloatTensor]] = None, attention_mask: Optional[torch.LongTensor] = None, head_mask: Optional[torch.Tensor] = None, decoder_head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, use_cache: Optional[bool] = None, encoder_outputs: Optional[List[torch.FloatTensor]] = None, **kwargs # TODO: Check if this is needed. It is unused? ) -> Dict[str, Any]: # cut decoder_input_ids if past is used if past is not None: decoder_input_ids = decoder_input_ids[:, -1:] return { "input_ids": None, # encoder_outputs is defined. input_ids not needed "encoder_outputs": encoder_outputs, "past_key_values": past, "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, "head_mask": head_mask, "decoder_head_mask": decoder_head_mask, "cross_attn_head_mask": cross_attn_head_mask, "use_cache": use_cache, # change this to avoid caching (presumably for debugging) } def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return shift_tokens_right(labels, self.config.pad_token_id) @staticmethod def _reorder_cache(past, beam_idx): reordered_past = () for layer_past in past: # cached cross_attention states don't have to be reordered -> they are always the same reordered_past += ( tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], ) return reordered_past