File size: 1,894 Bytes
58e7c02 f5b20aa 58e7c02 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 |
# from transformers.models.led.modeling_led import LEDEncoder
from transformers import LEDConfig, LEDModel, LEDPreTrainedModel
from transformers.modeling_outputs import TokenClassifierOutput
import torch.nn as nn
class CustomLEDForResultsIdModel(LEDPreTrainedModel):
def __init__(self, config: LEDConfig, checkpoint=None):
super().__init__(config)
self.num_labels = config.num_labels
print("Configs")
print(config.num_labels)
print(config.dropout)
#Load Model with given checkpoint and extract its body
if (checkpoint):
self.led = LEDModel.from_pretrained(checkpoint, config=config).get_encoder()
else:
self.led = LEDModel(config).get_encoder()
# self.model = LEDEncoder.from_pretrained(checkpoint, config=config)
self.dropout = nn.Dropout(config.dropout)
self.classifier = nn.Linear(self.led.config.d_model,self.num_labels) # load and initialize weights
def forward(self, input_ids=None, attention_mask=None, labels=None, global_attention_mask=None, return_loss=True):
#Extract outputs from the body
outputs = self.led(input_ids=input_ids, attention_mask=attention_mask, global_attention_mask=global_attention_mask)
sequence_output = self.dropout(outputs.last_hidden_state)
logits = self.classifier(sequence_output) # calculate losses [:,0,:].view(-1,768)
# start_logits, end_logits = logits.split(1, dim=-1)
# start_logits = start_logits.squeeze(-1).contiguous()
# end_logits = end_logits.squeeze(-1).contiguous()
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
return {
'loss': loss,
'logits': logits
} |