sketch_rec_Mini / models.py
d22cs051's picture
uploading app and initial commit
f217fdb
raw
history blame
5.36 kB
from typing import List
import torch
from torch import nn
import numpy as np
from torchvision import models
from torchvision.models import ResNet18_Weights,ResNet50_Weights,VGG16_Weights,MobileNet_V2_Weights
class EarlyStopping:
def __init__(self, tolerance=5, min_delta=0):
self.tolerance = tolerance
self.min_delta = min_delta
self.counter = 0
self.early_stop = False
def __call__(self, train_loss, validation_loss):
if (validation_loss - train_loss) > self.min_delta:
self.counter +=1
if self.counter >= self.tolerance:
self.early_stop = True
class Resnet18(nn.Module):
def __init__(self,out_shape:int = 1000) -> None:
super().__init__()
self.resnet = models.resnet18()
self.resnet.fc = nn.Linear(512,out_shape)
def forward(self,x):
return self.resnet(x)
class PretrainedResnet18(nn.Module):
def __init__(self,out_shape:int = 1000) -> None:
super().__init__()
self.resnet = models.resnet18(weights=ResNet18_Weights.DEFAULT)
# freeze all layers except last fc layer
for parms in self.resnet.parameters():
parms.requires_grad = False
self.resnet.fc = nn.Linear(512,out_shape)
def forward(self,x):
return self.resnet(x)
class Resnet50(nn.Module):
def __init__(self,out_shape:int = 1000) -> None:
super().__init__()
self.resnet = models.resnet50()
self.resnet.fc = nn.Linear(2048,out_shape)
def forward(self,x):
return self.resnet(x)
class PretrainedResnet50(nn.Module):
def __init__(self,out_shape:int = 1000) -> None:
super().__init__()
self.resnet = models.resnet50(weights=ResNet50_Weights.DEFAULT)
# freeze all layers except last fc layer
for parms in self.resnet.parameters():
parms.requires_grad = False
self.resnet.fc = nn.Linear(2048,out_shape)
def forward(self,x):
return self.resnet(x)
class EfficentNetB0(nn.Module):
def __init__(self,out_shape:int = 1000) -> None:
super().__init__()
self.effnet = models.efficientnet_b0()
self.effnet.classifier = nn.Linear(1280,out_shape)
def forward(self,x):
return self.effnet(x)
class MobileNetV2(nn.Module):
def __init__(self,out_shape:int = 1000) -> None:
super().__init__()
self.mobilenet = models.mobilenet_v2()
self.mobilenet.classifier[1] = nn.Linear(1280,out_shape)
def forward(self,x):
return self.mobilenet(x)
class PretrainedMobileNetV2(nn.Module):
def __init__(self,out_shape:int = 1000) -> None:
super().__init__()
self.mobilenet = models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT)
# freeze all layers except last fc layer
for parms in self.mobilenet.parameters():
parms.requires_grad = False
self.mobilenet.classifier = nn.Sequential(
nn.Dropout(p=0.2, inplace=False),
nn.Linear(in_features=1280, out_features=1000, bias=True)
)
def forward(self,x):
return self.mobilenet(x)
class VGG16(nn.Module):
def __init__(self,out_shape:int = 1000) -> None:
super().__init__()
self.vgg = models.vgg16()
self.vgg.classifier = nn.Sequential(
nn.Linear(in_features=25088, out_features=4096, bias=True),
nn.ReLU(inplace=True),
nn.Dropout(p=0.5, inplace=False),
nn.Linear(in_features=4096, out_features=4096, bias=True),
nn.ReLU(inplace=True),
nn.Dropout(p=0.5, inplace=False),
nn.Linear(in_features=4096, out_features=out_shape, bias=True),
)
def forward(self,x):
return self.vgg(x)
class PretrainedVGG16(nn.Module):
def __init__(self,out_shape:int = 1000) -> None:
super().__init__()
self.vgg = models.vgg16(weights=VGG16_Weights.DEFAULT)
# freeze all layers except last clf layer
for parms in self.vgg.parameters():
parms.requires_grad = False
self.vgg.classifier = nn.Sequential(
nn.Linear(in_features=25088, out_features=4096, bias=True),
nn.ReLU(inplace=True),
nn.Dropout(p=0.5, inplace=False),
nn.Linear(in_features=4096, out_features=4096, bias=True),
nn.ReLU(inplace=True),
nn.Dropout(p=0.5, inplace=False),
nn.Linear(in_features=4096, out_features=out_shape, bias=True),
)
def forward(self,x):
return self.vgg(x)
class VIT(nn.Module):
def __init__(self,out_shape:int = 1000) -> None:
super().__init__()
self.vit = models.vit_b_16()
self.vit.head = nn.Linear(768,out_shape)
def forward(self,x):
return self.vit(x)
# functions to get models
def get_resnet_18_model():
model = Resnet18(out_shape=7)
return model
def get_resnet_50_model():
model = Resnet50(out_shape=7)
return model
def get_vgg_16_model():
model = VGG16(out_shape=7)
return model
def get_mobilenet_v2_model():
model = MobileNetV2(out_shape=7)
return model