Food_Vision_Big / model.py
v0idcr0w's picture
Upload 6 files
bfd7be3
raw
history blame contribute delete
No virus
549 Bytes
import torchvision
import torch
from torch import nn
"""Script to load and create and instance of effnetb2 """
def create_effnetb2_model(num_classes=101):
weights=torchvision.models.EfficientNet_B2_Weights.DEFAULT
transforms=weights.transforms()
model=torchvision.models.efficientnet_b2(weights=weights)
for params in model.parameters():
params.requires_grad=False
model.classifier = nn.Sequential(
nn.Dropout(p=0.3, inplace=True),
nn.Linear(in_features=1408, out_features=num_classes)
)
return model, transforms