Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from transformers import RobertaModel, RobertaPreTrainedModel | |
| class RobertaMultiTask(RobertaPreTrainedModel): | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.num_labels = config.num_labels | |
| self.roberta = RobertaModel(config) | |
| self.dropout = nn.Dropout(config.hidden_dropout_prob) | |
| self.classifier = nn.Linear(config.hidden_size, config.num_labels) | |
| self.span_classifier = nn.Linear(config.hidden_size, 2) | |
| self.post_init() | |
| def forward( | |
| self, | |
| input_ids=None, | |
| attention_mask=None, | |
| token_type_ids=None, | |
| labels=None, | |
| span_labels=None | |
| ): | |
| outputs = self.roberta( | |
| input_ids, | |
| attention_mask=attention_mask | |
| ) | |
| sequence_output = self.dropout(outputs.last_hidden_state) | |
| pooled_output = self.dropout(outputs.pooler_output) | |
| logits = self.classifier(pooled_output) | |
| span_logits = self.span_classifier(sequence_output) | |
| loss = None | |
| if labels is not None and span_labels is not None: | |
| cls_loss = nn.CrossEntropyLoss()( | |
| logits.view(-1, self.num_labels), | |
| labels.view(-1) | |
| ) | |
| span_loss = nn.CrossEntropyLoss(ignore_index=-100)( | |
| span_logits.view(-1, 2), | |
| span_labels.view(-1) | |
| ) | |
| loss = cls_loss + 0.3 * span_loss | |
| return { | |
| "loss": loss, | |
| "logits": logits, | |
| "span_logits": span_logits | |
| } |