Spaces:
Runtime error
Runtime error
from torchvision.models import efficientnet_b2, EfficientNet_B2_Weights | |
from typing import List | |
import torch.nn as nn | |
import torch | |
from torchinfo import summary | |
def create_effnet_b2_model(class_names: int = 3, | |
seed=42, | |
print_summary=False): | |
torch.cuda.manual_seed(seed) | |
torch.manual_seed(seed) | |
weights = EfficientNet_B2_Weights.DEFAULT | |
effnet_b2_transforms = weights.transforms() | |
effnet_b2 = efficientnet_b2(weights=weights) | |
for param in effnet_b2.parameters(): | |
param.requires_grad = False | |
effnet_b2.classifier = nn.Sequential(nn.Dropout(p=0.3, inplace=True), | |
nn.Linear(in_features=1408, out_features=class_names, bias=True)) | |
effnet_b2.classifier.requires_grad_ = True | |
if print_summary: | |
summary(model=effnet_b2, | |
input_size=[1, 3, 288, 288], | |
col_names=["input_size", | |
"output_size", | |
"trainable"]) | |
return effnet_b2, effnet_b2_transforms | |