Munzali's picture
Update model.py
a9c42e5 verified
raw
history blame
2.01 kB
import torch
import torchvision
from torchvision import transforms
import torch.nn as nn
from torchvision.models import mobilenet_v2
# Load MobileNetV2 with pre-trained weights
def create_mobilenet_model(num_classes:int=4,
seed:int=42):
"""Creates an EfficientNetB2 feature extractor model and transforms.
Args:
num_classes (int, optional): number of classes in the classifier head.
Defaults to 3.
seed (int, optional): random seed value. Defaults to 42.
Returns:
model (torch.nn.Module): EffNetB2 feature extractor model.
transforms (torchvision.transforms): EffNetB2 image transforms.
"""
# Create EffNetB2 pretrained weights, transforms and model
transform = transforms.Compose([
transforms.Resize((224, 224)), # 1. Reshape all images to 224x224 (though some models may require different sizes)
transforms.ToTensor(), # 2. Turn image values to between 0 & 1
transforms.Normalize(mean=[0.485, 0.456, 0.406], # 3. A mean of [0.485, 0.456, 0.406] (across each colour channel)
std=[0.229, 0.224, 0.225]) # 4. A standard deviation of [0.229, 0.224, 0.225] (across each colour channel),
])
model = mobilenet_v2(pretrained=True)
# Freeze all layers in base model
# Freeze all base layers by setting requires_grad attribute to False
for param in model.parameters():
param.requires_grad = False
# Since we're creating a new layer with random weights (torch.nn.Linear),
# let's set the seeds
torch.manual_seed(42)
# Update the classifier head to suit our problem
model.classifier = nn.Sequential(
nn.Dropout(p=0.2, inplace=True),
nn.Linear(in_features=model.classifier[1].in_features, # Accessing the last layer of the classifier
out_features=num_classes,
bias=True)
)
return model, transform