COCO-Counter / object_detection.py
RodneyVG's picture
Upload object_detection.py
e07fa5a
raw history blame
No virus
934 Bytes
# from PIL import Image
from transformers import DetrFeatureExtractor
from transformers import DetrForObjectDetection
import torch
# import numpy as np
def object_count(picture):
feature_extractor = DetrFeatureExtractor.from_pretrained("facebook/detr-resnet-50")
encoding = feature_extractor(picture, return_tensors="pt")
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
outputs = model(**encoding)
# keep only predictions of queries with 0.9+ confidence (excluding no-object class)
probas = outputs.logits.softmax(-1)[0, :, :-1]
keep = probas.max(-1).values > 0.7
count = 0
for i in keep:
if i:
count=count+1
return "About " + str(count) +" common objects were detected"
# object_count("toothbrush.jpg")
import gradio as gr
interface = gr.Interface(object_count, gr.inputs.Image(shape=(640, 480)), "text").launch()