Spaces:
Sleeping
Sleeping
File size: 751 Bytes
2ee1a14 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 |
import torch
import torchvision
from torch import nn
def create_effnet(
pretrained_weights: torchvision.models.Weights,
model: torchvision.models,
in_features: int,
dropout: int,
out_features: int,
device: torch.device,
):
# Get the weights and setup the model
model = model(weights=pretrained_weights).to(device)
transforms = pretrained_weights.transforms()
# Freeze the base model layers
for param in model.features.parameters():
param.requires_grad = False
# Change the classifier head
model.classifier = nn.Sequential(
nn.Dropout(p=dropout, inplace=True),
nn.Linear(in_features=in_features, out_features=out_features),
).to(device)
return model, transforms
|