File size: 2,398 Bytes
0ad2d2a
 
cf24a38
0ad2d2a
 
 
 
 
 
 
 
 
 
 
 
 
 
cf24a38
 
 
0ad2d2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b12b77a
0ad2d2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b12b77a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import streamlit as st
import torch
from transformers import YolosImageProcessor, YolosForObjectDetection
from PIL import Image
import requests
from io import BytesIO
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import random

# Constants
EXAMPLE_URL = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/city-streets.jpg'
THRESHOLD = 0.2

# Load model and processor
@st.cache_resource
def load_model():
    model_id = 'hustvl/yolos-tiny'
    processor = YolosImageProcessor.from_pretrained(model_id)
    model = YolosForObjectDetection.from_pretrained(model_id)
    return processor, model

processor, model = load_model()

# Function to detect objects in the image
def detect(image):
    # Preprocess image
    inputs = processor(images=image, return_tensors="pt")
    
    # Predict bounding boxes
    with torch.no_grad():
        outputs = model(**inputs)
    
    # Extract bounding boxes and labels
    target_sizes = torch.tensor([image.size[::-1]])
    results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=THRESHOLD)[0]
    
    return results

# Function to render bounding boxes
def render_box(image, results):
    plt.figure(figsize=(10, 10))
    plt.imshow(image)
    ax = plt.gca()

    for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
        if score < THRESHOLD:
            continue
        
        color = tuple([random.random() for _ in range(3)])  # Random color for each box
        xmin, ymin, xmax, ymax = box
        
        rect = patches.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, linewidth=2, edgecolor=color, facecolor='none')
        ax.add_patch(rect)
        plt.text(xmin, ymin, f"{model.config.id2label[label.item()]}: {score:.2f}", color=color, fontsize=12, bbox=dict(facecolor='white', alpha=0.5))
    
    plt.axis('off')
    st.pyplot(plt)

# Streamlit app
st.title("Object Detection with Hugging Face Transformers")

uploaded_file = st.file_uploader("Choose an image...", type="jpg")

if uploaded_file is not None:
    image = Image.open(uploaded_file)
    results = detect(image)
    render_box(image, results)
else:
    if st.button("Try Example Image"):
        response = requests.get(EXAMPLE_URL)
        image = Image.open(BytesIO(response.content))
        results = detect(image)
        render_box(image, results)