File size: 752 Bytes
3e04925
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import torch
import torch.nn as nn
import torchvision.models as models
from config import Resnet50Config
from transformers import PreTrainedModel

class Resnet50FER(PreTrainedModel):
    config_class = Resnet50Config
    def __init__(self, config):
        super().__init__(config)
        # Load the ResNet50 model without the final fully connected layer
        self.resnet = models.resnet50(pretrained=False)
        
        num_ftrs = self.resnet.fc.in_features
        # Replace the fully connected layer with a new one for your specific classification task
        self.resnet.fc = nn.Linear(num_ftrs, config.num_classes)

    def forward(self, x):
        # Forward pass through the ResNet50 model
        x = self.resnet(x)
  
        return x