Spaces:
Build error
Build error
update predict_image.py
Browse files- predict_image.py +3 -0
predict_image.py
CHANGED
@@ -27,14 +27,17 @@ def load_model(model_name):
|
|
27 |
if model_name == 'effb0':
|
28 |
model = EffnetModel()
|
29 |
fname = download_weights(effb0_net_url)
|
|
|
30 |
elif model_name == 'res18':
|
31 |
model = ResnetModel
|
32 |
fname = download_weights(res18_net_url)
|
|
|
33 |
else:
|
34 |
raise ValueError('Enter correct model_name')
|
35 |
|
36 |
# loading pretrained model
|
37 |
state_dict = torch.load(fname)
|
|
|
38 |
model.load_state_dict(state_dict['weights'])
|
39 |
return model
|
40 |
|
|
|
27 |
if model_name == 'effb0':
|
28 |
model = EffnetModel()
|
29 |
fname = download_weights(effb0_net_url)
|
30 |
+
print('loaded effnet')
|
31 |
elif model_name == 'res18':
|
32 |
model = ResnetModel
|
33 |
fname = download_weights(res18_net_url)
|
34 |
+
print('loaded resnet')
|
35 |
else:
|
36 |
raise ValueError('Enter correct model_name')
|
37 |
|
38 |
# loading pretrained model
|
39 |
state_dict = torch.load(fname)
|
40 |
+
print(type(state_dict))
|
41 |
model.load_state_dict(state_dict['weights'])
|
42 |
return model
|
43 |
|