import torch import torch.nn as nn import torchvision.models as models from config import Resnet50Config from transformers import PreTrainedModel class Resnet50FER(PreTrainedModel): config_class = Resnet50Config def __init__(self, config): super().__init__(config) # Load the ResNet50 model without the final fully connected layer self.resnet = models.resnet50(pretrained=False) num_ftrs = self.resnet.fc.in_features # Replace the fully connected layer with a new one for your specific classification task self.resnet.fc = nn.Linear(num_ftrs, config.num_classes) def forward(self, x): # Forward pass through the ResNet50 model x = self.resnet(x) return x