CustomLEDForQAonNQ / modeling_CustomLEDForQA.py
ekolasky's picture
Upload model
b0ebb4e
# from transformers.models.led.modeling_led import LEDEncoder
from transformers import LEDConfig, LEDModel, LEDPreTrainedModel
import torch.nn as nn
# NEED TO REPLACE nn.Module with PreTrainedModel
class CustomLEDForQAModel(LEDPreTrainedModel):
config_class = LEDConfig
def __init__(self, config: LEDConfig, checkpoint):
super().__init__(config)
config.num_labels = 2
self.num_labels = config.num_labels
if (checkpoint):
self.led = LEDModel.from_pretrained(checkpoint, config=config).get_encoder()
else:
self.led = LEDModel(config).get_encoder()
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
def forward(self, input_ids=None, attention_mask=None, global_attention_mask=None, start_positions=None, end_positions=None):
outputs = self.led(input_ids=input_ids, attention_mask=attention_mask, global_attention_mask=global_attention_mask)
logits = self.qa_outputs(outputs.last_hidden_state)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1).contiguous()
total_loss = None
if start_positions is not None and end_positions is not None:
loss_fct = nn.CrossEntropyLoss()
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
# start_loss = loss_fct(start_logits[index], start_positions[index][0])
# end_loss = loss_fct(end_logits[index], end_positions[index][0])
total_loss = (start_loss + end_loss) / 2
return {
'loss': total_loss,
'start_logits': start_logits,
'end_logits': end_logits,
}