|
|
|
from transformers import LEDConfig, LEDModel, LEDPreTrainedModel |
|
from transformers.modeling_outputs import TokenClassifierOutput |
|
|
|
import torch.nn as nn |
|
|
|
class CustomLEDForResultsIdModel(LEDPreTrainedModel): |
|
def __init__(self, config: LEDConfig, checkpoint=None): |
|
super().__init__(config) |
|
self.num_labels = config.num_labels |
|
print("Configs") |
|
print(config.num_labels) |
|
print(config.dropout) |
|
|
|
|
|
if (checkpoint): |
|
self.led = LEDModel.from_pretrained(checkpoint, config=config).get_encoder() |
|
else: |
|
self.led = LEDModel(config).get_encoder() |
|
|
|
|
|
self.dropout = nn.Dropout(config.dropout) |
|
self.classifier = nn.Linear(self.led.config.d_model,self.num_labels) |
|
|
|
def forward(self, input_ids=None, attention_mask=None, labels=None, global_attention_mask=None, return_loss=True): |
|
|
|
outputs = self.led(input_ids=input_ids, attention_mask=attention_mask, global_attention_mask=global_attention_mask) |
|
|
|
sequence_output = self.dropout(outputs.last_hidden_state) |
|
logits = self.classifier(sequence_output) |
|
|
|
|
|
|
|
|
|
|
|
loss = None |
|
if labels is not None: |
|
loss_fct = nn.CrossEntropyLoss() |
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
|
|
|
return { |
|
'loss': loss, |
|
'logits': logits |
|
} |