testspace / model.py
noelsinghsr's picture
Update model.py
644d850
import torch
import torchvision
from torch import nn
def load_model():
loaded_model = torch.load('cifar10_effnet_sm_uncompiled.pth', map_location=torch.device('cpu'))
model_weights = torchvision.models.EfficientNet_B2_Weights.DEFAULT
transforms = model_weights.transforms()
return loaded_model, transforms