backpack-gpt2-nli / modeling_backpack_gpt2_nli.py
ErfanMoosaviMonazzah's picture
Upload model
eb0bbcd
import torch
import torch.utils.checkpoint
from torch import nn
from transformers.models.gpt2.modeling_gpt2 import GPT2PreTrainedModel
from .configuration_backpack_gpt2_nli import BackpackGPT2NLIConfig
from .modeling_backpack_gpt2 import BackpackGPT2Model
class BackpackGPT2NLIModel(GPT2PreTrainedModel):
config_class = BackpackGPT2NLIConfig
def __init__(self, config):
super().__init__(config)
self.backpack = BackpackGPT2Model(config)
self.n_embd = config.n_embd
self.num_labels = config.num_labels # 0: Entailment -- 1: Neutral -- 2: Contradiction
self.nli_head = nn.Sequential(
nn.Linear(self.n_embd, self.n_embd),
nn.Dropout(0.1),
nn.Linear(self.n_embd, self.num_labels)
)
# Freeze The Encoder if Needed
self.backpack.requires_grad_(not config.freeze_backpack)
self.loss_func = nn.CrossEntropyLoss()
# Model parallel
self.model_parallel = False
def forward(self, input_ids=None, attention_mask=None, labels=None):
backpack_outputs = self.backpack(input_ids=input_ids, position_ids=None)
backpack_hidden_states, backpack_contextualization = backpack_outputs.hidden_states, backpack_outputs.contextualization
last_toks_indices = attention_mask.shape[1] - 1 - attention_mask.flip((1,)).argmax(dim=1) # index of the last token of the input (according to att mask)
last_backpack_hidden_states = backpack_hidden_states[torch.arange(backpack_hidden_states.shape[0]), last_toks_indices, :]
logits = self.nli_head(last_backpack_hidden_states)
if labels is not None:
# Flatten the logits and labels, considering the attention mask
flat_logits = logits
flat_labels = labels.view(-1)
loss = self.loss_func(flat_logits, flat_labels)
return {'logits': logits, 'loss': loss}
else:
return {'logits': logits}
def predict(self, input_ids=None, attention_mask=None):
logits = self.forward(input_ids, attention_mask, labels=None)['logits']
p = torch.argmax(logits, axis=1)
labels = [self.config.id2label[index.item()] for index in p]
return {'labels':labels, 'logits':logits}