khizon commited on
Commit
beb576f
1 Parent(s): 9272812

Created app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -0
app.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import pipeline
2
+ from transformers import DetrFeatureExtractor, DetrForObjectDetection
3
+ from PIL import Image, ImageDraw, ImageFont
4
+ import gradio as gr
5
+
6
+ # Initialize another model and feature extractor
7
+ feature_extractor = DetrFeatureExtractor.from_pretrained('facebook/detr-resnet-50')
8
+ model = DetrForObjectDetection.from_pretrained('facebook/detr-resnet-50')
9
+
10
+ # Initialize the object detection pipeline
11
+ object_detector = pipeline("object-detection", model = model, feature_extractor = feature_extractor)
12
+
13
+ # Draw bounding box definition
14
+ def draw_bounding_box(im, score, label, xmin, ymin, xmax, ymax, index, num_boxes):
15
+ """ Draw a bounding box. """
16
+ # Draw the actual bounding box
17
+ outline = ' '
18
+ if label in ['truck', 'car', 'motorcycle', 'bus']:
19
+ outline = 'red'
20
+ elif label in ['person', 'bicycle']:
21
+ outline = 'green'
22
+ else:
23
+ outline = 'blue'
24
+ im_with_rectangle = ImageDraw.Draw(im)
25
+ im_with_rectangle.rounded_rectangle((xmin, ymin, xmax, ymax), outline = outline, width = 3, radius = 10)
26
+
27
+ # Return the result
28
+ return im
29
+
30
+ def detect_image(im):
31
+ # Perform object detection
32
+ bounding_boxes = object_detector(im)
33
+
34
+ # Iteration elements
35
+ num_boxes = len(bounding_boxes)
36
+ index = 0
37
+
38
+ # Draw bounding box for each result
39
+ for bounding_box in bounding_boxes:
40
+ if bounding_box['label'] in ['person','motorcycle','bicycle', 'truck', 'car','bus']:
41
+ box = bounding_box['box']
42
+
43
+ #Draw the bounding box
44
+ output_image = draw_bounding_box(im, bounding_box['score'],
45
+ bounding_box['label'],
46
+ box['xmin'], box['ymin'],
47
+ box['xmax'], box['ymax'],
48
+ index, num_boxes)
49
+ index += 1
50
+
51
+ return output_image
52
+
53
+ iface = gr.Interface(detect_image, gr.inputs.Image(type = 'pil'), gr.outputs.Image()).launch(debug = True)