relation-extraction-api / Nested /nn /BertSeqTagger.py
TymaaHammouda's picture
Upload 39 files
2551344 verified
raw
history blame contribute delete
504 Bytes
import torch.nn as nn
from transformers import BertModel
class BertSeqTagger(nn.Module):
def __init__(self, bert_model, num_labels=2, dropout=0.1):
super().__init__()
self.bert = BertModel.from_pretrained(bert_model)
self.dropout = nn.Dropout(dropout)
self.linear = nn.Linear(768, num_labels)
def forward(self, x):
y = self.bert(x)
y = self.dropout(y["last_hidden_state"])
logits = self.linear(y)
return logits