|
from transformers import BertTokenizer, BertModel |
|
from .configuration_my_bert_classifier import MyBertClassifierConfig |
|
from torch import nn |
|
from transformers.modeling_utils import PreTrainedModel |
|
|
|
class MyBertClassifier(PreTrainedModel): |
|
|
|
config_class = MyBertClassifierConfig |
|
|
|
def __init__(self, config): |
|
|
|
super(MyBertClassifier, self).__init__(config) |
|
|
|
self.bert = BertModel.from_pretrained('bert-base-cased') |
|
self.dropout = nn.Dropout(0.5) |
|
self.linear = nn.Linear(768, 5) |
|
self.relu = nn.ReLU() |
|
|
|
def forward(self, input_id, mask): |
|
|
|
_, pooled_output = self.bert(input_ids= input_id, attention_mask=mask,return_dict=False) |
|
dropout_output = self.dropout(pooled_output) |
|
linear_output = self.linear(dropout_output) |
|
final_layer = self.relu(linear_output) |
|
|
|
return final_layer |