Spaces:
Running
Running
import torch | |
from pathlib import Path | |
from torch import nn | |
from torchvision.models import efficientnet_b2, EfficientNet_B2_Weights | |
from typing import Optional, Tuple | |
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' | |
def create_effnetb2_model( | |
num_classes: int, | |
seed: Optional[int] = 42, | |
load_st_dict: Optional[bool] = False | |
) -> Tuple[nn.Module, nn.Module]: | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed(seed) | |
weights = EfficientNet_B2_Weights.DEFAULT | |
transforms = weights.transforms() | |
model = efficientnet_b2(weights=weights) | |
model.classifier = nn.Sequential( | |
nn.Dropout(p=0.3, inplace=True), | |
nn.Linear(in_features=1408, out_features=num_classes, bias=True) | |
).to(DEVICE) | |
if load_st_dict: | |
st_dict = Path('model.pth') | |
model.load_state_dict(torch.load(st_dict, map_location=DEVICE)) | |
for param in model.parameters(): | |
param.requires_grad = False | |
return model.to(DEVICE), transforms | |