|
|
|
from transformers import LEDConfig, LEDModel, LEDPreTrainedModel |
|
import torch.nn as nn |
|
|
|
|
|
class CustomLEDForQAModel(LEDPreTrainedModel): |
|
config_class = LEDConfig |
|
|
|
def __init__(self, config: LEDConfig, checkpoint): |
|
|
|
super().__init__(config) |
|
config.num_labels = 2 |
|
self.num_labels = config.num_labels |
|
|
|
if (checkpoint): |
|
self.led = LEDModel.from_pretrained(checkpoint, config=config).get_encoder() |
|
else: |
|
self.led = LEDModel(config).get_encoder() |
|
|
|
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) |
|
|
|
def forward(self, input_ids=None, attention_mask=None, global_attention_mask=None, start_positions=None, end_positions=None): |
|
|
|
outputs = self.led(input_ids=input_ids, attention_mask=attention_mask, global_attention_mask=global_attention_mask) |
|
|
|
logits = self.qa_outputs(outputs.last_hidden_state) |
|
start_logits, end_logits = logits.split(1, dim=-1) |
|
start_logits = start_logits.squeeze(-1).contiguous() |
|
end_logits = end_logits.squeeze(-1).contiguous() |
|
|
|
total_loss = None |
|
|
|
if start_positions is not None and end_positions is not None: |
|
|
|
loss_fct = nn.CrossEntropyLoss() |
|
|
|
if len(start_positions.size()) > 1: |
|
start_positions = start_positions.squeeze(-1) |
|
if len(end_positions.size()) > 1: |
|
end_positions = end_positions.squeeze(-1) |
|
|
|
start_loss = loss_fct(start_logits, start_positions) |
|
end_loss = loss_fct(end_logits, end_positions) |
|
|
|
|
|
total_loss = (start_loss + end_loss) / 2 |
|
|
|
|
|
|
|
|
|
return { |
|
'loss': total_loss, |
|
'start_logits': start_logits, |
|
'end_logits': end_logits, |
|
} |