Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import transformers | |
class BertAD(nn.Module): | |
def __init__(self): | |
super(BertAD, self).__init__() | |
model_config = transformers.AutoConfig.from_pretrained('./model') | |
model_config.update({"output_hidden_states":True}) | |
self.bert = transformers.BertModel(model_config) | |
self.layer = nn.Linear(768, 2) | |
def forward(self, ids, mask, token_type): | |
output = self.bert(input_ids = ids, | |
attention_mask = mask, | |
token_type_ids = token_type) | |
logits = self.layer(output[0]) | |
start_logits, end_logits = logits.split(1, dim=-1) | |
start_logits = start_logits.squeeze(-1) | |
end_logits = end_logits.squeeze(-1) | |
return start_logits, end_logits |