Covid-19-Detection / model.py
fbrynpk's picture
Initial Commit
2ee1a14
raw
history blame contribute delete
No virus
751 Bytes
import torch
import torchvision
from torch import nn
def create_effnet(
pretrained_weights: torchvision.models.Weights,
model: torchvision.models,
in_features: int,
dropout: int,
out_features: int,
device: torch.device,
):
# Get the weights and setup the model
model = model(weights=pretrained_weights).to(device)
transforms = pretrained_weights.transforms()
# Freeze the base model layers
for param in model.features.parameters():
param.requires_grad = False
# Change the classifier head
model.classifier = nn.Sequential(
nn.Dropout(p=dropout, inplace=True),
nn.Linear(in_features=in_features, out_features=out_features),
).to(device)
return model, transforms