File size: 2,397 Bytes
8777cb2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
from transformers import PretrainedConfig
from transformers import AutoModel, AutoConfig
import torch.nn as nn
from transformers import BertPreTrainedModel, AutoModel, PretrainedConfig
class PragFormerConfig(PretrainedConfig):
model_type = "pragformer"
def __init__(self, bert=None, dropout=0.2, fc1=512, fc2=2, softmax_dim=1, **kwargs):
self.bert = bert
self.dropout = dropout
self.fc1 = fc1
self.fc2 = fc2
self.softmax_dim = softmax_dim
super().__init__(**kwargs)
class BERT_Arch(BertPreTrainedModel):
config_class = PragFormerConfig
def __init__(self, config):
super().__init__(config)
self.bert = AutoModel.from_pretrained(config.bert['_name_or_path'])
# dropout layer
self.dropout = nn.Dropout(config.dropout)
# relu activation function
self.relu = nn.ReLU()
# dense layer 1
self.fc1 = nn.Linear(self.config.bert['hidden_size'], config.fc1)
# self.fc1 = nn.Linear(768, 512)
# dense layer 2 (Output layer)
self.fc2 = nn.Linear(config.fc1, config.fc2)
# softmax activation function
self.softmax = nn.LogSoftmax(dim = config.softmax_dim)
# define the forward pass
def forward(self, input_ids, attention_mask):
# pass the inputs to the model
_, cls_hs = self.bert(input_ids, attention_mask = attention_mask, return_dict=False)
x = self.fc1(cls_hs)
x = self.relu(x)
x = self.dropout(x)
# output layer
x = self.fc2(x)
# apply softmax activation
x = self.softmax(x)
return x
PragFormerConfig.register_for_auto_class()
BERT_Arch.register_for_auto_class("AutoModel")
config = PragFormerConfig.from_pretrained('./Classifier/PragFormer')
model = BERT_Arch(config)
pretrained_model = BERT_Arch.from_pretrained("./Classifier/PragFormer")
model.load_state_dict(pretrained_model.state_dict())
model.push_to_hub("PragFormer")
# AutoConfig.register("pragformer", PragFormerConfig)
# AutoModel.register(PragFormerConfig, BERT_Arch)
# model.push_to_hub("PragFormer")
# # config = PragFormerConfig.from_pretrained('./PragFormer')
# # model = BERT_Arch(config)
# # pretrained_model = BERT_Arch.from_pretrained("./PragFormer")
# # model.load_state_dict(pretrained_model.state_dict())
# # model.push_to_hub("PragFormer") |