from transformers import PreTrainedModel import torch import torch.nn as nn from .configuration_convnet import ConNetConfig # Convolutional neural network (two convolutional layers) class ConvNet(nn.Module): def __init__(self, num_classes=10): super(ConvNet, self).__init__() self.layer1 = nn.Sequential( nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2), nn.BatchNorm2d(16), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2)) self.layer2 = nn.Sequential( nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2)) self.fc = nn.Linear(7*7*32, num_classes) def forward(self, x): out = self.layer1(x) out = self.layer2(out) out = out.reshape(out.size(0), -1) out = self.fc(out) return out class ConvNetModel(PreTrainedModel): config_class = ConNetConfig def __init__(self, config): super().__init__(config) self.model = ConvNet(num_classes=config.num_classes) def forward(self, x): out = self.model(x) return out if __name__=="__main__": resnet50d_config = ConNetConfig(num_classes=10) resnet50d = ConvNetModel(resnet50d_config) resnet50d.save_pretrained("my_models") pass