athenas-lens / models.py
bowphs's picture
Add initial attempt of a code framework.
3bc4816
raw
history blame contribute delete
No virus
4.15 kB
import torch
from torch import nn
from transformers import RobertaPreTrainedModel
from transformers.modeling_outputs import TokenClassifierOutput
from transformers.models.roberta.modeling_roberta import RobertaConfig, RobertaModel
from utils import batched_index_select
class DependencyRobertaForTokenClassification(RobertaPreTrainedModel):
config_class = RobertaConfig # type: ignore
def __init__(self, config):
super().__init__(config)
self.roberta = RobertaModel(config, add_pooling_layer=False)
self.u_a = nn.Linear(768, 768)
self.w_a = nn.Linear(768, 768)
self.v_a_inv = nn.Linear(768, 1, bias=False)
self.criterion = nn.NLLLoss()
self.init_weights()
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
labels=None,
**kwargs,
):
loss = 0.0
output = self.roberta(
input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids
)[0]
batch_size, seq_len, _ = output.size()
parent_prob_table = []
for i in range(0, seq_len):
target = output[:, i, :].expand(seq_len, batch_size, -1).transpose(0, 1)
mask = output.eq(target)[:, :, 0].unsqueeze(2)
p_head = self.attention(output, target, mask)
if labels is not None:
current_loss = self.criterion(p_head.squeeze(-1), labels[:, i])
if not torch.all(labels[:, i] == -100):
loss += current_loss
parent_prob_table.append(torch.exp(p_head))
parent_prob_table = torch.cat((parent_prob_table), dim=2).data.transpose(1, 2)
prob, topi = parent_prob_table.topk(k=1, dim=2)
preds = topi.squeeze(-1)
loss = loss / seq_len
output = TokenClassifierOutput(loss=loss, logits=preds)
if labels is not None:
return output, preds, parent_prob_table, labels
else:
return output, preds, parent_prob_table
def attention(self, source, target, mask=None):
function_g = self.v_a_inv(torch.tanh(self.u_a(source) + self.w_a(target)))
if mask is not None:
function_g.masked_fill_(mask, -1e4)
return nn.functional.log_softmax(function_g, dim=1)
class LabelRobertaForTokenClassification(RobertaPreTrainedModel):
config_class = RobertaConfig # type: ignore
def __init__(self, config):
super().__init__(config)
self.roberta = RobertaModel(config, add_pooling_layer=False)
self.num_labels = 33
self.hidden = nn.Linear(768 * 2, 768)
self.relu = nn.ReLU()
self.out = nn.Linear(768, self.num_labels)
self.loss_fct = nn.CrossEntropyLoss()
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
labels=None,
**kwargs,
):
loss = 0.0
output = self.roberta(
input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids
)[0]
batch_size, seq_len, _ = output.size()
logits = []
for i in range(seq_len):
current_token = output[:, i, :]
connected_with_index = kwargs["head_labels"][:, i]
connected_with_index[connected_with_index == -100] = 0
connected_with_embedding = batched_index_select(
output.clone(), 1, connected_with_index.clone()
)
combined_embeddings = torch.cat(
(current_token, connected_with_embedding.squeeze(1)), -1
)
pred = self.out(self.relu(self.hidden(combined_embeddings)))
pred = pred.view(-1, self.num_labels)
logits.append(pred)
if labels is not None:
current_loss = self.loss_fct(pred, labels[:, i].view(-1))
if not torch.all(labels[:, i] == -100):
loss += current_loss
loss = loss / seq_len
logits = torch.stack(logits, dim=1)
output = TokenClassifierOutput(loss=loss, logits=logits)
return output