from transformers import PretrainedConfig from transformers import AutoModel, AutoConfig import torch.nn as nn from transformers import BertPreTrainedModel, AutoModel, PretrainedConfig class PragFormerConfig(PretrainedConfig): model_type = "pragformer" def __init__(self, bert=None, dropout=0.2, fc1=512, fc2=2, softmax_dim=1, **kwargs): self.bert = bert self.dropout = dropout self.fc1 = fc1 self.fc2 = fc2 self.softmax_dim = softmax_dim super().__init__(**kwargs) class BERT_Arch(BertPreTrainedModel): config_class = PragFormerConfig def __init__(self, config): super().__init__(config) self.bert = AutoModel.from_pretrained(config.bert['_name_or_path']) # dropout layer self.dropout = nn.Dropout(config.dropout) # relu activation function self.relu = nn.ReLU() # dense layer 1 self.fc1 = nn.Linear(self.config.bert['hidden_size'], config.fc1) # self.fc1 = nn.Linear(768, 512) # dense layer 2 (Output layer) self.fc2 = nn.Linear(config.fc1, config.fc2) # softmax activation function self.softmax = nn.LogSoftmax(dim = config.softmax_dim) # define the forward pass def forward(self, input_ids, attention_mask): # pass the inputs to the model _, cls_hs = self.bert(input_ids, attention_mask = attention_mask, return_dict=False) x = self.fc1(cls_hs) x = self.relu(x) x = self.dropout(x) # output layer x = self.fc2(x) # apply softmax activation x = self.softmax(x) return x if __name__ == "__main__": PragFormerConfig.register_for_auto_class() BERT_Arch.register_for_auto_class("AutoModel") config = PragFormerConfig.from_pretrained('./PragFormer') model = BERT_Arch(config) pretrained_model = BERT_Arch.from_pretrained("./PragFormer") model.load_state_dict(pretrained_model.state_dict()) model.push_to_hub("PragFormer") # AutoConfig.register("pragformer", PragFormerConfig) # AutoModel.register(PragFormerConfig, BERT_Arch) # model.push_to_hub("PragFormer") # # config = PragFormerConfig.from_pretrained('./PragFormer') # # model = BERT_Arch(config) # # pretrained_model = BERT_Arch.from_pretrained("./PragFormer") # # model.load_state_dict(pretrained_model.state_dict()) # # model.push_to_hub("PragFormer")