PragFormer / init.py
Pragformer's picture
Upload BERT_Arch
6c2fcf7
from transformers import PretrainedConfig
from transformers import AutoModel, AutoConfig
import torch.nn as nn
from transformers import BertPreTrainedModel, AutoModel, PretrainedConfig
class PragFormerConfig(PretrainedConfig):
model_type = "pragformer"
def __init__(self, bert=None, dropout=0.2, fc1=512, fc2=2, softmax_dim=1, **kwargs):
self.bert = bert
self.dropout = dropout
self.fc1 = fc1
self.fc2 = fc2
self.softmax_dim = softmax_dim
super().__init__(**kwargs)
class BERT_Arch(BertPreTrainedModel):
config_class = PragFormerConfig
def __init__(self, config):
super().__init__(config)
self.bert = AutoModel.from_pretrained(config.bert['_name_or_path'])
# dropout layer
self.dropout = nn.Dropout(config.dropout)
# relu activation function
self.relu = nn.ReLU()
# dense layer 1
self.fc1 = nn.Linear(self.config.bert['hidden_size'], config.fc1)
# self.fc1 = nn.Linear(768, 512)
# dense layer 2 (Output layer)
self.fc2 = nn.Linear(config.fc1, config.fc2)
# softmax activation function
self.softmax = nn.LogSoftmax(dim = config.softmax_dim)
# define the forward pass
def forward(self, input_ids, attention_mask):
# pass the inputs to the model
_, cls_hs = self.bert(input_ids, attention_mask = attention_mask, return_dict=False)
x = self.fc1(cls_hs)
x = self.relu(x)
x = self.dropout(x)
# output layer
x = self.fc2(x)
# apply softmax activation
x = self.softmax(x)
return x
if __name__ == "__main__":
PragFormerConfig.register_for_auto_class()
BERT_Arch.register_for_auto_class("AutoModel")
config = PragFormerConfig.from_pretrained('./PragFormer')
model = BERT_Arch(config)
pretrained_model = BERT_Arch.from_pretrained("./PragFormer")
model.load_state_dict(pretrained_model.state_dict())
model.push_to_hub("PragFormer")
# AutoConfig.register("pragformer", PragFormerConfig)
# AutoModel.register(PragFormerConfig, BERT_Arch)
# model.push_to_hub("PragFormer")
# # config = PragFormerConfig.from_pretrained('./PragFormer')
# # model = BERT_Arch(config)
# # pretrained_model = BERT_Arch.from_pretrained("./PragFormer")
# # model.load_state_dict(pretrained_model.state_dict())
# # model.push_to_hub("PragFormer")