Mayanand commited on
Commit
f67ac78
1 Parent(s): cdfdfeb

update predict_image.py

Browse files
Files changed (1) hide show
  1. predict_image.py +3 -0
predict_image.py CHANGED
@@ -27,14 +27,17 @@ def load_model(model_name):
27
  if model_name == 'effb0':
28
  model = EffnetModel()
29
  fname = download_weights(effb0_net_url)
 
30
  elif model_name == 'res18':
31
  model = ResnetModel
32
  fname = download_weights(res18_net_url)
 
33
  else:
34
  raise ValueError('Enter correct model_name')
35
 
36
  # loading pretrained model
37
  state_dict = torch.load(fname)
 
38
  model.load_state_dict(state_dict['weights'])
39
  return model
40
 
 
27
  if model_name == 'effb0':
28
  model = EffnetModel()
29
  fname = download_weights(effb0_net_url)
30
+ print('loaded effnet')
31
  elif model_name == 'res18':
32
  model = ResnetModel
33
  fname = download_weights(res18_net_url)
34
+ print('loaded resnet')
35
  else:
36
  raise ValueError('Enter correct model_name')
37
 
38
  # loading pretrained model
39
  state_dict = torch.load(fname)
40
+ print(type(state_dict))
41
  model.load_state_dict(state_dict['weights'])
42
  return model
43