|
from transformers import PreTrainedModel |
|
from transformers import PretrainedConfig |
|
from typing import List |
|
import torch.nn as nn |
|
import torch |
|
|
|
|
|
class MyModelConfig(PretrainedConfig): |
|
|
|
def __init__( |
|
self, |
|
input_dim=100, |
|
layers_num=5, |
|
**kwargs, |
|
): |
|
self.input_dim = input_dim |
|
self.layers_num = layers_num |
|
super().__init__(**kwargs) |
|
|
|
class MyModel(PreTrainedModel): |
|
config_class = MyModelConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
modules = [] |
|
assert config.layers_num >= 1 |
|
if config.layers_num == 1: |
|
modules.append(nn.Linear(config.input_dim,1)) |
|
else: |
|
modules.append(nn.Linear(config.input_dim,30)) |
|
for i in range(config.layers_num-2): |
|
modules.append(nn.Linear(30,30)) |
|
modules.append(nn.Linear(30,1)) |
|
self.model = nn.ModuleList(modules) |
|
|
|
|
|
def forward(self, tensor): |
|
return self.model(tensor) |
|
|
|
if __name__ == '__main__': |
|
save_config = MyModelConfig(input_dim=10,layers_num=3) |
|
save_config.save_pretrained("custom-mymodel") |
|
mymodel = MyModel(save_config) |
|
torch.save(mymodel.state_dict(),'pytorch_model.bin') |