from transformers import PreTrainedModel from .configuration_my import MyConfig import torch class MyModel(PreTrainedModel): config_class = MyConfig def __init__(self, config): super().__init__(config) self.model = torch.nn.Linear(10, 2) def forward(self, tensor): return self.model(tensor)