ibrim commited on
Commit
0d4900d
1 Parent(s): 5720001

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +110 -0
app.py CHANGED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from gradio_utils import *
2
+ def process_images_gradcam(show_gradcam, gradcam_count, gradcam_layer, gradcam_opacity):
3
+ if show_gradcam:
4
+ inv_normalize = transforms.Normalize(
5
+ mean=[-1.9899, -1.9844, -1.7111],
6
+ std=[4.0486, 4.1152, 3.8314])
7
+ classes = ('plane', 'car', 'bird', 'cat', 'deer',
8
+ 'dog', 'frog', 'horse', 'ship', 'truck')
9
+ misclassified_data = get_misclassified_data(modelfin, "cuda", test_loader)
10
+ if gradcam_layer=="1":
11
+ images = display_gradcam_output(misclassified_data, classes, inv_normalize, modelfin, target_layers= [modelfin.model.layer1[-1]], targets=None, number_of_samples=gradcam_count, transparency=gradcam_opacity)
12
+ if gradcam_layer=="2":
13
+ images = display_gradcam_output(misclassified_data, classes, inv_normalize, modelfin, target_layers= [modelfin.model.layer2[-1]], targets=None, number_of_samples=gradcam_count, transparency=gradcam_opacity)
14
+ if gradcam_layer=="3":
15
+ images = display_gradcam_output(misclassified_data, classes, inv_normalize, modelfin, target_layers= [modelfin.model.layer3[-1]], targets=None, number_of_samples=gradcam_count, transparency=gradcam_opacity)
16
+ if gradcam_layer=="4":
17
+ images = display_gradcam_output(misclassified_data, classes, inv_normalize, modelfin, target_layers= [modelfin.model.layer4[-1]], targets=None, number_of_samples=gradcam_count, transparency=gradcam_opacity)
18
+ return images
19
+
20
+ def process_images_misclass(show_misclassify, misclassify_count):
21
+ if show_misclassify:
22
+ misclassified_data = get_misclassified_data(modelfin, "cuda", test_loader)
23
+ image = display_cifar_misclassified_data(misclassified_data, classes, inv_normalize, number_of_samples=misclassify_count)
24
+ return image
25
+
26
+ def predict_classes(upload_image, top_classes):
27
+ transform = transforms.Compose([
28
+ transforms.Resize((32, 32)), # Resize to 32x32 pixels
29
+ transforms.ToTensor(), # Convert image to tensor
30
+ transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], # CIFAR-10 normalization
31
+ std=[0.2023, 0.1994, 0.2010])])
32
+
33
+ # Load and transform an image
34
+ image = upload_image
35
+ image = transform(image)
36
+ image = image.unsqueeze(0)
37
+ device = next(modelfin.parameters()).device
38
+ image = image.to(device)
39
+ # Ensure the model is in evaluation mode
40
+ modelfin.eval()
41
+
42
+ # Disable gradient computation for inference
43
+ with torch.no_grad():
44
+ output = modelfin(image)
45
+
46
+ # Get the top 5 predictions and their indices
47
+ probabilities = torch.nn.functional.softmax(output, dim=1)
48
+ top_prob, top_catid = torch.topk(probabilities, top_classes)
49
+
50
+ # CIFAR-10 classes
51
+ classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
52
+
53
+
54
+
55
+ # Initialize an empty string to collect predictions
56
+ predictions_str = ""
57
+
58
+ # Collect top 5 predictions in the string with line breaks
59
+ for i in range(top_prob.size(1)):
60
+ predictions_str += f"{classes[top_catid[0][i]]}: {top_prob[0][i].item()*100:.2f}%\n"
61
+
62
+ # Print or return the complete predictions string
63
+ return predictions_str
64
+
65
+
66
+ with gr.Blocks() as demo:
67
+ with gr.Row():
68
+ with gr.Column():
69
+ show_gradcam = gr.Checkbox(label="Show GradCAM Images?")
70
+ gradcam_count = gr.Number(label="How many GradCAM images?", value=1, precision=0)
71
+ gradcam_layer = gr.Radio(choices=["1", "2", "3", "4"], label="Choose a layer", value=4)
72
+ gradcam_opacity = gr.Slider(minimum=0, maximum=1, label="Opacity of overlay", value=0.5)
73
+ # with gr.Column():
74
+ # show_misclassified = gr.Checkbox(label="Show Misclassified Images?")
75
+ # misclassified_count = gr.Number(label="How many misclassified images?", value=1, precision=0)
76
+
77
+ #uploaded_images = gr.File(label="Upload New Images", type="file", accept="image/*", multiple=True)
78
+ #top_classes = gr.Number(label="How many top classes to show?", value=5, minimum=1, maximum=10, precision=0)
79
+
80
+ submit_button = gr.Button("GradCam")
81
+ outputs = gr.Image(label="Output")
82
+
83
+ show_misclassify = gr.Checkbox(label="Show misclassified images?")
84
+ misclassify_count=gr.Number(label="How many misclassified images?")
85
+ submit_button_misclass = gr.Button("Misclassified")
86
+ outputs_misclass = gr.Image(label="Output")
87
+
88
+ upload_image = gr.Image(label="Upload your image", interactive = True, type='pil')
89
+ top_classes = gr.Number(label="How many top classes would you like to see?", maximum=10)
90
+ upload_btn = gr.Button("Classify your image")
91
+ show_classes = gr.Textbox(label="Your top classes", interactive=False)
92
+
93
+ submit_button_misclass.click(
94
+ process_images_misclass,
95
+ inputs=[show_misclassify, misclassify_count],
96
+ outputs=outputs_misclass
97
+ )
98
+ submit_button.click(
99
+ process_images_gradcam,
100
+ inputs=[show_gradcam, gradcam_count, gradcam_layer, gradcam_opacity],
101
+ outputs=outputs
102
+ )
103
+
104
+ upload_btn.click(
105
+ predict_classes,
106
+ inputs=[upload_image, top_classes],
107
+ outputs=show_classes
108
+ )
109
+
110
+ demo.launch()