import torch from pathlib import Path from torch import nn from torchvision.models import efficientnet_b2, EfficientNet_B2_Weights from typing import Optional, Tuple DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' def create_effnetb2_model( num_classes: int, seed: Optional[int] = 42, load_st_dict: Optional[bool] = False ) -> Tuple[nn.Module, nn.Module]: torch.manual_seed(seed) torch.cuda.manual_seed(seed) weights = EfficientNet_B2_Weights.DEFAULT transforms = weights.transforms() model = efficientnet_b2(weights=weights) model.classifier = nn.Sequential( nn.Dropout(p=0.3, inplace=True), nn.Linear(in_features=1408, out_features=num_classes, bias=True) ).to(DEVICE) if load_st_dict: st_dict = Path('model.pth') model.load_state_dict(torch.load(st_dict, map_location=DEVICE)) for param in model.parameters(): param.requires_grad = False return model.to(DEVICE), transforms