import config import transformers import torch.nn as nn class BERTBaseUncased(nn.Module): def __init__(self): super(BERTBaseUncased, self).__init__() self.bert = transformers.BertModel.from_pretrained(config.BERT_PATH) self.bert_drop = nn.Dropout(0.3) self.out = nn.Linear(768, 3) # self.out = nn.Linear(256, 3) nn.init.xavier_uniform_(self.out.weight) def forward(self, ids, mask, token_type_ids): _, o2 = self.bert( ids, attention_mask=mask, token_type_ids=token_type_ids ) bo = self.bert_drop(o2) # bo = self.tanh(self.fc(bo)) # to be commented if original output = self.out(bo) return output def extract_features(self, ids, mask, token_type_ids): _, o2 = self.bert( ids, attention_mask=mask, token_type_ids=token_type_ids ) bo = self.bert_drop(o2) return bo