Modify layout and add TPH-YOLOv5 model

#1
Files changed (3) hide show
  1. README.md +3 -3
  2. app.py +235 -107
  3. requirements.txt +6 -1
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
  title: XAITK-Gradio
3
- emoji: 🐢
4
- colorFrom: yellow
5
- colorTo: green
6
  sdk: gradio
7
  sdk_version: 4.7.1
8
  app_file: app.py
 
1
  ---
2
  title: XAITK-Gradio
3
+ emoji: 🕵️‍♂️
4
+ colorFrom: purple
5
+ colorTo: blue
6
  sdk: gradio
7
  sdk_version: 4.7.1
8
  app_file: app.py
app.py CHANGED
@@ -3,10 +3,12 @@
3
  # This app makes use of the saliency generation example found in the base ``xaitk-saliency`` repo [here](https://github.com/XAITK/xaitk-saliency/blob/master/examples/OcclusionSaliency.ipynb), and explores integrating ``xaitk-saliency`` with ``Gradio`` to create an interactive interface for computing saliency maps.
4
 
5
  import os
 
6
  import PIL.Image
7
  import matplotlib.pyplot as plt # type: ignore
8
  import urllib
9
  import numpy as np
 
10
 
11
  import gradio as gr
12
  from gradio import ( # type: ignore
@@ -49,6 +51,7 @@ import torch
49
  import torchvision.transforms as transforms
50
  import torchvision.models as models
51
  import torch.nn.functional
 
52
 
53
  from smqtk_detection.impls.detect_image_objects.resnet_frcnn import ResNetFRCNN
54
  from xaitk_saliency.impls.gen_image_classifier_blackbox_sal.slidingwindow import SlidingWindowStack
@@ -57,7 +60,9 @@ from xaitk_saliency.impls.gen_object_detector_blackbox_sal.drise import RandomGr
57
  from xaitk_saliency.interfaces.gen_object_detector_blackbox_sal import GenerateObjectDetectorBlackboxSaliency
58
  from smqtk_detection.interfaces.detect_image_objects import DetectImageObjects
59
  from smqtk_classifier.interfaces.classify_image import ClassifyImage
 
60
 
 
61
 
62
  os.makedirs('data', exist_ok=True)
63
  test_image_filename = 'data/catdog.jpg'
@@ -72,7 +77,7 @@ model_input_size = (224, 224)
72
  model_mean = [0.485, 0.456, 0.406]
73
  model_loader = transforms.Compose([
74
  transforms.ToPILImage(),
75
- transforms.Resize(model_input_size),
76
  transforms.ToTensor(),
77
  transforms.Normalize(
78
  mean=model_mean,
@@ -84,32 +89,32 @@ def get_sal_labels(classes_file, custom_categories_list=None):
84
  if not os.path.isfile(classes_file):
85
  url = "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"
86
  _ = urllib.request.urlretrieve(url, classes_file)
87
-
88
  f = open(classes_file, "r")
89
  categories = [s.strip() for s in f.readlines()]
90
-
91
  if not custom_categories_list == None:
92
  sal_class_labels = custom_categories_list
93
  else:
94
  sal_class_labels = categories
95
-
96
  sal_class_idxs = [categories.index(lbl) for lbl in sal_class_labels]
97
-
98
  return sal_class_labels, sal_class_idxs
99
 
100
  def get_det_sal_labels(classes_file, custom_categories_list=None):
101
  if not os.path.isfile(classes_file):
102
  url = "https://raw.githubusercontent.com/matlab-deep-learning/Object-Detection-Using-Pretrained-YOLO-v2/main/%2Bhelper/coco-classes.txt"
103
  _ = urllib.request.urlretrieve(url, classes_file)
104
-
105
  f = open(classes_file, "r")
106
  categories = [s.strip() for s in f.readlines()]
107
-
108
  if not custom_categories_list == None:
109
  sal_obj_labels = custom_categories_list
110
  else:
111
  sal_obj_labels = categories
112
-
113
  sal_obj_idxs = [categories.index(lbl) for lbl in sal_obj_labels]
114
 
115
  return sal_obj_labels, sal_obj_idxs
@@ -131,9 +136,160 @@ def get_detection_model(model_choice):
131
  blackbox_detector = ResNetFRCNN(
132
  box_thresh=0.05,
133
  img_batch_size=1,
134
- use_cuda=False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  )
136
-
137
  else:
138
  raise Exception("Unknown Input")
139
 
@@ -142,21 +298,21 @@ def get_detection_model(model_choice):
142
  def get_saliency_algo(sal_choice):
143
  if sal_choice == "RISE":
144
  gen_sal = RISEStack(
145
- n=num_masks_state[-1],
146
- s=spatial_res_state[-1],
147
- p1=p1_state[-1],
148
- seed=seed_state[-1],
149
- threads=threads_state[-1],
150
  debiased=debiased_state[-1]
151
  )
152
-
153
  elif sal_choice == "SlidingWindowStack":
154
  gen_sal = SlidingWindowStack(
155
  window_size=eval(window_size_state[-1]),
156
  stride=eval(stride_state[-1]),
157
  threads=threads_state[-1]
158
  )
159
-
160
  else:
161
  raise Exception("Unknown Input")
162
 
@@ -168,22 +324,22 @@ def get_detection_saliency_algo(sal_choice):
168
  n=num_masks_state[-1],
169
  s=eval(occlusion_grid_state[-1]),
170
  p1=p1_state[-1],
171
- threads=threads_state[-1],
172
- seed=seed_state[-1],
173
  )
174
-
175
  elif sal_choice == "DRISE":
176
  gen_sal = DRISEStack(
177
- n=num_masks_state[-1],
178
- s=spatial_res_state[-1],
179
- p1=p1_state[-1],
180
- seed=seed_state[-1],
181
  threads=threads_state[-1]
182
  )
183
-
184
  else:
185
  raise Exception("Unknown Input")
186
-
187
  return gen_sal
188
 
189
 
@@ -202,27 +358,27 @@ class TorchResnet (ClassifyImage):
202
 
203
  def get_labels(self):
204
  return self.modified_class_labels
205
-
206
  def set_labels(self, class_labels):
207
  self.modified_class_labels = [lbl for lbl in class_labels]
208
-
209
  @torch.no_grad()
210
  def classify_images(self, image_iter):
211
  # Input may either be an NDaray, or some arbitrary iterable of NDarray images.
212
-
213
  model = get_model(img_cls_model_name[-1])
214
-
215
  for img in image_iter:
216
  image_tensor = model_loader(img).unsqueeze(0)
217
  if CUDA_AVAILABLE:
218
  image_tensor = image_tensor.cuda()
219
-
220
  feature_vec = model(image_tensor)
221
  # Converting feature extractor output to probabilities.
222
  class_conf = torch.nn.functional.softmax(feature_vec, dim=1).cpu().detach().numpy().squeeze()
223
  # Only return the confidences for the focus classes
224
  yield dict(zip(sal_class_labels, class_conf[sal_class_idxs]))
225
-
226
  def get_config(self):
227
  # Required by a parent class.
228
  return {}
@@ -256,7 +412,7 @@ def show_slider_parameters(choice):
256
  return Slider(visible=True), Slider(visible=False)
257
  else:
258
  raise Exception("Unknown Input")
259
-
260
  # Modify checkbox parameters based on chosen saliency algorithm
261
  def show_debiased_checkbox(choice):
262
  if choice == 'RISE':
@@ -268,7 +424,7 @@ def show_debiased_checkbox(choice):
268
 
269
  # Function that is called after clicking the "Classify" button in the demo
270
  def predict(x,top_n_classes):
271
-
272
  image_tensor = model_loader(x).unsqueeze(0)
273
  if CUDA_AVAILABLE:
274
  image_tensor = image_tensor.cuda()
@@ -277,18 +433,18 @@ def predict(x,top_n_classes):
277
  class_conf = torch.nn.functional.softmax(feature_vec, dim=1).cpu().detach().numpy().squeeze()
278
  labels = list(zip(sal_class_labels, class_conf[sal_class_idxs].tolist()))
279
  final_labels = dict(sorted(labels, key=lambda t: t[1],reverse=True)[:top_n_classes])
280
-
281
  return final_labels, Dropdown(choices=list(final_labels))
282
 
283
  # Interpretation function for image classification that implements the selected saliency algorithm and generates the class-wise saliency map visualizations
284
- def interpretation_function(image: np.ndarray,
285
  labels: dict,
286
- nth_class: str,
287
  img_alpha,
288
  sal_alpha,
289
  sal_range_min,
290
  sal_range_max):
291
-
292
  sal_generator = get_saliency_algo(img_cls_saliency_algo_name[-1])
293
  sal_generator.fill = blackbox_fill
294
  labels_list = labels.keys()
@@ -301,10 +457,10 @@ def interpretation_function(image: np.ndarray,
301
  sal_alpha,
302
  sal_range_min,
303
  sal_range_max)
304
-
305
  return fig
306
 
307
- def visualize_saliency_plot(image: np.ndarray,
308
  class_sal_map: np.ndarray,
309
  img_alpha,
310
  sal_alpha,
@@ -352,20 +508,20 @@ def run_detect(input_img: np.ndarray, num_detections: int):
352
  conf_score = str(round(score_list[int(max_scores_index[i,0])],4))
353
  label_with_score = str(i) + " : "+ label_name + " - " + conf_score
354
  final_label.append(label_with_score)
355
-
356
  bboxes_list = bboxes[:,:].astype(int).tolist()
357
 
358
  return (input_img, list(zip([f for f in bboxes_list], [l for l in final_label]))[:num_detections]), Dropdown(choices=[l for l in final_label][:num_detections])
359
 
360
  # Run saliency algorithm on the object detect predictions and generate corresponding visualizations
361
- def run_detect_saliency(input_img: np.ndarray,
362
  num_predictions,
363
- obj_label,
364
  img_alpha,
365
  sal_alpha,
366
  sal_range_min,
367
  sal_range_max):
368
-
369
  detect_model = get_detection_model(obj_det_model_name[-1])
370
  img_preds = list(list(detect_model([input_img]))[0])
371
  ref_preds = img_preds[:int(num_predictions)]
@@ -383,15 +539,11 @@ def run_detect_saliency(input_img: np.ndarray,
383
 
384
  ref_bboxes = np.array(ref_bboxes)
385
  ref_scores = np.array(ref_scores)
386
-
387
- print(f"Ref bboxes: {ref_bboxes.shape}")
388
- print(f"Ref scores: {ref_scores.shape}")
389
-
390
  sal_generator = get_detection_saliency_algo(obj_det_saliency_algo_name[-1])
391
  sal_generator.fill = blackbox_fill
392
-
393
  sal_maps = gen_det_saliency(input_img, detect_model, sal_generator,ref_bboxes,ref_scores)
394
- print(f"Saliency maps: {sal_maps.shape}")
395
 
396
  nth_class_index = int(obj_label.split(' : ')[0])
397
  scores = sal_maps[nth_class_index,:,:]
@@ -401,7 +553,7 @@ def run_detect_saliency(input_img: np.ndarray,
401
  sal_alpha,
402
  sal_range_min,
403
  sal_range_max)
404
-
405
  scores = np.clip(scores, sal_range_min, sal_range_max)
406
 
407
  return fig
@@ -421,99 +573,74 @@ def gen_det_saliency(input_img: np.ndarray,
421
 
422
  return sal_maps
423
 
424
- with gr.Blocks() as demo:
425
  with Tab("Image Classification"):
426
  with Row():
427
- with Column(scale=0.5):
428
  drop_list = Dropdown(value=img_cls_model_name[-1],choices=["ResNet-18","ResNet-50"],label="Choose Model",interactive="True")
429
- with Column(scale=0.5):
 
 
 
 
 
430
  drop_list_sal = Dropdown(value=img_cls_saliency_algo_name[-1],choices=["SlidingWindowStack","RISE"],label="Choose Saliency Algorithm",interactive="True")
431
- with Row():
432
- with Column(scale=0.33):
433
  window_size = Textbox(value=window_size_state[-1],label="Tuple of window size values (Press Enter to submit the input)",interactive=True,visible=False)
434
  masks = Number(value=num_masks_state[-1],label="Number of Random Masks (Press Enter to submit the input)",interactive=True,visible=True,precision=0)
435
- with Column(scale=0.33):
436
  stride = Textbox(value=stride_state[-1],label="Tuple of stride values (Press Enter to submit the input)" ,interactive=True,visible=False)
437
  spatial_res = Number(value=spatial_res_state[-1],label="Spatial Resolution of Masking Grid (Press Enter to submit the input)" ,interactive=True,visible=True,precision=0)
438
- with Column(scale=0.33):
439
- threads = Slider(value=threads_state[-1],label="Threads",interactive=True,visible=True)
440
- with Row():
441
- with Column(scale=0.33):
442
  seed = Number(value=seed_state[-1],label="Seed (Press Enter to submit the input)",interactive=True,visible=True,precision=0)
443
- with Column(scale=0.33):
444
  p1 = Slider(value=p1_state[-1],label="P1",interactive=True,visible=True, minimum=0,maximum=1,step=0.1)
445
- with Column(scale=0.33):
446
- debiased = Checkbox(value=debiased_state[-1],label="Debiased", interactive=True, visible=True)
447
- with Row():
448
- with Column():
449
- input_img = Image(label="Saliency Map Generation", width=640, height=480)
450
- num_classes = Slider(value=2,label="Top-N class labels", interactive=True,visible=True)
451
- classify = Button("Classify")
452
- with Column():
453
- class_label = Label(label="Predicted Class")
454
- with Column():
455
- with Row():
456
- class_name = Dropdown(label="Class to compute saliency",interactive=True,visible=True)
457
  with Row():
458
  img_alpha = Slider(value=0.7,label="Image Opacity",interactive=True,visible=True,minimum=0,maximum=1,step=0.1)
459
  sal_alpha = Slider(value=0.3,label="Saliency Map Opacity",interactive=True,visible=True,minimum=0,maximum=1,step=0.1)
460
  with Row():
461
  min_sal_range = Slider(value=0,label="Minimum Saliency Value",interactive=True,visible=True,minimum=-1,maximum=1,step=0.05)
462
  max_sal_range = Slider(value=1,label="Maximum Saliency Value",interactive=True,visible=True,minimum=-1,maximum=1,step=0.05)
463
- with Row():
464
- generate_saliency = Button("Generate Saliency")
465
- with Column():
466
- with Tabs():
467
- with TabItem("Display interpretation with plot"):
468
- interpretation_plot = Plot()
469
 
470
  with Tab("Object Detection"):
471
  with Row():
472
- with Column(scale=0.5):
473
- drop_list_detect_model = Dropdown(value=obj_det_model_name[-1],choices=["Faster-RCNN"],label="Choose Model",interactive="True")
474
- with Column(scale=0.5):
 
 
 
 
 
 
475
  drop_list_detect_sal = Dropdown(value=obj_det_saliency_algo_name[-1],choices=["RandomGridStack","DRISE"],label="Choose Saliency Algorithm",interactive="True")
476
- with Row():
477
- with Column(scale=0.33):
478
  masks_detect = Number(value=num_masks_state[-1],label="Number of Random Masks (Press Enter to submit the input)",interactive=True,visible=True,precision=0)
479
  occlusion_grid_size = Textbox(value=occlusion_grid_state[-1],label="Tuple of occlusion grid size values (Press Enter to submit the input)",interactive=True,visible=False)
480
  spatial_res_detect = Number(value=spatial_res_state[-1],label="Spatial Resolution of Masking Grid (Press Enter to submit the input)" ,interactive=True,visible=True,precision=0)
481
- with Column(scale=0.33):
482
  seed_detect = Number(value=seed_state[-1],label="Seed (Press Enter to submit the input)",interactive=True,visible=True,precision=0)
483
  p1_detect = Slider(value=p1_state[-1],label="P1",interactive=True,visible=True, minimum=0,maximum=1,step=0.1)
484
- with Column(scale=0.33):
485
  threads_detect = Slider(value=threads_state[-1],label="Threads",interactive=True,visible=True)
486
- with Row():
487
- with Column():
488
- input_img_detect = Image(label="Saliency Map Generation", width=640, height=480)
489
- num_detections = Slider(value=2,label="Top-N detections", interactive=True,visible=True)
490
- detection = Button("Run Detection Algorithm")
491
- with Column():
492
- detect_label = AnnotatedImage(label="Detections")
493
- with Column():
494
- with Row():
495
- class_name_det = Dropdown(label="Detection to compute saliency",interactive=True,visible=True)
496
  with Row():
497
  img_alpha_det = Slider(value=0.7,label="Image Opacity",interactive=True,visible=True,minimum=0,maximum=1,step=0.1)
498
  sal_alpha_det = Slider(value=0.3,label="Saliency Map Opacity",interactive=True,visible=True,minimum=0,maximum=1,step=0.1)
499
  with Row():
500
  min_sal_range_det = Slider(value=0.95,label="Minimum Saliency Value",interactive=True,visible=True,minimum=0.80,maximum=1,step=0.05)
501
  max_sal_range_det = Slider(value=1,label="Maximum Saliency Value",interactive=True,visible=True,minimum=0.80,maximum=1,step=0.05)
502
- with Row():
503
- generate_det_saliency = Button("Generate Saliency")
504
- with Column():
505
- with Tabs():
506
- with TabItem("Display saliency map plot"):
507
- det_saliency_plot = Plot()
508
 
509
- # Image Classification dropdown list event listeners
510
  drop_list.select(select_img_cls_model,drop_list,drop_list)
511
  drop_list_sal.select(select_img_cls_saliency_algo,drop_list_sal,drop_list_sal)
512
  drop_list_sal.change(show_textbox_parameters,drop_list_sal,[window_size,stride,masks,spatial_res,seed])
513
  drop_list_sal.change(show_slider_parameters,drop_list_sal,[threads,p1])
514
  drop_list_sal.change(show_debiased_checkbox,drop_list_sal,debiased)
515
 
516
- # Image Classification textbox, slider and checkbox event listeners
517
  window_size.submit(enter_window_size,window_size,window_size)
518
  masks.submit(enter_num_masks,masks,masks)
519
  stride.submit(enter_stride, stride, stride)
@@ -533,7 +660,7 @@ with gr.Blocks() as demo:
533
  drop_list_detect_sal.change(show_slider_parameters,drop_list_detect_sal,[threads_detect,p1_detect])
534
  drop_list_detect_sal.change(show_textbox_parameters,drop_list_detect_sal,[masks_detect,spatial_res_detect,seed_detect,occlusion_grid_size])
535
 
536
- # Object detection textbox and slider event listeners
537
  masks_detect.submit(enter_num_masks,masks_detect,masks_detect)
538
  occlusion_grid_size.submit(enter_occlusion_grid_size,occlusion_grid_size,occlusion_grid_size)
539
  spatial_res_detect.submit(enter_spatial_res, spatial_res_detect, spatial_res_detect)
@@ -545,4 +672,5 @@ with gr.Blocks() as demo:
545
  detection.click(run_detect, [input_img_detect, num_detections], [detect_label,class_name_det])
546
  generate_det_saliency.click(run_detect_saliency,[input_img_detect, num_detections, class_name_det, img_alpha_det, sal_alpha_det, min_sal_range_det, max_sal_range_det],det_saliency_plot)
547
 
548
- demo.launch(share=True)
 
 
3
  # This app makes use of the saliency generation example found in the base ``xaitk-saliency`` repo [here](https://github.com/XAITK/xaitk-saliency/blob/master/examples/OcclusionSaliency.ipynb), and explores integrating ``xaitk-saliency`` with ``Gradio`` to create an interactive interface for computing saliency maps.
4
 
5
  import os
6
+ import sys
7
  import PIL.Image
8
  import matplotlib.pyplot as plt # type: ignore
9
  import urllib
10
  import numpy as np
11
+ from git import Repo
12
 
13
  import gradio as gr
14
  from gradio import ( # type: ignore
 
51
  import torchvision.transforms as transforms
52
  import torchvision.models as models
53
  import torch.nn.functional
54
+ from torch.utils.data import Dataset, DataLoader
55
 
56
  from smqtk_detection.impls.detect_image_objects.resnet_frcnn import ResNetFRCNN
57
  from xaitk_saliency.impls.gen_image_classifier_blackbox_sal.slidingwindow import SlidingWindowStack
 
60
  from xaitk_saliency.interfaces.gen_object_detector_blackbox_sal import GenerateObjectDetectorBlackboxSaliency
61
  from smqtk_detection.interfaces.detect_image_objects import DetectImageObjects
62
  from smqtk_classifier.interfaces.classify_image import ClassifyImage
63
+ from smqtk_image_io import AxisAlignedBoundingBox
64
 
65
+ from typing import Iterable, Dict, Hashable, Tuple
66
 
67
  os.makedirs('data', exist_ok=True)
68
  test_image_filename = 'data/catdog.jpg'
 
77
  model_mean = [0.485, 0.456, 0.406]
78
  model_loader = transforms.Compose([
79
  transforms.ToPILImage(),
80
+ transforms.Resize(model_input_size),
81
  transforms.ToTensor(),
82
  transforms.Normalize(
83
  mean=model_mean,
 
89
  if not os.path.isfile(classes_file):
90
  url = "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"
91
  _ = urllib.request.urlretrieve(url, classes_file)
92
+
93
  f = open(classes_file, "r")
94
  categories = [s.strip() for s in f.readlines()]
95
+
96
  if not custom_categories_list == None:
97
  sal_class_labels = custom_categories_list
98
  else:
99
  sal_class_labels = categories
100
+
101
  sal_class_idxs = [categories.index(lbl) for lbl in sal_class_labels]
102
+
103
  return sal_class_labels, sal_class_idxs
104
 
105
  def get_det_sal_labels(classes_file, custom_categories_list=None):
106
  if not os.path.isfile(classes_file):
107
  url = "https://raw.githubusercontent.com/matlab-deep-learning/Object-Detection-Using-Pretrained-YOLO-v2/main/%2Bhelper/coco-classes.txt"
108
  _ = urllib.request.urlretrieve(url, classes_file)
109
+
110
  f = open(classes_file, "r")
111
  categories = [s.strip() for s in f.readlines()]
112
+
113
  if not custom_categories_list == None:
114
  sal_obj_labels = custom_categories_list
115
  else:
116
  sal_obj_labels = categories
117
+
118
  sal_obj_idxs = [categories.index(lbl) for lbl in sal_obj_labels]
119
 
120
  return sal_obj_labels, sal_obj_idxs
 
136
  blackbox_detector = ResNetFRCNN(
137
  box_thresh=0.05,
138
  img_batch_size=1,
139
+ use_cuda=CUDA_AVAILABLE
140
+ )
141
+ elif model_choice == "TPH-YOLOv5":
142
+ dest = os.path.join(data_path, 'tph-yolov5')
143
+ if not os.path.isdir(dest):
144
+ Repo.clone_from("https://github.com/cv516Buaa/tph-yolov5.git", dest)
145
+ sys.path.insert(1, dest)
146
+
147
+ # imports from TPH-YOLOv5 github repo
148
+ from utils.augmentations import letterbox
149
+ from models.experimental import attempt_load
150
+ from utils.datasets import LoadImages
151
+ from utils.general import non_max_suppression, scale_coords
152
+
153
+ class YOLOVisdrone(DetectImageObjects):
154
+ def __init__(
155
+ self,
156
+ weights,
157
+ img_size=(640, 640),
158
+ batch_size=1,
159
+ conf_thresh=0.5,
160
+ iou_thresh=0.5,
161
+ use_cuda=False,
162
+ num_workers=4
163
+ ):
164
+ """
165
+ img_size: size of image input to model
166
+ batch_size: number of images to input as once
167
+ conf_thresh: confidence threshold for detection results
168
+ iou_thresh: IOU threshold for NMS
169
+ use_cuda: use CUDA device to compute detections
170
+ num_workers: number of worker processes to use for data loading
171
+ """
172
+
173
+ self.img_size = np.array(img_size)
174
+
175
+ if use_cuda:
176
+ self.device = torch.device('cuda:0')
177
+ else:
178
+ self.device = torch.device('cpu')
179
+
180
+ self.model = attempt_load(weights).to(self.device)
181
+ self.model = self.model.eval()
182
+
183
+ self.conf_thresh = conf_thresh
184
+ self.iou_thresh = iou_thresh
185
+
186
+ self.batch_size = batch_size
187
+ self.num_workers = num_workers
188
+
189
+ with torch.no_grad():
190
+ _ = self.model(torch.zeros(1, 3, *self.img_size).to(self.device)) # warm up
191
+
192
+ def detect_objects(
193
+ self,
194
+ imgIter: Iterable[np.ndarray]
195
+ ) -> Iterable[Iterable[Tuple[AxisAlignedBoundingBox, Dict[Hashable, float]]]]:
196
+
197
+ # pytorch DataLoader for passed images
198
+ dataset = DataLoader(
199
+ pytorchDataset(
200
+ imgIter,
201
+ img_size=self.img_size,
202
+
203
+ ),
204
+ batch_size=self.batch_size,
205
+ num_workers=self.num_workers
206
+ )
207
+
208
+ # list of AxisAlignedBoundingBox detections to return
209
+ preds = []
210
+ for i, (img_batch, hs, ws) in enumerate(dataset):
211
+ # load batch and normalize
212
+ img_batch = img_batch.to(self.device)
213
+ img_batch = img_batch.float()
214
+ img_batch /= 255
215
+
216
+ # pass through model
217
+ with torch.no_grad():
218
+ pred_batch = self.model(img_batch)[0]
219
+
220
+ # perform NMS and scale detections to original image dimensions
221
+ for img_pred, h, w in zip(pred_batch, hs, ws):
222
+ img_pred = non_max_suppression(
223
+ img_pred[None], conf_thres=self.conf_thresh, iou_thres=self.iou_thresh)[0]
224
+ img_pred[:, :4] = scale_coords(
225
+ img_batch.shape[2:], img_pred[:, :4], (h, w))
226
+ img_pred = img_pred.cpu().numpy()
227
+
228
+ preds.append(pred_mat_to_list(img_pred))
229
+
230
+ return preds
231
+
232
+ # requried by interface
233
+ def get_config(self):
234
+ return {}
235
+
236
+
237
+ class pytorchDataset(Dataset):
238
+ """
239
+ pyTorch DataLoader for images. Resizes image to model input size and
240
+ returns original height and width as well.
241
+ """
242
+
243
+ def __init__(self, imgs, img_size=[640, 640]):
244
+ self.imgs = list(imgs)
245
+ self.img_size = img_size
246
+
247
+ def __getitem__(self, idx):
248
+ img = self.imgs[idx]
249
+ h = img.shape[0]
250
+ w = img.shape[1]
251
+
252
+ img = letterbox(img, new_shape=self.img_size, auto=True)[0]
253
+ img = img.transpose((2, 0, 1))
254
+ img = np.ascontiguousarray(img)
255
+
256
+ return img, h, w
257
+
258
+ def __len__(self):
259
+ return len(self.imgs)
260
+
261
+
262
+ def pred_mat_to_list(preds):
263
+ """
264
+ Convert prediction matrix model output to AxisAlignedBoundingBox format.
265
+ """
266
+ pred_list = []
267
+
268
+ for pred in preds:
269
+ bbox = AxisAlignedBoundingBox(pred[0:2], pred[2:4])
270
+
271
+ CLASS_NAMES = ['pedestrian', 'people', 'bicycle', 'car', 'van', 'truck',
272
+ 'tricycle', 'awning-tricycle', 'bus', 'motor']
273
+ score_dict = dict.fromkeys(CLASS_NAMES, 0)
274
+ score_dict[CLASS_NAMES[int(pred[5])]] = pred[4]
275
+
276
+ pred_list.append((bbox, score_dict))
277
+
278
+ return pred_list
279
+
280
+ model_file = os.path.join(data_path,'tph-yolov5.pth')
281
+ if not os.path.isfile(model_file):
282
+ urllib.request.urlretrieve('https://data.kitware.com/api/v1/item/623880d04acac99f429fe3bf/download', model_file)
283
+
284
+ blackbox_detector = YOLOVisdrone(
285
+ weights=model_file,
286
+ img_size=(1536,1536),
287
+ batch_size=1,
288
+ use_cuda=CUDA_AVAILABLE,
289
+ num_workers=4,
290
+ conf_thresh=0.1,
291
+ iou_thresh=0.5
292
  )
 
293
  else:
294
  raise Exception("Unknown Input")
295
 
 
298
  def get_saliency_algo(sal_choice):
299
  if sal_choice == "RISE":
300
  gen_sal = RISEStack(
301
+ n=num_masks_state[-1],
302
+ s=spatial_res_state[-1],
303
+ p1=p1_state[-1],
304
+ seed=seed_state[-1],
305
+ threads=threads_state[-1],
306
  debiased=debiased_state[-1]
307
  )
308
+
309
  elif sal_choice == "SlidingWindowStack":
310
  gen_sal = SlidingWindowStack(
311
  window_size=eval(window_size_state[-1]),
312
  stride=eval(stride_state[-1]),
313
  threads=threads_state[-1]
314
  )
315
+
316
  else:
317
  raise Exception("Unknown Input")
318
 
 
324
  n=num_masks_state[-1],
325
  s=eval(occlusion_grid_state[-1]),
326
  p1=p1_state[-1],
327
+ threads=threads_state[-1],
328
+ seed=seed_state[-1],
329
  )
330
+
331
  elif sal_choice == "DRISE":
332
  gen_sal = DRISEStack(
333
+ n=num_masks_state[-1],
334
+ s=spatial_res_state[-1],
335
+ p1=p1_state[-1],
336
+ seed=seed_state[-1],
337
  threads=threads_state[-1]
338
  )
339
+
340
  else:
341
  raise Exception("Unknown Input")
342
+
343
  return gen_sal
344
 
345
 
 
358
 
359
  def get_labels(self):
360
  return self.modified_class_labels
361
+
362
  def set_labels(self, class_labels):
363
  self.modified_class_labels = [lbl for lbl in class_labels]
364
+
365
  @torch.no_grad()
366
  def classify_images(self, image_iter):
367
  # Input may either be an NDaray, or some arbitrary iterable of NDarray images.
368
+
369
  model = get_model(img_cls_model_name[-1])
370
+
371
  for img in image_iter:
372
  image_tensor = model_loader(img).unsqueeze(0)
373
  if CUDA_AVAILABLE:
374
  image_tensor = image_tensor.cuda()
375
+
376
  feature_vec = model(image_tensor)
377
  # Converting feature extractor output to probabilities.
378
  class_conf = torch.nn.functional.softmax(feature_vec, dim=1).cpu().detach().numpy().squeeze()
379
  # Only return the confidences for the focus classes
380
  yield dict(zip(sal_class_labels, class_conf[sal_class_idxs]))
381
+
382
  def get_config(self):
383
  # Required by a parent class.
384
  return {}
 
412
  return Slider(visible=True), Slider(visible=False)
413
  else:
414
  raise Exception("Unknown Input")
415
+
416
  # Modify checkbox parameters based on chosen saliency algorithm
417
  def show_debiased_checkbox(choice):
418
  if choice == 'RISE':
 
424
 
425
  # Function that is called after clicking the "Classify" button in the demo
426
  def predict(x,top_n_classes):
427
+
428
  image_tensor = model_loader(x).unsqueeze(0)
429
  if CUDA_AVAILABLE:
430
  image_tensor = image_tensor.cuda()
 
433
  class_conf = torch.nn.functional.softmax(feature_vec, dim=1).cpu().detach().numpy().squeeze()
434
  labels = list(zip(sal_class_labels, class_conf[sal_class_idxs].tolist()))
435
  final_labels = dict(sorted(labels, key=lambda t: t[1],reverse=True)[:top_n_classes])
436
+
437
  return final_labels, Dropdown(choices=list(final_labels))
438
 
439
  # Interpretation function for image classification that implements the selected saliency algorithm and generates the class-wise saliency map visualizations
440
+ def interpretation_function(image: np.ndarray,
441
  labels: dict,
442
+ nth_class: str,
443
  img_alpha,
444
  sal_alpha,
445
  sal_range_min,
446
  sal_range_max):
447
+
448
  sal_generator = get_saliency_algo(img_cls_saliency_algo_name[-1])
449
  sal_generator.fill = blackbox_fill
450
  labels_list = labels.keys()
 
457
  sal_alpha,
458
  sal_range_min,
459
  sal_range_max)
460
+
461
  return fig
462
 
463
+ def visualize_saliency_plot(image: np.ndarray,
464
  class_sal_map: np.ndarray,
465
  img_alpha,
466
  sal_alpha,
 
508
  conf_score = str(round(score_list[int(max_scores_index[i,0])],4))
509
  label_with_score = str(i) + " : "+ label_name + " - " + conf_score
510
  final_label.append(label_with_score)
511
+
512
  bboxes_list = bboxes[:,:].astype(int).tolist()
513
 
514
  return (input_img, list(zip([f for f in bboxes_list], [l for l in final_label]))[:num_detections]), Dropdown(choices=[l for l in final_label][:num_detections])
515
 
516
  # Run saliency algorithm on the object detect predictions and generate corresponding visualizations
517
+ def run_detect_saliency(input_img: np.ndarray,
518
  num_predictions,
519
+ obj_label,
520
  img_alpha,
521
  sal_alpha,
522
  sal_range_min,
523
  sal_range_max):
524
+
525
  detect_model = get_detection_model(obj_det_model_name[-1])
526
  img_preds = list(list(detect_model([input_img]))[0])
527
  ref_preds = img_preds[:int(num_predictions)]
 
539
 
540
  ref_bboxes = np.array(ref_bboxes)
541
  ref_scores = np.array(ref_scores)
542
+
 
 
 
543
  sal_generator = get_detection_saliency_algo(obj_det_saliency_algo_name[-1])
544
  sal_generator.fill = blackbox_fill
545
+
546
  sal_maps = gen_det_saliency(input_img, detect_model, sal_generator,ref_bboxes,ref_scores)
 
547
 
548
  nth_class_index = int(obj_label.split(' : ')[0])
549
  scores = sal_maps[nth_class_index,:,:]
 
553
  sal_alpha,
554
  sal_range_min,
555
  sal_range_max)
556
+
557
  scores = np.clip(scores, sal_range_min, sal_range_max)
558
 
559
  return fig
 
573
 
574
  return sal_maps
575
 
576
+ with gr.Blocks() as xaitk_demo:
577
  with Tab("Image Classification"):
578
  with Row():
579
+ with Column():
580
  drop_list = Dropdown(value=img_cls_model_name[-1],choices=["ResNet-18","ResNet-50"],label="Choose Model",interactive="True")
581
+ input_img = Image(label="Input Image")
582
+ num_classes = Slider(value=2,label="Top-N Class Labels", interactive=True,visible=True)
583
+ classify = Button("Classify")
584
+ class_label = Label(label="Predictions")
585
+ class_name = Dropdown(label="Class to Compute Saliency",interactive=True,visible=True)
586
+ with Column():
587
  drop_list_sal = Dropdown(value=img_cls_saliency_algo_name[-1],choices=["SlidingWindowStack","RISE"],label="Choose Saliency Algorithm",interactive="True")
 
 
588
  window_size = Textbox(value=window_size_state[-1],label="Tuple of window size values (Press Enter to submit the input)",interactive=True,visible=False)
589
  masks = Number(value=num_masks_state[-1],label="Number of Random Masks (Press Enter to submit the input)",interactive=True,visible=True,precision=0)
 
590
  stride = Textbox(value=stride_state[-1],label="Tuple of stride values (Press Enter to submit the input)" ,interactive=True,visible=False)
591
  spatial_res = Number(value=spatial_res_state[-1],label="Spatial Resolution of Masking Grid (Press Enter to submit the input)" ,interactive=True,visible=True,precision=0)
592
+ debiased = Checkbox(value=debiased_state[-1],label="Debiased", interactive=True, visible=True)
 
 
 
593
  seed = Number(value=seed_state[-1],label="Seed (Press Enter to submit the input)",interactive=True,visible=True,precision=0)
 
594
  p1 = Slider(value=p1_state[-1],label="P1",interactive=True,visible=True, minimum=0,maximum=1,step=0.1)
595
+ threads = Slider(value=threads_state[-1],label="Threads",interactive=True,visible=True)
596
+ with Tabs():
597
+ with TabItem("Display Interpretation with Plot"):
598
+ interpretation_plot = Plot()
 
 
 
 
 
 
 
 
599
  with Row():
600
  img_alpha = Slider(value=0.7,label="Image Opacity",interactive=True,visible=True,minimum=0,maximum=1,step=0.1)
601
  sal_alpha = Slider(value=0.3,label="Saliency Map Opacity",interactive=True,visible=True,minimum=0,maximum=1,step=0.1)
602
  with Row():
603
  min_sal_range = Slider(value=0,label="Minimum Saliency Value",interactive=True,visible=True,minimum=-1,maximum=1,step=0.05)
604
  max_sal_range = Slider(value=1,label="Maximum Saliency Value",interactive=True,visible=True,minimum=-1,maximum=1,step=0.05)
605
+ generate_saliency = Button("Generate Saliency")
 
 
 
 
 
606
 
607
  with Tab("Object Detection"):
608
  with Row():
609
+ with Column():
610
+ drop_list_detect_model = Dropdown(value=obj_det_model_name[-1],choices=["Faster-RCNN", "TPH-YOLOv5"],label="Choose Model",interactive="True")
611
+ input_img_detect = Image(label="Input Image")
612
+ num_detections = Slider(value=2,label="Top-N Detections", interactive=True,visible=True)
613
+ detection = Button("Run Detection Algorithm")
614
+ detect_label = AnnotatedImage(label="Detections")
615
+ class_name_det = Dropdown(label="Detection to Compute Saliency",interactive=True,visible=True)
616
+
617
+ with Column():
618
  drop_list_detect_sal = Dropdown(value=obj_det_saliency_algo_name[-1],choices=["RandomGridStack","DRISE"],label="Choose Saliency Algorithm",interactive="True")
 
 
619
  masks_detect = Number(value=num_masks_state[-1],label="Number of Random Masks (Press Enter to submit the input)",interactive=True,visible=True,precision=0)
620
  occlusion_grid_size = Textbox(value=occlusion_grid_state[-1],label="Tuple of occlusion grid size values (Press Enter to submit the input)",interactive=True,visible=False)
621
  spatial_res_detect = Number(value=spatial_res_state[-1],label="Spatial Resolution of Masking Grid (Press Enter to submit the input)" ,interactive=True,visible=True,precision=0)
 
622
  seed_detect = Number(value=seed_state[-1],label="Seed (Press Enter to submit the input)",interactive=True,visible=True,precision=0)
623
  p1_detect = Slider(value=p1_state[-1],label="P1",interactive=True,visible=True, minimum=0,maximum=1,step=0.1)
 
624
  threads_detect = Slider(value=threads_state[-1],label="Threads",interactive=True,visible=True)
625
+ with Tabs():
626
+ with TabItem("Display saliency map plot"):
627
+ det_saliency_plot = Plot()
 
 
 
 
 
 
 
628
  with Row():
629
  img_alpha_det = Slider(value=0.7,label="Image Opacity",interactive=True,visible=True,minimum=0,maximum=1,step=0.1)
630
  sal_alpha_det = Slider(value=0.3,label="Saliency Map Opacity",interactive=True,visible=True,minimum=0,maximum=1,step=0.1)
631
  with Row():
632
  min_sal_range_det = Slider(value=0.95,label="Minimum Saliency Value",interactive=True,visible=True,minimum=0.80,maximum=1,step=0.05)
633
  max_sal_range_det = Slider(value=1,label="Maximum Saliency Value",interactive=True,visible=True,minimum=0.80,maximum=1,step=0.05)
634
+ generate_det_saliency = Button("Generate Saliency")
 
 
 
 
 
635
 
636
+ # Image Classification dropdown list event listeners
637
  drop_list.select(select_img_cls_model,drop_list,drop_list)
638
  drop_list_sal.select(select_img_cls_saliency_algo,drop_list_sal,drop_list_sal)
639
  drop_list_sal.change(show_textbox_parameters,drop_list_sal,[window_size,stride,masks,spatial_res,seed])
640
  drop_list_sal.change(show_slider_parameters,drop_list_sal,[threads,p1])
641
  drop_list_sal.change(show_debiased_checkbox,drop_list_sal,debiased)
642
 
643
+ # Image Classification textbox, slider and checkbox event listeners
644
  window_size.submit(enter_window_size,window_size,window_size)
645
  masks.submit(enter_num_masks,masks,masks)
646
  stride.submit(enter_stride, stride, stride)
 
660
  drop_list_detect_sal.change(show_slider_parameters,drop_list_detect_sal,[threads_detect,p1_detect])
661
  drop_list_detect_sal.change(show_textbox_parameters,drop_list_detect_sal,[masks_detect,spatial_res_detect,seed_detect,occlusion_grid_size])
662
 
663
+ # Object detection textbox and slider event listeners
664
  masks_detect.submit(enter_num_masks,masks_detect,masks_detect)
665
  occlusion_grid_size.submit(enter_occlusion_grid_size,occlusion_grid_size,occlusion_grid_size)
666
  spatial_res_detect.submit(enter_spatial_res, spatial_res_detect, spatial_res_detect)
 
672
  detection.click(run_detect, [input_img_detect, num_detections], [detect_label,class_name_det])
673
  generate_det_saliency.click(run_detect_saliency,[input_img_detect, num_detections, class_name_det, img_alpha_det, sal_alpha_det, min_sal_range_det, max_sal_range_det],det_saliency_plot)
674
 
675
+
676
+ xaitk_demo.launch(show_error=True)
requirements.txt CHANGED
@@ -2,4 +2,9 @@ xaitk-saliency
2
  torch
3
  torchvision
4
  urllib3
5
- Pillow
 
 
 
 
 
 
2
  torch
3
  torchvision
4
  urllib3
5
+ Pillow
6
+ gitpython
7
+
8
+ # tph-yolov5
9
+ opencv-python
10
+ seaborn