airplaneclassifier / model.py
noelsinghsr's picture
Update model.py
5a754f4
import torch
import torchvision
from torch import nn
def load_model():
loaded_model = torch.load('airplane_uncompiled.pth', map_location=torch.device('cpu'))
model_weights = torchvision.models.EfficientNet_B2_Weights.DEFAULT
transforms = model_weights.transforms()
return loaded_model, transforms