import torch from torch import nn from transformers import RobertaPreTrainedModel from transformers.modeling_outputs import TokenClassifierOutput from transformers.models.roberta.modeling_roberta import RobertaConfig, RobertaModel from utils import batched_index_select class DependencyRobertaForTokenClassification(RobertaPreTrainedModel): config_class = RobertaConfig # type: ignore def __init__(self, config): super().__init__(config) self.roberta = RobertaModel(config, add_pooling_layer=False) self.u_a = nn.Linear(768, 768) self.w_a = nn.Linear(768, 768) self.v_a_inv = nn.Linear(768, 1, bias=False) self.criterion = nn.NLLLoss() self.init_weights() def forward( self, input_ids=None, attention_mask=None, token_type_ids=None, labels=None, **kwargs, ): loss = 0.0 output = self.roberta( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids )[0] batch_size, seq_len, _ = output.size() parent_prob_table = [] for i in range(0, seq_len): target = output[:, i, :].expand(seq_len, batch_size, -1).transpose(0, 1) mask = output.eq(target)[:, :, 0].unsqueeze(2) p_head = self.attention(output, target, mask) if labels is not None: current_loss = self.criterion(p_head.squeeze(-1), labels[:, i]) if not torch.all(labels[:, i] == -100): loss += current_loss parent_prob_table.append(torch.exp(p_head)) parent_prob_table = torch.cat((parent_prob_table), dim=2).data.transpose(1, 2) prob, topi = parent_prob_table.topk(k=1, dim=2) preds = topi.squeeze(-1) loss = loss / seq_len output = TokenClassifierOutput(loss=loss, logits=preds) if labels is not None: return output, preds, parent_prob_table, labels else: return output, preds, parent_prob_table def attention(self, source, target, mask=None): function_g = self.v_a_inv(torch.tanh(self.u_a(source) + self.w_a(target))) if mask is not None: function_g.masked_fill_(mask, -1e4) return nn.functional.log_softmax(function_g, dim=1) class LabelRobertaForTokenClassification(RobertaPreTrainedModel): config_class = RobertaConfig # type: ignore def __init__(self, config): super().__init__(config) self.roberta = RobertaModel(config, add_pooling_layer=False) self.num_labels = 33 self.hidden = nn.Linear(768 * 2, 768) self.relu = nn.ReLU() self.out = nn.Linear(768, self.num_labels) self.loss_fct = nn.CrossEntropyLoss() def forward( self, input_ids=None, attention_mask=None, token_type_ids=None, labels=None, **kwargs, ): loss = 0.0 output = self.roberta( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids )[0] batch_size, seq_len, _ = output.size() logits = [] for i in range(seq_len): current_token = output[:, i, :] connected_with_index = kwargs["head_labels"][:, i] connected_with_index[connected_with_index == -100] = 0 connected_with_embedding = batched_index_select( output.clone(), 1, connected_with_index.clone() ) combined_embeddings = torch.cat( (current_token, connected_with_embedding.squeeze(1)), -1 ) pred = self.out(self.relu(self.hidden(combined_embeddings))) pred = pred.view(-1, self.num_labels) logits.append(pred) if labels is not None: current_loss = self.loss_fct(pred, labels[:, i].view(-1)) if not torch.all(labels[:, i] == -100): loss += current_loss loss = loss / seq_len logits = torch.stack(logits, dim=1) output = TokenClassifierOutput(loss=loss, logits=logits) return output