from typing import List import torch from torch import nn import numpy as np from torchvision import models from torchvision.models import ResNet18_Weights,ResNet50_Weights,VGG16_Weights,MobileNet_V2_Weights class EarlyStopping: def __init__(self, tolerance=5, min_delta=0): self.tolerance = tolerance self.min_delta = min_delta self.counter = 0 self.early_stop = False def __call__(self, train_loss, validation_loss): if (validation_loss - train_loss) > self.min_delta: self.counter +=1 if self.counter >= self.tolerance: self.early_stop = True class Resnet18(nn.Module): def __init__(self,out_shape:int = 1000) -> None: super().__init__() self.resnet = models.resnet18() self.resnet.fc = nn.Linear(512,out_shape) def forward(self,x): return self.resnet(x) class PretrainedResnet18(nn.Module): def __init__(self,out_shape:int = 1000) -> None: super().__init__() self.resnet = models.resnet18(weights=ResNet18_Weights.DEFAULT) # freeze all layers except last fc layer for parms in self.resnet.parameters(): parms.requires_grad = False self.resnet.fc = nn.Linear(512,out_shape) def forward(self,x): return self.resnet(x) class Resnet50(nn.Module): def __init__(self,out_shape:int = 1000) -> None: super().__init__() self.resnet = models.resnet50() self.resnet.fc = nn.Linear(2048,out_shape) def forward(self,x): return self.resnet(x) class PretrainedResnet50(nn.Module): def __init__(self,out_shape:int = 1000) -> None: super().__init__() self.resnet = models.resnet50(weights=ResNet50_Weights.DEFAULT) # freeze all layers except last fc layer for parms in self.resnet.parameters(): parms.requires_grad = False self.resnet.fc = nn.Linear(2048,out_shape) def forward(self,x): return self.resnet(x) class EfficentNetB0(nn.Module): def __init__(self,out_shape:int = 1000) -> None: super().__init__() self.effnet = models.efficientnet_b0() self.effnet.classifier = nn.Linear(1280,out_shape) def forward(self,x): return self.effnet(x) class MobileNetV2(nn.Module): def __init__(self,out_shape:int = 1000) -> None: super().__init__() self.mobilenet = models.mobilenet_v2() self.mobilenet.classifier[1] = nn.Linear(1280,out_shape) def forward(self,x): return self.mobilenet(x) class PretrainedMobileNetV2(nn.Module): def __init__(self,out_shape:int = 1000) -> None: super().__init__() self.mobilenet = models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT) # freeze all layers except last fc layer for parms in self.mobilenet.parameters(): parms.requires_grad = False self.mobilenet.classifier = nn.Sequential( nn.Dropout(p=0.2, inplace=False), nn.Linear(in_features=1280, out_features=1000, bias=True) ) def forward(self,x): return self.mobilenet(x) class VGG16(nn.Module): def __init__(self,out_shape:int = 1000) -> None: super().__init__() self.vgg = models.vgg16() self.vgg.classifier = nn.Sequential( nn.Linear(in_features=25088, out_features=4096, bias=True), nn.ReLU(inplace=True), nn.Dropout(p=0.5, inplace=False), nn.Linear(in_features=4096, out_features=4096, bias=True), nn.ReLU(inplace=True), nn.Dropout(p=0.5, inplace=False), nn.Linear(in_features=4096, out_features=out_shape, bias=True), ) def forward(self,x): return self.vgg(x) class PretrainedVGG16(nn.Module): def __init__(self,out_shape:int = 1000) -> None: super().__init__() self.vgg = models.vgg16(weights=VGG16_Weights.DEFAULT) # freeze all layers except last clf layer for parms in self.vgg.parameters(): parms.requires_grad = False self.vgg.classifier = nn.Sequential( nn.Linear(in_features=25088, out_features=4096, bias=True), nn.ReLU(inplace=True), nn.Dropout(p=0.5, inplace=False), nn.Linear(in_features=4096, out_features=4096, bias=True), nn.ReLU(inplace=True), nn.Dropout(p=0.5, inplace=False), nn.Linear(in_features=4096, out_features=out_shape, bias=True), ) def forward(self,x): return self.vgg(x) class VIT(nn.Module): def __init__(self,out_shape:int = 1000) -> None: super().__init__() self.vit = models.vit_b_16() self.vit.head = nn.Linear(768,out_shape) def forward(self,x): return self.vit(x) # functions to get models def get_resnet_18_model(): model = Resnet18(out_shape=7) return model def get_resnet_50_model(): model = Resnet50(out_shape=7) return model def get_vgg_16_model(): model = VGG16(out_shape=7) return model def get_mobilenet_v2_model(): model = MobileNetV2(out_shape=7) return model