File size: 2,175 Bytes
08d4e3b
 
57576d3
 
08d4e3b
57576d3
08d4e3b
57576d3
08d4e3b
 
 
 
 
57576d3
 
 
 
 
 
 
 
 
 
 
08d4e3b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57576d3
08d4e3b
 
57576d3
08d4e3b
 
 
 
 
 
 
 
57576d3
08d4e3b
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import streamlit as st
from transformers import DetrImageProcessor, DetrForObjectDetection
import torch
from PIL import Image
import requests

st.set_page_config(page_title="SnapSpot", page_icon="📸", layout="wide", initial_sidebar_state="collapsed")

# Function to perform object detection
def detect_objects(image):
    # Load DETR model and processor
    processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
    model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", revision="no_timm")

    # Preprocess the image
    inputs = processor(images=image, return_tensors="pt")

    # Perform object detection
    outputs = model(**inputs)

    # Convert outputs to COCO API format
    target_sizes = torch.tensor([image.size[::-1]])
    results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]

    return results

# Main Streamlit app
def main():
    st.title("SnapSpot")
    st.markdown(
        """
        <style>
            .reportview-container {
                background: #0e1117;
                color: #f0f6fc;
            }
            .st-bq {
                background-color: #0e1117;
            }
            .st-bm {
                padding-top: 2rem;
            }
        </style>
        """,
        unsafe_allow_html=True,
    )

    # Upload image
    uploaded_image = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])

    if uploaded_image is not None:
        # Display uploaded image
        image = Image.open(uploaded_image)
        st.image(image, caption="Uploaded Image", use_column_width=True)

        # Perform object detection
        results = detect_objects(image)

        # Display detection results
        st.subheader("Detection Results:")
        for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
            box = [round(i, 2) for i in box.tolist()]
            st.write(
                f"Detected {model.config.id2label[label.item()]} with confidence "
                f"{round(score.item(), 3)} at location {box}"
            )

if __name__ == "__main__":
    main()