RodneyVG commited on
Commit
e07fa5a
1 Parent(s): d82b6ba

Upload object_detection.py

Browse files
Files changed (1) hide show
  1. object_detection.py +27 -0
object_detection.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from PIL import Image
2
+ from transformers import DetrFeatureExtractor
3
+ from transformers import DetrForObjectDetection
4
+ import torch
5
+ # import numpy as np
6
+
7
+ def object_count(picture):
8
+
9
+ feature_extractor = DetrFeatureExtractor.from_pretrained("facebook/detr-resnet-50")
10
+ encoding = feature_extractor(picture, return_tensors="pt")
11
+ model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
12
+ outputs = model(**encoding)
13
+ # keep only predictions of queries with 0.9+ confidence (excluding no-object class)
14
+ probas = outputs.logits.softmax(-1)[0, :, :-1]
15
+ keep = probas.max(-1).values > 0.7
16
+ count = 0
17
+ for i in keep:
18
+ if i:
19
+ count=count+1
20
+
21
+ return "About " + str(count) +" common objects were detected"
22
+
23
+ # object_count("toothbrush.jpg")
24
+ import gradio as gr
25
+
26
+ interface = gr.Interface(object_count, gr.inputs.Image(shape=(640, 480)), "text").launch()
27
+