Spaces:
Sleeping
Sleeping
import torch | |
from torch import nn | |
from transformers import RobertaPreTrainedModel | |
from transformers.modeling_outputs import TokenClassifierOutput | |
from transformers.models.roberta.modeling_roberta import RobertaConfig, RobertaModel | |
from utils import batched_index_select | |
class DependencyRobertaForTokenClassification(RobertaPreTrainedModel): | |
config_class = RobertaConfig # type: ignore | |
def __init__(self, config): | |
super().__init__(config) | |
self.roberta = RobertaModel(config, add_pooling_layer=False) | |
self.u_a = nn.Linear(768, 768) | |
self.w_a = nn.Linear(768, 768) | |
self.v_a_inv = nn.Linear(768, 1, bias=False) | |
self.criterion = nn.NLLLoss() | |
self.init_weights() | |
def forward( | |
self, | |
input_ids=None, | |
attention_mask=None, | |
token_type_ids=None, | |
labels=None, | |
**kwargs, | |
): | |
loss = 0.0 | |
output = self.roberta( | |
input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids | |
)[0] | |
batch_size, seq_len, _ = output.size() | |
parent_prob_table = [] | |
for i in range(0, seq_len): | |
target = output[:, i, :].expand(seq_len, batch_size, -1).transpose(0, 1) | |
mask = output.eq(target)[:, :, 0].unsqueeze(2) | |
p_head = self.attention(output, target, mask) | |
if labels is not None: | |
current_loss = self.criterion(p_head.squeeze(-1), labels[:, i]) | |
if not torch.all(labels[:, i] == -100): | |
loss += current_loss | |
parent_prob_table.append(torch.exp(p_head)) | |
parent_prob_table = torch.cat((parent_prob_table), dim=2).data.transpose(1, 2) | |
prob, topi = parent_prob_table.topk(k=1, dim=2) | |
preds = topi.squeeze(-1) | |
loss = loss / seq_len | |
output = TokenClassifierOutput(loss=loss, logits=preds) | |
if labels is not None: | |
return output, preds, parent_prob_table, labels | |
else: | |
return output, preds, parent_prob_table | |
def attention(self, source, target, mask=None): | |
function_g = self.v_a_inv(torch.tanh(self.u_a(source) + self.w_a(target))) | |
if mask is not None: | |
function_g.masked_fill_(mask, -1e4) | |
return nn.functional.log_softmax(function_g, dim=1) | |
class LabelRobertaForTokenClassification(RobertaPreTrainedModel): | |
config_class = RobertaConfig # type: ignore | |
def __init__(self, config): | |
super().__init__(config) | |
self.roberta = RobertaModel(config, add_pooling_layer=False) | |
self.num_labels = 33 | |
self.hidden = nn.Linear(768 * 2, 768) | |
self.relu = nn.ReLU() | |
self.out = nn.Linear(768, self.num_labels) | |
self.loss_fct = nn.CrossEntropyLoss() | |
def forward( | |
self, | |
input_ids=None, | |
attention_mask=None, | |
token_type_ids=None, | |
labels=None, | |
**kwargs, | |
): | |
loss = 0.0 | |
output = self.roberta( | |
input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids | |
)[0] | |
batch_size, seq_len, _ = output.size() | |
logits = [] | |
for i in range(seq_len): | |
current_token = output[:, i, :] | |
connected_with_index = kwargs["head_labels"][:, i] | |
connected_with_index[connected_with_index == -100] = 0 | |
connected_with_embedding = batched_index_select( | |
output.clone(), 1, connected_with_index.clone() | |
) | |
combined_embeddings = torch.cat( | |
(current_token, connected_with_embedding.squeeze(1)), -1 | |
) | |
pred = self.out(self.relu(self.hidden(combined_embeddings))) | |
pred = pred.view(-1, self.num_labels) | |
logits.append(pred) | |
if labels is not None: | |
current_loss = self.loss_fct(pred, labels[:, i].view(-1)) | |
if not torch.all(labels[:, i] == -100): | |
loss += current_loss | |
loss = loss / seq_len | |
logits = torch.stack(logits, dim=1) | |
output = TokenClassifierOutput(loss=loss, logits=logits) | |
return output | |