import torch import torch.nn as nn N_EMOTIONS = 8 N_CELEBRITIES = 17 class CustomModel(nn.Module) : def __init__(self,mode = 'emotion') : super().__init__() self.mode = mode self.backbone = nn.Sequential( #3x224x224 nn.Conv2d(3, 64, kernel_size=3, stride=1, bias=False), nn.BatchNorm2d(64), nn.LeakyReLU(0.2, inplace=True), # out: 64 x 222 x 222 nn.Conv2d(64, 32, kernel_size=3, stride=1, bias=False), nn.BatchNorm2d(32), nn.LeakyReLU(0.2, inplace=True), nn.MaxPool2d(kernel_size=2), nn.Dropout(0.2), # out: 32 x 110 x 110 nn.Conv2d(32, 32, kernel_size=3, stride=1, bias=False), nn.BatchNorm2d(32), nn.LeakyReLU(0.2, inplace=True), nn.MaxPool2d(kernel_size=2), nn.Dropout(0.3), # out: 32 x 54 x 54 nn.Flatten(), ) self.in_features = 32*54*54 self.neck = nn.Sequential( nn.Linear(self.in_features,128), nn.ReLU(), nn.Linear(128,64), nn.ReLU() ) self.emotion_classifier = nn.Linear(64,N_EMOTIONS) self.celebrity_classifier = nn.Linear(64,N_CELEBRITIES) def forward(self,image) : features = self.backbone(image) features = self.neck(features) if self.mode=='emotion' : emotion_logits = self.emotion_classifier(features) return emotion_logits elif self.mode=='celebrity' : celebrity_logits = self.celebrity_classifier(features) return celebrity_logits else : emotion_logits = self.emotion_classifier(features) celebrity_logits = self.celebrity_classifier(features) return emotion_logits,celebrity_logits import torchvision.models as models class ResNet50Model(nn.Module) : def __init__(self,mode = 'emotion') : super().__init__() self.mode = mode self.backbone = getattr(models, 'resnet50')(False) self.in_features = 1000 self.neck = nn.Sequential( nn.Linear(self.in_features,128), nn.ReLU(), nn.Linear(128,64), nn.ReLU() ) self.emotion_classifier = nn.Linear(64,N_EMOTIONS) self.celebrity_classifier = nn.Linear(64,N_CELEBRITIES) def forward(self,image) : features = self.backbone(image) features = self.neck(features) if self.mode=='emotion' : emotion_logits = self.emotion_classifier(features) return emotion_logits elif self.mode=='celebrity' : celebrity_logits = self.celebrity_classifier(features) return celebrity_logits else : emotion_logits = self.emotion_classifier(features) celebrity_logits = self.celebrity_classifier(features) return emotion_logits,celebrity_logits