JohnAlexander23's picture
Update app.py
b12b77a verified
raw
history blame contribute delete
No virus
2.4 kB
import streamlit as st
import torch
from transformers import YolosImageProcessor, YolosForObjectDetection
from PIL import Image
import requests
from io import BytesIO
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import random
# Constants
EXAMPLE_URL = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/city-streets.jpg'
THRESHOLD = 0.2
# Load model and processor
@st.cache_resource
def load_model():
model_id = 'hustvl/yolos-tiny'
processor = YolosImageProcessor.from_pretrained(model_id)
model = YolosForObjectDetection.from_pretrained(model_id)
return processor, model
processor, model = load_model()
# Function to detect objects in the image
def detect(image):
# Preprocess image
inputs = processor(images=image, return_tensors="pt")
# Predict bounding boxes
with torch.no_grad():
outputs = model(**inputs)
# Extract bounding boxes and labels
target_sizes = torch.tensor([image.size[::-1]])
results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=THRESHOLD)[0]
return results
# Function to render bounding boxes
def render_box(image, results):
plt.figure(figsize=(10, 10))
plt.imshow(image)
ax = plt.gca()
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
if score < THRESHOLD:
continue
color = tuple([random.random() for _ in range(3)]) # Random color for each box
xmin, ymin, xmax, ymax = box
rect = patches.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, linewidth=2, edgecolor=color, facecolor='none')
ax.add_patch(rect)
plt.text(xmin, ymin, f"{model.config.id2label[label.item()]}: {score:.2f}", color=color, fontsize=12, bbox=dict(facecolor='white', alpha=0.5))
plt.axis('off')
st.pyplot(plt)
# Streamlit app
st.title("Object Detection with Hugging Face Transformers")
uploaded_file = st.file_uploader("Choose an image...", type="jpg")
if uploaded_file is not None:
image = Image.open(uploaded_file)
results = detect(image)
render_box(image, results)
else:
if st.button("Try Example Image"):
response = requests.get(EXAMPLE_URL)
image = Image.open(BytesIO(response.content))
results = detect(image)
render_box(image, results)