File size: 331 Bytes
79ba3b2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 |
from transformers import PreTrainedModel
from .configuration_my import MyConfig
import torch
class MyModel(PreTrainedModel):
config_class = MyConfig
def __init__(self, config):
super().__init__(config)
self.model = torch.nn.Linear(10, 2)
def forward(self, tensor):
return self.model(tensor) |