|
from transformers import PretrainedConfig |
|
from transformers import AutoModel, AutoConfig |
|
import torch.nn as nn |
|
from transformers import BertPreTrainedModel, AutoModel, PretrainedConfig |
|
|
|
|
|
|
|
class PragFormerConfig(PretrainedConfig): |
|
model_type = "pragformer" |
|
|
|
def __init__(self, bert=None, dropout=0.2, fc1=512, fc2=2, softmax_dim=1, **kwargs): |
|
self.bert = bert |
|
self.dropout = dropout |
|
self.fc1 = fc1 |
|
self.fc2 = fc2 |
|
self.softmax_dim = softmax_dim |
|
super().__init__(**kwargs) |
|
|
|
|
|
class BERT_Arch(BertPreTrainedModel): |
|
config_class = PragFormerConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
self.bert = AutoModel.from_pretrained(config.bert['_name_or_path']) |
|
|
|
|
|
self.dropout = nn.Dropout(config.dropout) |
|
|
|
|
|
self.relu = nn.ReLU() |
|
|
|
|
|
self.fc1 = nn.Linear(self.config.bert['hidden_size'], config.fc1) |
|
|
|
|
|
|
|
self.fc2 = nn.Linear(config.fc1, config.fc2) |
|
|
|
|
|
self.softmax = nn.LogSoftmax(dim = config.softmax_dim) |
|
|
|
|
|
def forward(self, input_ids, attention_mask): |
|
|
|
_, cls_hs = self.bert(input_ids, attention_mask = attention_mask, return_dict=False) |
|
|
|
x = self.fc1(cls_hs) |
|
|
|
x = self.relu(x) |
|
|
|
x = self.dropout(x) |
|
|
|
|
|
x = self.fc2(x) |
|
|
|
|
|
x = self.softmax(x) |
|
return x |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
PragFormerConfig.register_for_auto_class() |
|
BERT_Arch.register_for_auto_class("AutoModel") |
|
config = PragFormerConfig.from_pretrained('./PragFormer') |
|
model = BERT_Arch(config) |
|
pretrained_model = BERT_Arch.from_pretrained("./PragFormer") |
|
model.load_state_dict(pretrained_model.state_dict()) |
|
model.push_to_hub("PragFormer") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|