mrdbourke commited on
Commit
bb0d52f
โ€ข
1 Parent(s): df8b8a4

Uploading Trashify box detection model v3 app.py with NMS post processing

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ examples/trashify_example_2.jpeg filter=lfs diff=lfs merge=lfs -text
.gradio/cached_examples/21/Image Output no filtering/fee2a07231fec8287609/image.webp ADDED
.gradio/cached_examples/21/log.csv ADDED
@@ -0,0 +1 @@
 
 
1
+ Image Output (no filtering),Text Output (no filtering),Image Output (with max score per class box filtering),Text Output (with max score per class box filtering),timestamp
README.md CHANGED
@@ -1,13 +1,25 @@
1
  ---
2
- title: Trashify V3
3
- emoji: โšก
4
- colorFrom: red
5
- colorTo: pink
6
  sdk: gradio
7
- sdk_version: 4.41.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Trashify Demo V3 ๐Ÿšฎ
3
+ emoji: ๐Ÿ—‘๏ธ
4
+ colorFrom: purple
5
+ colorTo: blue
6
  sdk: gradio
7
+ sdk_version: 4.40.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
  ---
12
 
13
+ # ๐Ÿšฎ Trashify Object Detector Demo V3
14
+
15
+ Object detection demo to detect `trash`, `bin`, `hand`, `trash_arm`, `not_trash`, `not_bin`, `not_hand`.
16
+
17
+ Used as example for encouraging people to cleanup their local area.
18
+
19
+ If `trash`, `hand`, `bin` all detected = +1 point.
20
+
21
+ * V1 = model trained *without* data augmentation
22
+ * V2 = model trained *with* data augmentation
23
+ * V3 = model trained *with* data augmentation & NMS ([Non Maximum Suppression](https://paperswithcode.com/method/non-maximum-suppression)) post processing step
24
+
25
+ TK - finish the README.md + update with links to materials
app.py CHANGED
@@ -1,29 +1,105 @@
1
  import gradio as gr
2
  import torch
3
- from PIL import Image, ImageDraw
4
 
5
  from transformers import AutoImageProcessor
6
  from transformers import AutoModelForObjectDetection
7
 
8
- from PIL import Image
9
-
10
- model_save_path = "mrdbourke/detr_finetuned_trashify_box_detector_synthetic_and_real_data"
11
 
 
12
  image_processor = AutoImageProcessor.from_pretrained(model_save_path)
13
  model = AutoModelForObjectDetection.from_pretrained(model_save_path)
14
 
 
 
 
 
15
  id2label = model.config.id2label
16
- color_dict = {
17
- "not_trash": "red",
 
18
  "bin": "green",
19
  "trash": "blue",
20
- "hand": "purple"
 
 
 
 
21
  }
22
 
23
- device = "cuda" if torch.cuda.is_available() else "cpu"
24
- model = model.to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
- def predict_on_image(image, conf_threshold=0.25):
27
  with torch.no_grad():
28
  inputs = image_processor(images=[image], return_tensors="pt")
29
  outputs = model(**inputs.to(device))
@@ -43,13 +119,37 @@ def predict_on_image(image, conf_threshold=0.25):
43
  # Can return results as plotted on a PIL image (then display the image)
44
  draw = ImageDraw.Draw(image)
45
 
46
- for box, score, label in zip(results["boxes"], results["scores"], results["labels"]):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  # Create coordinates
48
  x, y, x2, y2 = tuple(box.tolist())
49
 
50
  # Get label_name
51
  label_name = id2label[label.item()]
52
  targ_color = color_dict[label_name]
 
53
 
54
  # Draw the rectangle
55
  draw.rectangle(xy=(x, y, x2, y2),
@@ -62,23 +162,70 @@ def predict_on_image(image, conf_threshold=0.25):
62
  # Draw the text on the image
63
  draw.text(xy=(x, y),
64
  text=text_string_to_show,
65
- fill="white")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
  # Remove the draw each time
68
  del draw
 
 
 
 
 
69
 
70
- return image
71
 
 
72
  demo = gr.Interface(
73
  fn=predict_on_image,
74
  inputs=[
75
- gr.Image(type="pil", label="Upload Target Image"),
76
  gr.Slider(minimum=0, maximum=1, value=0.25, label="Confidence Threshold")
77
  ],
78
- outputs=gr.Image(type="pil"),
79
- title="๐Ÿšฎ Trashify Object Detection Demo (real and synthetic data)",
80
- description="Upload an image to detect whether there's a bin, a hand or trash in it. Trained on a mixture of real and synthetic data."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  )
82
 
83
- if __name__ == "__main__":
84
- demo.launch()
 
1
  import gradio as gr
2
  import torch
3
+ from PIL import Image, ImageDraw, ImageFont
4
 
5
  from transformers import AutoImageProcessor
6
  from transformers import AutoModelForObjectDetection
7
 
8
+ # Note: Can load from Hugging Face or can load from local.
9
+ # You will have to replace {mrdbourke} for your own username if the model is on your Hugging Face account.
10
+ model_save_path = "mrdbourke/detr_finetuned_trashify_box_detector_with_data_aug"
11
 
12
+ # Load the model and preprocessor
13
  image_processor = AutoImageProcessor.from_pretrained(model_save_path)
14
  model = AutoModelForObjectDetection.from_pretrained(model_save_path)
15
 
16
+ device = "cuda" if torch.cuda.is_available() else "cpu"
17
+ model = model.to(device)
18
+
19
+ # Get the id2label dictionary from the model
20
  id2label = model.config.id2label
21
+
22
+ # Set up a colour dictionary for plotting boxes with different colours
23
+ color_dict = {
24
  "bin": "green",
25
  "trash": "blue",
26
+ "hand": "purple",
27
+ "trash_arm": "yellow",
28
+ "not_trash": "red",
29
+ "not_bin": "red",
30
+ "not_hand": "red",
31
  }
32
 
33
+ # Create helper functions for seeing if items from one list are in another
34
+ def any_in_list(list_a, list_b):
35
+ "Returns True if any item from list_a is in list_b, otherwise False."
36
+ return any(item in list_b for item in list_a)
37
+
38
+ def all_in_list(list_a, list_b):
39
+ "Returns True if all items from list_a are in list_b, otherwise False."
40
+ return all(item in list_b for item in list_a)
41
+
42
+ def filter_highest_scoring_box_per_class(boxes, labels, scores):
43
+ """
44
+ Perform NMS (Non-max Supression) to only keep the top scoring box per class.
45
+
46
+ Args:
47
+ boxes: tensor of shape (N, 4)
48
+ labels: tensor of shape (N,)
49
+ scores: tensor of shape (N,)
50
+ Returns:
51
+ boxes: tensor of shape (N, 4) filtered for max scoring item per class
52
+ labels: tensor of shape (N,) filtered for max scoring item per class
53
+ scores: tensor of shape (N,) filtered for max scoring item per class
54
+ """
55
+ # Start with a blank keep mask (e.g. all False and then update the boxes to keep with True)
56
+ keep_mask = torch.zeros(len(boxes), dtype=torch.bool)
57
+
58
+ # For each unique class
59
+ for class_id in labels.unique():
60
+ # Get the indicies for the target class
61
+ class_mask = labels == class_id
62
+
63
+ # If any of the labels match the current class_id
64
+ if class_mask.any():
65
+ # Find the index of highest scoring box for this specific class
66
+ class_scores = scores[class_mask]
67
+ highest_score_idx = class_scores.argmax()
68
+
69
+ # Convert back to the original index
70
+ original_idx = torch.where(class_mask)[0][highest_score_idx]
71
+
72
+ # Update the index in the keep mask to keep the highest scoring box
73
+ keep_mask[original_idx] = True
74
+
75
+ return boxes[keep_mask], labels[keep_mask], scores[keep_mask]
76
+
77
+ def create_return_string(list_of_predicted_labels, target_items=["trash", "bin", "hand"]):
78
+ # Setup blank string to print out
79
+ return_string = ""
80
+
81
+ # If no items detected or trash, bin, hand not in list, return notification
82
+ if (len(list_of_predicted_labels) == 0) or not (any_in_list(list_a=target_items, list_b=list_of_predicted_labels)):
83
+ return_string = f"No trash, bin or hand detected at confidence threshold {conf_threshold}. Try another image or lowering the confidence threshold."
84
+ return return_string
85
+
86
+ # If there are some missing, print the ones which are missing
87
+ elif not all_in_list(list_a=target_items, list_b=list_of_predicted_labels):
88
+ missing_items = []
89
+ for item in target_items:
90
+ if item not in list_of_predicted_labels:
91
+ missing_items.append(item)
92
+ return_string = f"Detected the following items: {list_of_predicted_labels} (total: {len(list_of_predicted_labels)}). But missing the following in order to get +1: {missing_items}. If this is an error, try another image or altering the confidence threshold. Otherwise, the model may need to be updated with better data."
93
+
94
+ # If all 3 trash, bin, hand occur = + 1
95
+ if all_in_list(list_a=target_items, list_b=list_of_predicted_labels):
96
+ return_string = f"+1! Found the following items: {list_of_predicted_labels} (total: {len(list_of_predicted_labels)}), thank you for cleaning up the area!"
97
+
98
+ print(return_string)
99
+
100
+ return return_string
101
 
102
+ def predict_on_image(image, conf_threshold):
103
  with torch.no_grad():
104
  inputs = image_processor(images=[image], return_tensors="pt")
105
  outputs = model(**inputs.to(device))
 
119
  # Can return results as plotted on a PIL image (then display the image)
120
  draw = ImageDraw.Draw(image)
121
 
122
+ # Create a copy of the image to draw on it for NMS
123
+ image_nms = image.copy()
124
+ draw_nms = ImageDraw.Draw(image_nms)
125
+
126
+ # Get a font from ImageFont
127
+ font = ImageFont.load_default(size=20)
128
+
129
+ # Get class names as text for print out
130
+ class_name_text_labels = []
131
+
132
+ # TK - update this for NMS
133
+ class_name_text_labels_nms = []
134
+
135
+ # Get original boxes, scores, labels
136
+ original_boxes = results["boxes"]
137
+ original_labels = results["labels"]
138
+ original_scores = results["scores"]
139
+
140
+ # Filter boxes and only keep 1x of each label with highest score
141
+ filtered_boxes, filtered_labels, filtered_scores = filter_highest_scoring_box_per_class(boxes=original_boxes,
142
+ labels=original_labels,
143
+ scores=original_scores)
144
+ # TODO: turn this into a function so it's cleaner?
145
+ for box, label, score in zip(original_boxes, original_labels, original_scores):
146
  # Create coordinates
147
  x, y, x2, y2 = tuple(box.tolist())
148
 
149
  # Get label_name
150
  label_name = id2label[label.item()]
151
  targ_color = color_dict[label_name]
152
+ class_name_text_labels.append(label_name)
153
 
154
  # Draw the rectangle
155
  draw.rectangle(xy=(x, y, x2, y2),
 
162
  # Draw the text on the image
163
  draw.text(xy=(x, y),
164
  text=text_string_to_show,
165
+ fill="white",
166
+ font=font)
167
+
168
+ # TODO: turn this into a function so it's cleaner?
169
+ for box, label, score in zip(filtered_boxes, filtered_labels, filtered_scores):
170
+ # Create coordinates
171
+ x, y, x2, y2 = tuple(box.tolist())
172
+
173
+ # Get label_name
174
+ label_name = id2label[label.item()]
175
+ targ_color = color_dict[label_name]
176
+ class_name_text_labels_nms.append(label_name)
177
+
178
+ # Draw the rectangle
179
+ draw_nms.rectangle(xy=(x, y, x2, y2),
180
+ outline=targ_color,
181
+ width=3)
182
+
183
+ # Create a text string to display
184
+ text_string_to_show = f"{label_name} ({round(score.item(), 3)})"
185
+
186
+ # Draw the text on the image
187
+ draw_nms.text(xy=(x, y),
188
+ text=text_string_to_show,
189
+ fill="white",
190
+ font=font)
191
+
192
 
193
  # Remove the draw each time
194
  del draw
195
+ del draw_nms
196
+
197
+ # Create the return string
198
+ return_string = create_return_string(list_of_predicted_labels=class_name_text_labels)
199
+ return_string_nms = create_return_string(list_of_predicted_labels=class_name_text_labels_nms)
200
 
201
+ return image, return_string, image_nms, return_string_nms
202
 
203
+ # Create the interface
204
  demo = gr.Interface(
205
  fn=predict_on_image,
206
  inputs=[
207
+ gr.Image(type="pil", label="Target Image"),
208
  gr.Slider(minimum=0, maximum=1, value=0.25, label="Confidence Threshold")
209
  ],
210
+ outputs=[
211
+ gr.Image(type="pil", label="Image Output (no filtering)"),
212
+ gr.Text(label="Text Output (no filtering)"),
213
+ gr.Image(type="pil", label="Image Output (with max score per class box filtering)"),
214
+ gr.Text(label="Text Output (with max score per class box filtering)")
215
+
216
+ ],
217
+ title="๐Ÿšฎ Trashify Object Detection Demo V3",
218
+ description="""Help clean up your local area! Upload an image and get +1 if there is all of the following items detected: trash, bin, hand.
219
+ Model in V3 has been trained with data augmentation and has an additional post-processing step to filter classes for only the highest scoring box of each class. (tk - add link to model).
220
+ """,
221
+ # Examples come in the form of a list of lists, where each inner list contains elements to prefill the `inputs` parameter with
222
+ examples=[
223
+ ["examples/trashify_example_1.jpeg", 0.25],
224
+ ["examples/trashify_example_2.jpeg", 0.25],
225
+ ["examples/trashify_example_3.jpeg", 0.25]
226
+ ],
227
+ cache_examples=True
228
  )
229
 
230
+ # Launch the demo
231
+ demo.launch()
examples/trashify_example_1.jpeg ADDED
examples/trashify_example_2.jpeg ADDED

Git LFS Details

  • SHA256: 89ed8acec03b7890e5d2e6fa509c7e842e70a6dd9f6ad4e37d5d1431a1081be7
  • Pointer size: 132 Bytes
  • Size of remote file: 1.07 MB
examples/trashify_example_3.jpeg ADDED