Spaces:
Runtime error
Runtime error
import warnings | |
from pytorch_lightning import LightningModule | |
from fengshen.models import transformer_utils | |
import torch | |
import torch.utils.checkpoint | |
from torch import nn | |
import torch.nn.functional as F | |
from dataclasses import dataclass | |
from typing import Optional, Tuple | |
from transformers.file_utils import * | |
from transformers.modeling_outputs import * | |
from transformers.models.bart import * | |
from transformers.models.bart.modeling_bart import BartClassificationHead | |
_CONFIG_FOR_DOC = "BartConfig" | |
# ------------------------ ZZ: CBart addition ------------------------ | |
def _reorder_buffer(attn_cache, new_order): | |
for k, input_buffer_k in attn_cache.items(): | |
if input_buffer_k is not None: | |
attn_cache[k] = input_buffer_k.index_select(0, new_order) | |
return attn_cache | |
def _make_linear_from_emb(emb): | |
vocab_size, emb_size = emb.weight.shape | |
lin_layer = nn.Linear(vocab_size, emb_size, bias=False) | |
lin_layer.weight.data = emb.weight.data | |
return lin_layer | |
BART_GENERATION_EXAMPLE = r""" | |
Summarization example:: | |
>>> from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig | |
>>> model = BartForConditionalGeneration.from_pretrained('facebook/bart-large-cnn') | |
>>> tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-cnn') | |
>>> ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs." | |
>>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='pt') | |
>>> # Generate Summary | |
>>> summary_ids = model.generate(inputs['input_ids'], num_beams=4, max_length=5, early_stopping=True) | |
>>> print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids]) | |
Mask filling example:: | |
>>> from transformers import BartTokenizer, BartForConditionalGeneration | |
>>> tokenizer = BartTokenizer.from_pretrained('facebook/bart-large') | |
>>> TXT = "My friends are <mask> but they eat too many carbs." | |
>>> model = BartForConditionalGeneration.from_pretrained('facebook/bart-large') | |
>>> input_ids = tokenizer([TXT], return_tensors='pt')['input_ids'] | |
>>> logits = model(input_ids).logits | |
>>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item() | |
>>> probs = logits[0, masked_index].softmax(dim=0) | |
>>> values, predictions = probs.topk(5) | |
>>> tokenizer.decode(predictions).split() | |
""" | |
class CBartLMOutput(ModelOutput): | |
""" | |
Base class for CBart specific language models outputs. | |
Args: | |
.... | |
""" | |
loss: Optional[torch.FloatTensor] = None | |
encoder_loss: Optional[torch.FloatTensor] = None | |
decoder_loss: Optional[torch.FloatTensor] = None | |
encoder_logits: torch.FloatTensor = None | |
logits: torch.FloatTensor = None | |
past_key_values: Optional[Tuple[torch.FloatTensor]] = None | |
decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None | |
decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None | |
encoder_last_hidden_state: Optional[torch.FloatTensor] = None | |
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None | |
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None | |
class BartForTextInfill(BartPretrainedModel): | |
""" | |
this class is designed for text infilling. | |
During training, the encoder is used to predict replace, insert, | |
and the decoder is used to generate original input. | |
Compared with BartForConditionalGeneration class, | |
we add a module over the encoder and add a new loss for the encoder. | |
""" | |
base_model_prefix = "model" | |
authorized_missing_keys = [r"final_logits_bias", | |
r"encoder\.version", r"decoder\.version"] | |
def __init__(self, config: BartConfig): | |
super().__init__(config) | |
base_model = BartModel(config) | |
self.model = base_model | |
self.register_buffer("final_logits_bias", torch.zeros( | |
(1, self.model.shared.num_embeddings))) | |
# print( config.encoder_loss_type, config.num_labels) | |
# add a new attribute into BartConfig class (revise BartConfig) | |
self.encoder_loss_type = config.encoder_loss_type | |
self.num_labels = config.num_labels | |
if self.encoder_loss_type == 0: # 0 is classification loss, 1 is regression loss | |
# add a classification module for the encoder | |
self.classification_head = BartClassificationHead( | |
config.d_model, config.d_model, config.num_labels, config.classif_dropout, | |
) | |
else: | |
# add a regression module for the encoder | |
self.classification_head = BartClassificationHead( | |
config.d_model, config.d_model, 1, config.classif_dropout, | |
) | |
self.model._init_weights(self.classification_head.dense) | |
self.model._init_weights(self.classification_head.out_proj) | |
self.loss_weight = config.loss_weight | |
self.register_buffer("label_weights", torch.zeros((self.num_labels))) | |
def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding: | |
old_num_tokens = self.model.shared.num_embeddings | |
new_embeddings = super().resize_token_embeddings(new_num_tokens) | |
self.model.shared = new_embeddings | |
self._resize_final_logits_bias(new_num_tokens, old_num_tokens) | |
return new_embeddings | |
def _resize_final_logits_bias(self, new_num_tokens: int, old_num_tokens: int) -> None: | |
if new_num_tokens <= old_num_tokens: | |
new_bias = self.final_logits_bias[:, :new_num_tokens] | |
else: | |
extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), | |
device=self.final_logits_bias.device) | |
new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1) | |
self.register_buffer("final_logits_bias", new_bias) | |
def forward( | |
self, | |
input_ids, | |
attention_mask=None, | |
encoder_outputs=None, | |
decoder_input_ids=None, | |
decoder_attention_mask=None, | |
past_key_values=None, | |
encoder_labels=None, | |
labels=None, | |
use_cache=None, | |
output_attentions=None, | |
output_hidden_states=None, | |
return_dict=True, | |
**unused, | |
): | |
r""" | |
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): | |
Labels for computing the masked language modeling loss. | |
Indices should either be in ``[0, ..., config.vocab_size]`` or -100 (see ``input_ids`` docstring). | |
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens | |
with labels in ``[0, ..., config.vocab_size]``. | |
Returns: | |
Conditional generation example:: | |
# Mask filling only works for bart-large | |
from transformers import BartTokenizer, BartForConditionalGeneration | |
tokenizer = BartTokenizer.from_pretrained('facebook/bart-large') | |
TXT = "My friends are <mask> but they eat too many carbs." | |
model = BartForConditionalGeneration.from_pretrained('facebook/bart-large') | |
input_ids = tokenizer([TXT], return_tensors='pt')['input_ids'] | |
logits = model(input_ids).logits | |
masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item() | |
probs = logits[0, masked_index].softmax(dim=0) | |
values, predictions = probs.topk(5) | |
tokenizer.decode(predictions).split() | |
# ['good', 'great', 'all', 'really', 'very'] | |
""" | |
if "lm_labels" in unused: | |
warnings.warn( | |
"The `lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.", | |
FutureWarning, | |
) | |
labels = unused.pop("lm_labels") | |
if "decoder_cached_states" in unused: | |
warnings.warn( | |
"The `decoder_cached_states` argument is deprecated and will be removed in a future version, use `decoder_past_key_values` instead.", | |
FutureWarning, | |
) | |
decoder_past_key_values = unused.pop("decoder_cached_states") | |
return_dict = return_dict if return_dict is not None else False | |
if labels is not None: | |
use_cache = False | |
outputs = self.model( | |
input_ids, | |
attention_mask=attention_mask, | |
decoder_input_ids=decoder_input_ids, | |
encoder_outputs=encoder_outputs, | |
decoder_attention_mask=decoder_attention_mask, | |
past_key_values=past_key_values, | |
use_cache=use_cache, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
# logits and loss for the encoder | |
# last hidden state | |
encoder_last_hidden_state = outputs['encoder_last_hidden_state'] | |
# eos_mask = input_ids.eq(self.config.eos_token_id) | |
# if len(torch.unique(eos_mask.sum(1))) > 1: | |
# raise ValueError("All examples must have the same number of <eos> tokens.") | |
# sentence_representation = x[eos_mask, :].view(x.size(0), -1, x.size(-1))[:, -1, :] | |
encoder_logits = self.classification_head(encoder_last_hidden_state) | |
encoder_loss = None | |
if encoder_labels is not None: | |
# classification loss | |
if self.encoder_loss_type == 0: | |
# ZZ: seems like MSE loss does not support weighting, so only CEL has weighting applied for now | |
loss_fct = nn.CrossEntropyLoss(weight=self.label_weights) | |
encoder_loss = loss_fct( | |
encoder_logits.view(-1, self.config.num_labels), encoder_labels.view(-1)) | |
# regression loss | |
else: | |
encoder_logits = encoder_logits.view( | |
encoder_logits.size(0), -1) | |
encoder_logits = torch.sigmoid( | |
encoder_logits) * self.num_labels - 0.5 | |
loss_fct = nn.MSELoss(reduction='none') | |
_loss = loss_fct(encoder_logits, encoder_labels) | |
encoder_loss = torch.mean(_loss[encoder_labels >= 0]) | |
# encoder_loss =_loss[encoder_labels>=0] | |
# logits and loss for the decoder | |
lm_logits = F.linear( | |
outputs[0], self.model.shared.weight, bias=self.final_logits_bias) | |
masked_lm_loss = None | |
if labels is not None: | |
loss_fct = nn.CrossEntropyLoss() | |
# TODO(SS): do we need to ignore pad tokens in labels? | |
masked_lm_loss = loss_fct( | |
lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) | |
loss = None | |
if masked_lm_loss is not None and encoder_loss is not None: | |
loss = encoder_loss * self.loss_weight + masked_lm_loss | |
if not return_dict: | |
output = (lm_logits,) + outputs[1:] | |
return ((loss,) + output) if loss is not None else output | |
return CBartLMOutput( | |
loss=loss, | |
encoder_loss=encoder_loss, | |
decoder_loss=masked_lm_loss, | |
encoder_logits=encoder_logits, | |
logits=lm_logits, | |
past_key_values=outputs.past_key_values, | |
decoder_hidden_states=outputs.decoder_hidden_states, | |
decoder_attentions=outputs.decoder_attentions, | |
encoder_last_hidden_state=outputs.encoder_last_hidden_state, | |
encoder_hidden_states=outputs.encoder_hidden_states, | |
encoder_attentions=outputs.encoder_attentions, | |
) | |
def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, use_cache, **kwargs): | |
assert past is not None, "past has to be defined for encoder_outputs" | |
encoder_outputs, past_key_values = past | |
return { | |
"input_ids": None, # encoder_outputs is defined. input_ids not needed | |
"encoder_outputs": encoder_outputs, | |
"past_key_values": past_key_values, | |
"decoder_input_ids": decoder_input_ids, | |
"attention_mask": attention_mask, | |
# change this to avoid caching (presumably for debugging) | |
"use_cache": use_cache, | |
} | |
def adjust_logits_during_generation(self, logits, cur_len, max_length): | |
if cur_len == 1: | |
self._force_token_ids_generation(logits, self.config.bos_token_id) | |
if cur_len == max_length - 1 and self.config.eos_token_id is not None: | |
self._force_token_ids_generation(logits, self.config.eos_token_id) | |
return logits | |
def _force_token_ids_generation(self, scores, token_ids) -> None: | |
"""force one of token_ids to be generated by setting prob of all other tokens to 0""" | |
if isinstance(token_ids, int): | |
token_ids = [token_ids] | |
all_but_token_ids_mask = torch.tensor( | |
[x for x in range(self.config.vocab_size) if x not in token_ids], | |
dtype=torch.long, | |
device=next(self.parameters()).device, | |
) | |
assert len( | |
scores.shape) == 2, "scores should be of rank 2 with shape: [batch_size, vocab_size]" | |
scores[:, all_but_token_ids_mask] = -float("inf") | |
def _reorder_cache(past, beam_idx): | |
((enc_out, enc_mask), past_key_values) = past | |
reordered_past = [] | |
for layer_past in past_key_values: | |
# get the correct batch idx from decoder layer's batch dim for cross and self-attn | |
layer_past_new = { | |
attn_key: _reorder_buffer(attn_cache, beam_idx) for attn_key, attn_cache in layer_past.items() | |
} | |
reordered_past.append(layer_past_new) | |
new_enc_out = enc_out if enc_out is None else enc_out.index_select( | |
0, beam_idx) | |
new_enc_mask = enc_mask if enc_mask is None else enc_mask.index_select( | |
0, beam_idx) | |
past = ((new_enc_out, new_enc_mask), reordered_past) | |
return past | |
def get_encoder(self): | |
return self.model.encoder | |
def get_output_embeddings(self): | |
return _make_linear_from_emb(self.model.shared) # make it on the fly | |
def get_encoder_logits(self, input_ids, attention_mask=None): | |
# print(input_ids, attention_mask) | |
# encoder_outputs = self.model.get_encoder_outputs( | |
# self, | |
# input_ids, | |
# attention_mask=attention_mask, | |
# output_attentions=None, | |
# output_hidden_states=None, | |
# return_dict=None, | |
# ) | |
encoder_outputs = self.model.encoder( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
return_dict=True | |
) | |
# logits and loss for the encoder | |
# last hidden state | |
encoder_last_hidden_state = encoder_outputs['last_hidden_state'] | |
encoder_logits = self.classification_head(encoder_last_hidden_state) | |
# classification | |
if self.encoder_loss_type == 0: | |
# probs = torch.softmax(encoder_logits,dim=-1) | |
pass | |
# regression | |
else: | |
encoder_logits = encoder_logits.view(encoder_logits.size(0), -1) | |
encoder_logits = torch.sigmoid( | |
encoder_logits) * self.num_labels - 0.5 | |
return encoder_outputs, encoder_logits | |
class CBartLightning(LightningModule): | |
def add_module_specific_args(parent_args): | |
parser = parent_args.add_argument_group("CBart specific parameters") | |
parser.add_argument('--num_labels', type=int, default=3) | |
parser.add_argument('--encoder_loss_type', type=int, default=0) | |
parser.add_argument('--loss_weight', type=float, default=1.0) | |
parser.add_argument('--label_weights', type=float, nargs='+', default=[1.0, 1.0, 1.0]) | |
parser.add_argument('--masked_lm', type=float, default=0) | |
return parent_args | |
def __init__( | |
self, | |
args, | |
**kwargs, | |
): | |
super().__init__() | |
self.save_hyperparameters(args) | |
self.model = BartForTextInfill.from_pretrained(args.model_path, num_labels=self.hparams.num_labels, | |
encoder_loss_type=self.hparams.encoder_loss_type, | |
loss_weight=self.hparams.loss_weight,) | |
self.model.label_weights = torch.tensor( | |
self.hparams.label_weights, dtype=torch.half) | |
def forward(self, **inputs): | |
return self.model(**inputs) | |
def training_step(self, batch, batch_idx): | |
outputs = self(**batch) | |
return outputs | |
def validation_step(self, batch, batch_idx, dataloader_idx=0): | |
outputs = self(**batch) | |
val_loss = outputs["loss"] | |
return {"loss": val_loss} | |
def setup(self, stage=None) -> None: | |
if stage != "fit": | |
return | |
# Get dataloader by calling it - train_dataloader() is called after setup() by default | |
train_loader = self.trainer._data_connector._train_dataloader_source.dataloader() | |
# Calculate total steps | |
tb_size = self.hparams.train_batchsize * max(1, self.trainer.gpus) | |
ab_size = self.trainer.accumulate_grad_batches * float(self.trainer.max_epochs) | |
self.total_steps = (len(train_loader.dataset) // tb_size) // ab_size | |
def configure_optimizers(self): | |
transformer_utils.configure_optimizers(self) | |