rdkulkarni commited on
Commit
37fb3ed
1 Parent(s): 7eb04d1

Update predict.py

Browse files
Files changed (1) hide show
  1. predict.py +57 -50
predict.py CHANGED
@@ -8,7 +8,8 @@ from neural_network_model import initialize_existing_models, build_custom_models
8
  from utilities import process_image, get_input_args_predict
9
 
10
  def predict(image_path, model, topk=5):
11
-
 
12
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
13
  model.to(device)
14
  model.eval()
@@ -27,54 +28,60 @@ def predict(image_path, model, topk=5):
27
 
28
  return (ps, idxs)
29
 
30
- #0. Get user inputs
31
- in_arg = vars(get_input_args_predict())
32
- print("User arguments/hyperparameters or default used are as below")
33
- print(in_arg)
34
-
35
- #1. Get device for prediction and Load model from checkpoint along with some other information
36
- if in_arg['gpu'] == 'gpu' and torch.cuda.is_available():
37
- device = torch.device("cuda")
38
- checkpoint = torch.load(in_arg['save_dir'])
39
- else:
40
- device = "cpu"
41
- checkpoint = torch.load(in_arg['save_dir'], map_location = device)
42
- print(f"Using {device} device for predicting/inference")
43
-
44
- if checkpoint['arch_type'] == 'existing':
45
- model_ft, input_size = initialize_existing_models(checkpoint['arch'], checkpoint['arch_type'], len(checkpoint['class_to_idx']),
46
- checkpoint['feature_extract'], checkpoint['hidden_units'], use_pretrained=False)
47
- elif checkpoint['arch_type'] == 'custom':
48
- model_ft = build_custom_models(checkpoint['arch'], checkpoint['arch_type'], len(checkpoint['class_to_idx']), checkpoint['feature_extract'],
49
- checkpoint['hidden_units'], use_pretrained=True)
50
- else:
51
- print("Nothing to predict")
52
- exit()
53
 
 
 
 
 
 
 
 
 
54
 
55
- model_ft.class_to_idx = checkpoint['class_to_idx']
56
- model_ft.gpu_or_cpu = checkpoint['gpu_or_cpu']
57
- model_ft.load_state_dict(checkpoint['state_dict'])
58
- model_ft.to(device)
59
-
60
- #Predict
61
- # Get the prediction by passing image and other user preferences through the model
62
- probs, idxs = predict(image_path = in_arg['path'], model = model_ft, topk = in_arg['top_k'])
63
-
64
- # Swap class to index mapping with index to class mapping and then map the classes to flower category labels using the json file
65
- idx_to_class = {v: k for k, v in model_ft.class_to_idx.items()}
66
- with open('cat_to_name.json','r') as f:
67
- cat_to_name = json.load(f)
68
- names = list(map(lambda x: cat_to_name[f"{idx_to_class[x]}"],idxs))
69
-
70
- # Display final prediction and Top k most probable flower categories
71
- print("-"*60)
72
- print(" PREDICTION")
73
- print("-"*60)
74
- print("Image provided : {}" .format(in_arg['path']))
75
- print("Predicted Flower Name : {} (Class {} and Index {})" .format(names[0].upper(), idx_to_class[idxs[0]], idxs[0] ))
76
- print("Model used : {}" .format(checkpoint['arch']))
77
- print(f"The top {in_arg['top_k']} probabilities of the flower names")
78
- for name, prob in zip(names, probs):
79
- length = 30 - len(name)
80
- print(f"{name.title()}{' '*length}{round(prob*100,2)}%")
 
 
 
 
 
 
 
 
 
 
 
 
8
  from utilities import process_image, get_input_args_predict
9
 
10
  def predict(image_path, model, topk=5):
11
+
12
+
13
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
14
  model.to(device)
15
  model.eval()
 
28
 
29
  return (ps, idxs)
30
 
31
+ def process_input(image_path):
32
+ #0. Get user inputs
33
+ #in_arg = vars(get_input_args_predict())
34
+ #print("User arguments/hyperparameters or default used are as below")
35
+ #print(in_arg)
36
+ in_arg = {}
37
+ in_arg['gpu'] = 'gpu'
38
+ in_arg['save_dir'] = 'checkpoint-densenet121.pth'
39
+ in_arg['path'] = image_path
40
+ in_arg['top_k'] = 5
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
+ #1. Get device for prediction and Load model from checkpoint along with some other information
43
+ if in_arg['gpu'] == 'gpu' and torch.cuda.is_available():
44
+ device = torch.device("cuda")
45
+ checkpoint = torch.load(in_arg['save_dir'])
46
+ else:
47
+ device = "cpu"
48
+ checkpoint = torch.load(in_arg['save_dir'], map_location = device)
49
+ #print(f"Using {device} device for predicting/inference")
50
 
51
+ if checkpoint['arch_type'] == 'existing':
52
+ model_ft, input_size = initialize_existing_models(checkpoint['arch'], checkpoint['arch_type'], len(checkpoint['class_to_idx']),
53
+ checkpoint['feature_extract'], checkpoint['hidden_units'], use_pretrained=False)
54
+ elif checkpoint['arch_type'] == 'custom':
55
+ model_ft = build_custom_models(checkpoint['arch'], checkpoint['arch_type'], len(checkpoint['class_to_idx']), checkpoint['feature_extract'],
56
+ checkpoint['hidden_units'], use_pretrained=True)
57
+ else:
58
+ #print("Nothing to predict")
59
+ exit()
60
+
61
+
62
+ model_ft.class_to_idx = checkpoint['class_to_idx']
63
+ model_ft.gpu_or_cpu = checkpoint['gpu_or_cpu']
64
+ model_ft.load_state_dict(checkpoint['state_dict'])
65
+ model_ft.to(device)
66
+
67
+ #Predict
68
+ # Get the prediction by passing image and other user preferences through the model
69
+ probs, idxs = predict(image_path = in_arg['path'], model = model_ft, topk = in_arg['top_k'])
70
+
71
+ # Swap class to index mapping with index to class mapping and then map the classes to flower category labels using the json file
72
+ idx_to_class = {v: k for k, v in model_ft.class_to_idx.items()}
73
+ with open('cat_to_name.json','r') as f:
74
+ cat_to_name = json.load(f)
75
+ names = list(map(lambda x: cat_to_name[f"{idx_to_class[x]}"],idxs))
76
+
77
+ # Display final prediction and Top k most probable flower categories
78
+ #print("-"*60)
79
+ #print(" PREDICTION")
80
+ #print("-"*60)
81
+ #print("Image provided : {}" .format(in_arg['path']))
82
+ #print("Predicted Flower Name : {} (Class {} and Index {})" .format(names[0].upper(), idx_to_class[idxs[0]], idxs[0] ))
83
+ #print("Model used : {}" .format(checkpoint['arch']))
84
+ #print(f"The top {in_arg['top_k']} probabilities of the flower names")
85
+ #for name, prob in zip(names, probs):
86
+ # length = 30 - len(name)
87
+ # print(f"{name.title()}{' '*length}{round(prob*100,2)}%")