Spaces:
Runtime error
Runtime error
File size: 1,069 Bytes
7bebc28 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 |
from torchvision.models import efficientnet_b2, EfficientNet_B2_Weights
from typing import List
import torch.nn as nn
import torch
from torchinfo import summary
def create_effnet_b2_model(class_names: int = 3,
seed=42,
print_summary=False):
torch.cuda.manual_seed(seed)
torch.manual_seed(seed)
weights = EfficientNet_B2_Weights.DEFAULT
effnet_b2_transforms = weights.transforms()
effnet_b2 = efficientnet_b2(weights=weights)
for param in effnet_b2.parameters():
param.requires_grad = False
effnet_b2.classifier = nn.Sequential(nn.Dropout(p=0.3, inplace=True),
nn.Linear(in_features=1408, out_features=class_names, bias=True))
effnet_b2.classifier.requires_grad_ = True
if print_summary:
summary(model=effnet_b2,
input_size=[1, 3, 288, 288],
col_names=["input_size",
"output_size",
"trainable"])
return effnet_b2, effnet_b2_transforms
|