import torch import torchvision from torch import nn import data_setup, engine from pathlib import Path def train_and_save_effnetb2(seed: int=42): """ Returns an EfficientNetB2 feature extractor to classify pizza, steak, and sushi, its transforms, and its saved path. Optional argument of seed to set the seed for reproducibility. """ weights = torchvision.models.EfficientNet_B2_Weights.DEFAULT transforms = weights.transforms() model = torchvision.models.efficientnet_b2(weights=weights) for param in model.parameters(): param.requires_grad = False torch.manual_seed(seed) model.classifier = nn.Sequential( nn.Dropout(p=0.3, inplace=True), nn.Linear(in_features=1408, out_features=3) ) train_dataloader, test_dataloader, class_names = data_setup.create_dataloaders() optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) loss_fn = nn.CrossEntropyLoss() engine.train(model, train_dataloader, test_dataloader, optimizer, loss_fn, epochs=10) model_save_path = Path("foodvision_mini.pth") torch.save(obj=model.state_dict(), f=model_save_path) print(f"[INFO] Saving model to: {model_save_path}") return model, transforms, model_save_path def create_effnetb2_model(num_classes: int=3, seed: int=42): weights = torchvision.models.EfficientNet_B2_Weights.DEFAULT transforms = weights.transforms() model = torchvision.models.efficientnet_b2(weights=weights) for param in model.parameters(): param.requires_grad = False torch.manual_seed(seed) model.classifier = nn.Sequential( nn.Dropout(p=0.3, inplace=True), nn.Linear(in_features=1408, out_features=3) ) return model, transforms