File size: 751 Bytes
2ee1a14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch
import torchvision

from torch import nn


def create_effnet(
    pretrained_weights: torchvision.models.Weights,
    model: torchvision.models,
    in_features: int,
    dropout: int,
    out_features: int,
    device: torch.device,
):
    # Get the weights and setup the model
    model = model(weights=pretrained_weights).to(device)
    transforms = pretrained_weights.transforms()

    # Freeze the base model layers
    for param in model.features.parameters():
        param.requires_grad = False

    # Change the classifier head
    model.classifier = nn.Sequential(
        nn.Dropout(p=dropout, inplace=True),
        nn.Linear(in_features=in_features, out_features=out_features),
    ).to(device)

    return model, transforms