pratikshahp's picture
Update app.py
d5348f1 verified
raw
history blame
2.4 kB
import os
from PIL import Image
from transformers import AutoImageProcessor, AutoModelForObjectDetection
import streamlit as st
import torch
import requests
def prettier(results):
for item in results:
score = round(item['score'], 3)
label = item['label'] # Use square brackets to access the 'label' key
location = [round(value, 2) for value in item['box'].values()]
print(f'Detected {label} with confidence {score} at location {location}')
def input_image_setup(uploaded_file):
if uploaded_file is not None:
#read the file into byte
bytes_data = uploaded_file.getvalue()
image_parts=[
{
"mime_type": uploaded_file.type,
"data": bytes_data
}
]
return image_parts
else:
raise FileNotFoundError("No file uploaded")
#Streamlit App
st.set_page_config(page_title="Image Detection")
st.header("Object Detection Application")
#Select your model
models = ["facebook/detr-resnet-50", "ciasimbaya/ObjectDetection", "hustvl/yolos-tiny"] # List of supported models
model_name = st.selectbox("Select model", models)
processor = AutoImageProcessor.from_pretrained(model_name)
model = AutoModelForObjectDetection.from_pretrained(model_name)
#Upload an image
uploaded_file = st.file_uploader("choose an image...", type=["jpg","jpeg","png"])
image=""
if uploaded_file is not None:
image = Image.open(uploaded_file)
st.image(image, caption="Uploaded Image.", use_column_width=True)
submit = st.button("Detect Objects ")
if submit:
image_data=input_image_setup(uploaded_file)
st.subheader("The response is..")
#process with model
inputs = processor(images=image_data, return_tensors="pt")
outputs = model(**inputs)
# model predicts bounding boxes and corresponding COCO classes
logits = outputs.logits
bboxes = outputs.pred_boxes
# print results
target_sizes = torch.tensor([image.size[::-1]])
results = image_processor.post_process_object_detection(outputs, threshold=0.9, target_sizes=target_sizes)[0]
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
box = [round(i, 2) for i in box.tolist()]
print(
f"Detected {model.config.id2label[label.item()]} with confidence "
f"{round(score.item(), 3)} at location {box}"
)