File size: 324 Bytes
d88ab8c
 
 
 
 
 
5a754f4
d88ab8c
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
import torch
import torchvision
from torch import nn

def load_model():
    
    loaded_model = torch.load('airplane_uncompiled.pth', map_location=torch.device('cpu'))
    model_weights = torchvision.models.EfficientNet_B2_Weights.DEFAULT
    transforms = model_weights.transforms()
    
    return loaded_model, transforms