h5 / model.py
Diego192's picture
Upload 8 files
3e04925 verified
raw
history blame contribute delete
No virus
752 Bytes
import torch
import torch.nn as nn
import torchvision.models as models
from config import Resnet50Config
from transformers import PreTrainedModel
class Resnet50FER(PreTrainedModel):
config_class = Resnet50Config
def __init__(self, config):
super().__init__(config)
# Load the ResNet50 model without the final fully connected layer
self.resnet = models.resnet50(pretrained=False)
num_ftrs = self.resnet.fc.in_features
# Replace the fully connected layer with a new one for your specific classification task
self.resnet.fc = nn.Linear(num_ftrs, config.num_classes)
def forward(self, x):
# Forward pass through the ResNet50 model
x = self.resnet(x)
return x