kpriyanshu256's picture
Memory release
4013271
raw
history blame contribute delete
768 Bytes
import torch
import torch.nn as nn
import transformers
class BertAD(nn.Module):
def __init__(self):
super(BertAD, self).__init__()
model_config = transformers.AutoConfig.from_pretrained('./model')
model_config.update({"output_hidden_states":True})
self.bert = transformers.BertModel(model_config)
self.layer = nn.Linear(768, 2)
def forward(self, ids, mask, token_type):
output = self.bert(input_ids = ids,
attention_mask = mask,
token_type_ids = token_type)
logits = self.layer(output[0])
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1)
return start_logits, end_logits