rdkulkarni commited on
Commit
70a466d
1 Parent(s): 39bf8ff

Upload predict.py

Browse files
Files changed (1) hide show
  1. predict.py +80 -0
predict.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #Import libraries
2
+ import torch
3
+ import torchvision.models as models
4
+ import json
5
+
6
+ #Import User Defined libraries
7
+ from neural_network_model import initialize_existing_models, build_custom_models, set_parameter_requires_grad
8
+ from utilities import process_image, get_input_args_predict
9
+
10
+ def predict(image_path, model, topk=5):
11
+
12
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
13
+ model.to(device)
14
+ model.eval()
15
+
16
+ tensor_img = torch.FloatTensor(process_image(image_path))
17
+ tensor_img = tensor_img.unsqueeze(0)
18
+ tensor_img = tensor_img.to(device)
19
+ log_ps = model(tensor_img)
20
+ result = log_ps.topk(topk)
21
+ if torch.cuda.is_available(): #gpu Move it from gpu to cpu for numpy
22
+ ps = torch.exp(result[0].data).cpu().numpy()[0]
23
+ idxs = result[1].data.cpu().numpy()[0]
24
+ else: #cpu Keep it on cpu for nump
25
+ ps = torch.exp(result[0].data).numpy()[0]
26
+ idxs = result[1].data.numpy()[0]
27
+
28
+ return (ps, idxs)
29
+
30
+ #0. Get user inputs
31
+ in_arg = vars(get_input_args_predict())
32
+ print("User arguments/hyperparameters or default used are as below")
33
+ print(in_arg)
34
+
35
+ #1. Get device for prediction and Load model from checkpoint along with some other information
36
+ if in_arg['gpu'] == 'gpu' and torch.cuda.is_available():
37
+ device = torch.device("cuda")
38
+ checkpoint = torch.load(in_arg['save_dir'])
39
+ else:
40
+ device = "cpu"
41
+ checkpoint = torch.load(in_arg['save_dir'], map_location = device)
42
+ print(f"Using {device} device for predicting/inference")
43
+
44
+ if checkpoint['arch_type'] == 'existing':
45
+ model_ft, input_size = initialize_existing_models(checkpoint['arch'], checkpoint['arch_type'], len(checkpoint['class_to_idx']),
46
+ checkpoint['feature_extract'], checkpoint['hidden_units'], use_pretrained=False)
47
+ elif checkpoint['arch_type'] == 'custom':
48
+ model_ft = build_custom_models(checkpoint['arch'], checkpoint['arch_type'], len(checkpoint['class_to_idx']), checkpoint['feature_extract'],
49
+ checkpoint['hidden_units'], use_pretrained=True)
50
+ else:
51
+ print("Nothing to predict")
52
+ exit()
53
+
54
+
55
+ model_ft.class_to_idx = checkpoint['class_to_idx']
56
+ model_ft.gpu_or_cpu = checkpoint['gpu_or_cpu']
57
+ model_ft.load_state_dict(checkpoint['state_dict'])
58
+ model_ft.to(device)
59
+
60
+ #Predict
61
+ # Get the prediction by passing image and other user preferences through the model
62
+ probs, idxs = predict(image_path = in_arg['path'], model = model_ft, topk = in_arg['top_k'])
63
+
64
+ # Swap class to index mapping with index to class mapping and then map the classes to flower category labels using the json file
65
+ idx_to_class = {v: k for k, v in model_ft.class_to_idx.items()}
66
+ with open('cat_to_name.json','r') as f:
67
+ cat_to_name = json.load(f)
68
+ names = list(map(lambda x: cat_to_name[f"{idx_to_class[x]}"],idxs))
69
+
70
+ # Display final prediction and Top k most probable flower categories
71
+ print("-"*60)
72
+ print(" PREDICTION")
73
+ print("-"*60)
74
+ print("Image provided : {}" .format(in_arg['path']))
75
+ print("Predicted Flower Name : {} (Class {} and Index {})" .format(names[0].upper(), idx_to_class[idxs[0]], idxs[0] ))
76
+ print("Model used : {}" .format(checkpoint['arch']))
77
+ print(f"The top {in_arg['top_k']} probabilities of the flower names")
78
+ for name, prob in zip(names, probs):
79
+ length = 30 - len(name)
80
+ print(f"{name.title()}{' '*length}{round(prob*100,2)}%")