import torch import torch.utils.checkpoint from torch import nn from transformers.models.gpt2.modeling_gpt2 import GPT2PreTrainedModel from .configuration_backpack_gpt2_nli import BackpackGPT2NLIConfig from .modeling_backpack_gpt2 import BackpackGPT2Model class BackpackGPT2NLIModel(GPT2PreTrainedModel): config_class = BackpackGPT2NLIConfig def __init__(self, config): super().__init__(config) self.backpack = BackpackGPT2Model(config) self.n_embd = config.n_embd self.num_labels = config.num_labels # 0: Entailment -- 1: Neutral -- 2: Contradiction self.nli_head = nn.Sequential( nn.Linear(self.n_embd, self.n_embd), nn.Dropout(0.1), nn.Linear(self.n_embd, self.num_labels) ) # Freeze The Encoder if Needed self.backpack.requires_grad_(not config.freeze_backpack) self.loss_func = nn.CrossEntropyLoss() # Model parallel self.model_parallel = False def forward(self, input_ids=None, attention_mask=None, labels=None): backpack_outputs = self.backpack(input_ids=input_ids, position_ids=None) backpack_hidden_states, backpack_contextualization = backpack_outputs.hidden_states, backpack_outputs.contextualization last_toks_indices = attention_mask.shape[1] - 1 - attention_mask.flip((1,)).argmax(dim=1) # index of the last token of the input (according to att mask) last_backpack_hidden_states = backpack_hidden_states[torch.arange(backpack_hidden_states.shape[0]), last_toks_indices, :] logits = self.nli_head(last_backpack_hidden_states) if labels is not None: # Flatten the logits and labels, considering the attention mask flat_logits = logits flat_labels = labels.view(-1) loss = self.loss_func(flat_logits, flat_labels) return {'logits': logits, 'loss': loss} else: return {'logits': logits} def predict(self, input_ids=None, attention_mask=None): logits = self.forward(input_ids, attention_mask, labels=None)['logits'] p = torch.argmax(logits, axis=1) labels = [self.config.id2label[index.item()] for index in p] return {'labels':labels, 'logits':logits}