kasper-boy commited on
Commit
223ea38
1 Parent(s): 66136c9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -35
app.py CHANGED
@@ -11,51 +11,55 @@ processor = DetrImageProcessor.from_pretrained('facebook/detr-resnet-101')
11
  model = DetrForObjectDetection.from_pretrained('facebook/detr-resnet-101')
12
 
13
  def object_detection(image, confidence_threshold):
14
- # Convert the input to a PIL Image object if it's not already
15
- if not isinstance(image, Image.Image):
16
- image = Image.open(io.BytesIO(image))
17
-
18
- # Preprocess the image
19
- inputs = processor(images=image, return_tensors="pt")
 
20
 
21
- # Perform object detection
22
- outputs = model(**inputs)
23
 
24
- # Extract bounding boxes and labels
25
- target_sizes = torch.tensor([image.size[::-1]])
26
- results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=confidence_threshold)[0]
27
 
28
- # Plot the image with bounding boxes
29
- plt.figure(figsize=(16, 10))
30
- plt.imshow(image)
31
- ax = plt.gca()
32
 
33
- detected_objects = []
34
- for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
35
- box = [round(i, 2) for i in box.tolist()]
36
- xmin, ymin, xmax, ymax = box
37
- width, height = xmax - xmin, ymax - ymin
38
 
39
- ax.add_patch(plt.Rectangle((xmin, ymin), width, height, fill=False, color='red', linewidth=3))
40
- text = f'{model.config.id2label[label.item()]}: {round(score.item(), 3)}'
41
- ax.text(xmin, ymin, text, fontsize=15, bbox=dict(facecolor='yellow', alpha=0.5))
42
- detected_objects.append(text)
43
 
44
- plt.axis('off')
45
 
46
- # Save the plot to an image buffer
47
- buf = io.BytesIO()
48
- plt.savefig(buf, format='png')
49
- buf.seek(0)
50
- plt.close()
51
 
52
- # Convert buffer to an Image object
53
- result_image = Image.open(buf)
54
 
55
- # Join detected objects into a single string
56
- detected_objects_text = "\n".join(detected_objects)
57
 
58
- return result_image, detected_objects_text
 
 
 
59
 
60
  # Define the Gradio interface
61
  demo = gr.Interface(
 
11
  model = DetrForObjectDetection.from_pretrained('facebook/detr-resnet-101')
12
 
13
  def object_detection(image, confidence_threshold):
14
+ try:
15
+ # Convert the input to a PIL Image object if it's not already
16
+ if not isinstance(image, Image.Image):
17
+ image = Image.open(io.BytesIO(image))
18
+
19
+ # Preprocess the image
20
+ inputs = processor(images=image, return_tensors="pt")
21
 
22
+ # Perform object detection
23
+ outputs = model(**inputs)
24
 
25
+ # Extract bounding boxes and labels
26
+ target_sizes = torch.tensor([image.size[::-1]])
27
+ results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=confidence_threshold)[0]
28
 
29
+ # Plot the image with bounding boxes
30
+ plt.figure(figsize=(16, 10))
31
+ plt.imshow(image)
32
+ ax = plt.gca()
33
 
34
+ detected_objects = []
35
+ for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
36
+ box = [round(i, 2) for i in box.tolist()]
37
+ xmin, ymin, xmax, ymax = box
38
+ width, height = xmax - xmin, ymax - ymin
39
 
40
+ ax.add_patch(plt.Rectangle((xmin, ymin), width, height, fill=False, color='red', linewidth=3))
41
+ text = f'{model.config.id2label[label.item()]}: {round(score.item(), 3)}'
42
+ ax.text(xmin, ymin, text, fontsize=15, bbox=dict(facecolor='yellow', alpha=0.5))
43
+ detected_objects.append(text)
44
 
45
+ plt.axis('off')
46
 
47
+ # Save the plot to an image buffer
48
+ buf = io.BytesIO()
49
+ plt.savefig(buf, format='png')
50
+ buf.seek(0)
51
+ plt.close()
52
 
53
+ # Convert buffer to an Image object
54
+ result_image = Image.open(buf)
55
 
56
+ # Join detected objects into a single string
57
+ detected_objects_text = "\n".join(detected_objects)
58
 
59
+ return result_image, detected_objects_text
60
+
61
+ except Exception as e:
62
+ return Image.new("RGB", (224, 224), color="white"), str(e)
63
 
64
  # Define the Gradio interface
65
  demo = gr.Interface(