import torch import torchvision from torch import nn def load_model(): loaded_model = torch.load('airplane_uncompiled.pth', map_location=torch.device('cpu')) model_weights = torchvision.models.EfficientNet_B2_Weights.DEFAULT transforms = model_weights.transforms() return loaded_model, transforms