File size: 752 Bytes
3e04925 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 |
import torch
import torch.nn as nn
import torchvision.models as models
from config import Resnet50Config
from transformers import PreTrainedModel
class Resnet50FER(PreTrainedModel):
config_class = Resnet50Config
def __init__(self, config):
super().__init__(config)
# Load the ResNet50 model without the final fully connected layer
self.resnet = models.resnet50(pretrained=False)
num_ftrs = self.resnet.fc.in_features
# Replace the fully connected layer with a new one for your specific classification task
self.resnet.fc = nn.Linear(num_ftrs, config.num_classes)
def forward(self, x):
# Forward pass through the ResNet50 model
x = self.resnet(x)
return x |