import streamlit as st import torch from transformers import YolosImageProcessor, YolosForObjectDetection from PIL import Image import requests from io import BytesIO import matplotlib.pyplot as plt import matplotlib.patches as patches import random # Constants EXAMPLE_URL = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/city-streets.jpg' THRESHOLD = 0.2 # Load model and processor @st.cache_resource def load_model(): model_id = 'hustvl/yolos-tiny' processor = YolosImageProcessor.from_pretrained(model_id) model = YolosForObjectDetection.from_pretrained(model_id) return processor, model processor, model = load_model() # Function to detect objects in the image def detect(image): # Preprocess image inputs = processor(images=image, return_tensors="pt") # Predict bounding boxes with torch.no_grad(): outputs = model(**inputs) # Extract bounding boxes and labels target_sizes = torch.tensor([image.size[::-1]]) results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=THRESHOLD)[0] return results # Function to render bounding boxes def render_box(image, results): plt.figure(figsize=(10, 10)) plt.imshow(image) ax = plt.gca() for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): if score < THRESHOLD: continue color = tuple([random.random() for _ in range(3)]) # Random color for each box xmin, ymin, xmax, ymax = box rect = patches.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, linewidth=2, edgecolor=color, facecolor='none') ax.add_patch(rect) plt.text(xmin, ymin, f"{model.config.id2label[label.item()]}: {score:.2f}", color=color, fontsize=12, bbox=dict(facecolor='white', alpha=0.5)) plt.axis('off') st.pyplot(plt) # Streamlit app st.title("Object Detection with Hugging Face Transformers") uploaded_file = st.file_uploader("Choose an image...", type="jpg") if uploaded_file is not None: image = Image.open(uploaded_file) results = detect(image) render_box(image, results) else: if st.button("Try Example Image"): response = requests.get(EXAMPLE_URL) image = Image.open(BytesIO(response.content)) results = detect(image) render_box(image, results)