Spaces:
Runtime error
Runtime error
from typing import Optional, List, Union, Tuple | |
import torch | |
import torch.nn as nn | |
import random | |
from torch.nn import CrossEntropyLoss | |
from transformers.utils import ( | |
add_start_docstrings_to_model_forward, | |
add_end_docstrings, | |
replace_return_docstrings | |
) | |
from transformers import AutoModelForSeq2SeqLM | |
from transformers.models.bart.modeling_bart import ( | |
BartForConditionalGeneration, | |
_expand_mask, logger, | |
shift_tokens_right, | |
BartPretrainedModel, | |
BART_INPUTS_DOCSTRING, | |
_CONFIG_FOR_DOC, | |
BART_GENERATION_EXAMPLE, | |
BartModel, | |
BartDecoder | |
) | |
from .adapter import Adapter | |
from transformers.modeling_outputs import ( | |
BaseModelOutputWithPastAndCrossAttentions, | |
Seq2SeqModelOutput, | |
BaseModelOutput, | |
Seq2SeqLMOutput | |
) | |
class KeyBartAdapter(BartForConditionalGeneration): | |
def __init__(self,adapter_hid_dim:int) -> None: | |
keyBart = AutoModelForSeq2SeqLM.from_pretrained("bloomberg/KeyBART") | |
self.__fix_weights__(keyBart) | |
super().__init__(keyBart.model.config) | |
self.lm_head = keyBart.lm_head | |
self.model = BartPlus(keyBart, adapter_hid_dim) | |
self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings))) | |
def __fix_weights__(self,keyBart:BartForConditionalGeneration): | |
for i in keyBart.model.parameters(): | |
i.requires_grad = False | |
for i in keyBart.lm_head.parameters(): | |
i.requires_grad = False | |
def forward( | |
self, | |
input_ids: torch.LongTensor = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
decoder_input_ids: Optional[torch.LongTensor] = None, | |
decoder_attention_mask: Optional[torch.LongTensor] = None, | |
head_mask: Optional[torch.Tensor] = None, | |
decoder_head_mask: Optional[torch.Tensor] = None, | |
cross_attn_head_mask: Optional[torch.Tensor] = None, | |
encoder_outputs: Optional[List[torch.FloatTensor]] = None, | |
past_key_values: Optional[List[torch.FloatTensor]] = None, | |
inputs_embeds: Optional[torch.FloatTensor] = None, | |
decoder_inputs_embeds: Optional[torch.FloatTensor] = None, | |
labels: Optional[torch.LongTensor] = None, | |
use_cache: Optional[bool] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
) -> Union[Tuple, Seq2SeqLMOutput]: | |
r""" | |
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): | |
Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., | |
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored | |
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. | |
Returns: | |
""" | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
if labels is not None: | |
if use_cache: | |
logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") | |
use_cache = False | |
if decoder_input_ids is None and decoder_inputs_embeds is None: | |
decoder_input_ids = shift_tokens_right( | |
labels, self.config.pad_token_id, self.config.decoder_start_token_id | |
) | |
outputs = self.model( | |
input_ids, | |
attention_mask=attention_mask, | |
decoder_input_ids=decoder_input_ids, | |
encoder_outputs=encoder_outputs, | |
decoder_attention_mask=decoder_attention_mask, | |
head_mask=head_mask, | |
decoder_head_mask=decoder_head_mask, | |
cross_attn_head_mask=cross_attn_head_mask, | |
past_key_values=past_key_values, | |
inputs_embeds=inputs_embeds, | |
decoder_inputs_embeds=decoder_inputs_embeds, | |
use_cache=use_cache, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias | |
masked_lm_loss = None | |
if labels is not None: | |
loss_fct = CrossEntropyLoss() | |
masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) | |
if not return_dict: | |
output = (lm_logits,) + outputs[1:] | |
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output | |
return Seq2SeqLMOutput( | |
loss=masked_lm_loss, | |
logits=lm_logits, | |
past_key_values=outputs.past_key_values, | |
decoder_hidden_states=outputs.decoder_hidden_states, | |
decoder_attentions=outputs.decoder_attentions, | |
cross_attentions=outputs.cross_attentions, | |
encoder_last_hidden_state=outputs.encoder_last_hidden_state, | |
encoder_hidden_states=outputs.encoder_hidden_states, | |
encoder_attentions=outputs.encoder_attentions, | |
) | |
class BartDecoderPlus(BartDecoder): | |
def __init__(self,keyBart:BartForConditionalGeneration,adapter_hid_dim: int) -> None: | |
super().__init__(keyBart.get_decoder().config) | |
self.decoder = keyBart.model.decoder | |
self.adapters = nn.ModuleList([Adapter(self.decoder.config.d_model,adapter_hid_dim) for _ in range(len(self.decoder.layers))]) | |
self.config = self.decoder.config | |
self.dropout = self.decoder.dropout | |
self.layerdrop = self.decoder.layerdrop | |
self.padding_idx = self.decoder.padding_idx | |
self.max_target_positions = self.decoder.max_target_positions | |
self.embed_scale = self.decoder.embed_scale | |
self.embed_tokens = self.decoder.embed_tokens | |
self.embed_positions = self.decoder.embed_positions | |
self.layers = self.decoder.layers | |
self.layernorm_embedding = self.decoder.layernorm_embedding | |
self.gradient_checkpointing = self.decoder.gradient_checkpointing | |
def forward( | |
self, | |
input_ids: torch.LongTensor = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
encoder_hidden_states: Optional[torch.FloatTensor] = None, | |
encoder_attention_mask: Optional[torch.LongTensor] = None, | |
head_mask: Optional[torch.Tensor] = None, | |
cross_attn_head_mask: Optional[torch.Tensor] = None, | |
past_key_values: Optional[List[torch.FloatTensor]] = None, | |
inputs_embeds: Optional[torch.FloatTensor] = None, | |
use_cache: Optional[bool] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: | |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
output_hidden_states = ( | |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
) | |
use_cache = use_cache if use_cache is not None else self.config.use_cache | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
# retrieve input_ids and inputs_embeds | |
if input_ids is not None and inputs_embeds is not None: | |
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") | |
elif input_ids is not None: | |
input = input_ids | |
input_shape = input.shape | |
input_ids = input_ids.view(-1, input_shape[-1]) | |
elif inputs_embeds is not None: | |
input_shape = inputs_embeds.size()[:-1] | |
input = inputs_embeds[:, :, -1] | |
else: | |
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") | |
# past_key_values_length | |
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 | |
if inputs_embeds is None: | |
inputs_embeds = self.decoder.embed_tokens(input) * self.decoder.embed_scale | |
attention_mask = self.decoder._prepare_decoder_attention_mask( | |
attention_mask, input_shape, inputs_embeds, past_key_values_length | |
) | |
# expand encoder attention mask | |
if encoder_hidden_states is not None and encoder_attention_mask is not None: | |
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] | |
encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) | |
# embed positions | |
positions = self.decoder.embed_positions(input, past_key_values_length) | |
hidden_states = inputs_embeds + positions | |
hidden_states = self.decoder.layernorm_embedding(hidden_states) | |
hidden_states = nn.functional.dropout(hidden_states, p=self.decoder.dropout, training=self.decoder.training) | |
# decoder layers | |
all_hidden_states = () if output_hidden_states else None | |
all_self_attns = () if output_attentions else None | |
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None | |
next_decoder_cache = () if use_cache else None | |
# check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired | |
for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): | |
if attn_mask is not None: | |
if attn_mask.size()[0] != (len(self.decoder.layers)): | |
raise ValueError( | |
f"The `{mask_name}` should be specified for {len(self.decoder.layers)} layers, but it is for" | |
f" {head_mask.size()[0]}." | |
) | |
for idx, decoder_layer in enumerate(self.decoder.layers): | |
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) | |
if output_hidden_states: | |
all_hidden_states += (hidden_states,) | |
dropout_probability = random.uniform(0, 1) | |
if self.decoder.training and (dropout_probability < self.decoder.layerdrop): | |
continue | |
past_key_value = past_key_values[idx] if past_key_values is not None else None | |
if self.decoder.gradient_checkpointing and self.decoder.training: | |
if use_cache: | |
logger.warning( | |
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." | |
) | |
use_cache = False | |
def create_custom_forward(module): | |
def custom_forward(*inputs): | |
# None for past_key_value | |
return module(*inputs, output_attentions, use_cache) | |
return custom_forward | |
layer_outputs = torch.utils.checkpoint.checkpoint( | |
create_custom_forward(decoder_layer), | |
hidden_states, | |
attention_mask, | |
encoder_hidden_states, | |
encoder_attention_mask, | |
head_mask[idx] if head_mask is not None else None, | |
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, | |
None, | |
) | |
else: | |
layer_outputs = decoder_layer( | |
hidden_states, | |
attention_mask=attention_mask, | |
encoder_hidden_states=encoder_hidden_states, | |
encoder_attention_mask=encoder_attention_mask, | |
layer_head_mask=(head_mask[idx] if head_mask is not None else None), | |
cross_attn_layer_head_mask=( | |
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None | |
), | |
past_key_value=past_key_value, | |
output_attentions=output_attentions, | |
use_cache=use_cache, | |
) | |
hidden_states = layer_outputs[0] | |
######################### new ################################# | |
hidden_states = self.adapters[idx](hidden_states) | |
######################### new ################################# | |
if use_cache: | |
next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) | |
if output_attentions: | |
all_self_attns += (layer_outputs[1],) | |
if encoder_hidden_states is not None: | |
all_cross_attentions += (layer_outputs[2],) | |
# add hidden states from the last decoder layer | |
if output_hidden_states: | |
all_hidden_states += (hidden_states,) | |
next_cache = next_decoder_cache if use_cache else None | |
if not return_dict: | |
return tuple( | |
v | |
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] | |
if v is not None | |
) | |
return BaseModelOutputWithPastAndCrossAttentions( | |
last_hidden_state=hidden_states, | |
past_key_values=next_cache, | |
hidden_states=all_hidden_states, | |
attentions=all_self_attns, | |
cross_attentions=all_cross_attentions, | |
) | |
class BartPlus(BartModel): | |
def __init__(self,keyBart: BartForConditionalGeneration, adapter_hid_dim: int ) -> None: | |
super().__init__(keyBart.model.config) | |
self.config = keyBart.model.config | |
# self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx) | |
self.shared = keyBart.model.shared | |
#self.encoder = BartEncoder(config, self.shared) | |
self.encoder = keyBart.model.encoder | |
#self.decoder = BartDecoder(config, self.shared) | |
#self.decoder = keyBart.model.decoder | |
self.decoder = BartDecoderPlus(keyBart,adapter_hid_dim=adapter_hid_dim) | |
def forward( | |
self, | |
input_ids: torch.LongTensor = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
decoder_input_ids: Optional[torch.LongTensor] = None, | |
decoder_attention_mask: Optional[torch.LongTensor] = None, | |
head_mask: Optional[torch.Tensor] = None, | |
decoder_head_mask: Optional[torch.Tensor] = None, | |
cross_attn_head_mask: Optional[torch.Tensor] = None, | |
encoder_outputs: Optional[List[torch.FloatTensor]] = None, | |
past_key_values: Optional[List[torch.FloatTensor]] = None, | |
inputs_embeds: Optional[torch.FloatTensor] = None, | |
decoder_inputs_embeds: Optional[torch.FloatTensor] = None, | |
use_cache: Optional[bool] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
) -> Union[Tuple, Seq2SeqModelOutput]: | |
# different to other models, Bart automatically creates decoder_input_ids from | |
# input_ids if no decoder_input_ids are provided | |
if decoder_input_ids is None and decoder_inputs_embeds is None: | |
if input_ids is None: | |
raise ValueError( | |
"If no `decoder_input_ids` or `decoder_inputs_embeds` are " | |
"passed, `input_ids` cannot be `None`. Please pass either " | |
"`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`." | |
) | |
decoder_input_ids = shift_tokens_right( | |
input_ids, self.config.pad_token_id, self.config.decoder_start_token_id | |
) | |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
output_hidden_states = ( | |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
) | |
use_cache = use_cache if use_cache is not None else self.config.use_cache | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
if encoder_outputs is None: | |
encoder_outputs = self.encoder( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
head_mask=head_mask, | |
inputs_embeds=inputs_embeds, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
# If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True | |
elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): | |
encoder_outputs = BaseModelOutput( | |
last_hidden_state=encoder_outputs[0], | |
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, | |
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, | |
) | |
# decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) | |
decoder_outputs = self.decoder( | |
input_ids=decoder_input_ids, | |
attention_mask=decoder_attention_mask, | |
encoder_hidden_states=encoder_outputs[0], | |
encoder_attention_mask=attention_mask, | |
head_mask=decoder_head_mask, | |
cross_attn_head_mask=cross_attn_head_mask, | |
past_key_values=past_key_values, | |
inputs_embeds=decoder_inputs_embeds, | |
use_cache=use_cache, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
if not return_dict: | |
return decoder_outputs + encoder_outputs | |
return Seq2SeqModelOutput( | |
last_hidden_state=decoder_outputs.last_hidden_state, | |
past_key_values=decoder_outputs.past_key_values, | |
decoder_hidden_states=decoder_outputs.hidden_states, | |
decoder_attentions=decoder_outputs.attentions, | |
cross_attentions=decoder_outputs.cross_attentions, | |
encoder_last_hidden_state=encoder_outputs.last_hidden_state, | |
encoder_hidden_states=encoder_outputs.hidden_states, | |
encoder_attentions=encoder_outputs.attentions, | |
) | |