custom-mybertclassifier / modeling_my_bert_classifier.py
Capstone-lpx's picture
Upload MyBertClassifier
0c45ca3 verified
raw
history blame contribute delete
No virus
866 Bytes
from transformers import BertTokenizer, BertModel
from .configuration_my_bert_classifier import MyBertClassifierConfig
from torch import nn
from transformers.modeling_utils import PreTrainedModel
class MyBertClassifier(PreTrainedModel):
config_class = MyBertClassifierConfig
def __init__(self, config):
super(MyBertClassifier, self).__init__(config)
self.bert = BertModel.from_pretrained('bert-base-cased')
self.dropout = nn.Dropout(0.5)
self.linear = nn.Linear(768, 5)
self.relu = nn.ReLU()
def forward(self, input_id, mask):
_, pooled_output = self.bert(input_ids= input_id, attention_mask=mask,return_dict=False)
dropout_output = self.dropout(pooled_output)
linear_output = self.linear(dropout_output)
final_layer = self.relu(linear_output)
return final_layer