from transformers import PreTrainedModel | |
from .configuration_my import MyConfig | |
import torch | |
class MyModel(PreTrainedModel): | |
config_class = MyConfig | |
def __init__(self, config): | |
super().__init__(config) | |
self.model = torch.nn.Linear(10, 2) | |
def forward(self, tensor): | |
return self.model(tensor) |