SnapSpot / app.py
haydenbanz's picture
Update app.py
08d4e3b verified
raw
history blame contribute delete
No virus
2.18 kB
import streamlit as st
from transformers import DetrImageProcessor, DetrForObjectDetection
import torch
from PIL import Image
import requests
st.set_page_config(page_title="SnapSpot", page_icon="📸", layout="wide", initial_sidebar_state="collapsed")
# Function to perform object detection
def detect_objects(image):
# Load DETR model and processor
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
# Preprocess the image
inputs = processor(images=image, return_tensors="pt")
# Perform object detection
outputs = model(**inputs)
# Convert outputs to COCO API format
target_sizes = torch.tensor([image.size[::-1]])
results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
return results
# Main Streamlit app
def main():
st.title("SnapSpot")
st.markdown(
"""
<style>
.reportview-container {
background: #0e1117;
color: #f0f6fc;
}
.st-bq {
background-color: #0e1117;
}
.st-bm {
padding-top: 2rem;
}
</style>
""",
unsafe_allow_html=True,
)
# Upload image
uploaded_image = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
if uploaded_image is not None:
# Display uploaded image
image = Image.open(uploaded_image)
st.image(image, caption="Uploaded Image", use_column_width=True)
# Perform object detection
results = detect_objects(image)
# Display detection results
st.subheader("Detection Results:")
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
box = [round(i, 2) for i in box.tolist()]
st.write(
f"Detected {model.config.id2label[label.item()]} with confidence "
f"{round(score.item(), 3)} at location {box}"
)
if __name__ == "__main__":
main()