from transformers import PreTrainedModel from transformers import PretrainedConfig from typing import List import torch.nn as nn import torch class MyModelConfig(PretrainedConfig): def __init__(# 每个参数都必须带有默认值,否则会报错 self, input_dim=100, layers_num=5, **kwargs, ): self.input_dim = input_dim self.layers_num = layers_num super().__init__(**kwargs) class MyModel(PreTrainedModel): config_class = MyModelConfig def __init__(self, config): super().__init__(config) modules = [] assert config.layers_num >= 1 if config.layers_num == 1: modules.append(nn.Linear(config.input_dim,1)) else: modules.append(nn.Linear(config.input_dim,30)) for i in range(config.layers_num-2): modules.append(nn.Linear(30,30)) modules.append(nn.Linear(30,1)) self.model = nn.ModuleList(modules) def forward(self, tensor): return self.model(tensor) if __name__ == '__main__': save_config = MyModelConfig(input_dim=10,layers_num=3) save_config.save_pretrained("custom-mymodel") mymodel = MyModel(save_config) torch.save(mymodel.state_dict(),'pytorch_model.bin') # 通常以此命名