haydenbanz commited on
Commit
7f3e41d
1 Parent(s): 5133224

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -51
app.py CHANGED
@@ -1,16 +1,16 @@
1
- import streamlit as st
2
- from transformers import DetrImageProcessor, DetrForObjectDetection
3
  import torch
4
  from PIL import Image
5
- import requests
6
 
7
- st.set_page_config(page_title="SnapSpot", page_icon="📸", layout="wide", initial_sidebar_state="collapsed")
 
 
8
 
9
- # Function to perform object detection
10
- def detect_objects(image):
11
- # Load DETR model and processor
12
- processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
13
- model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
14
 
15
  # Preprocess the image
16
  inputs = processor(images=image, return_tensors="pt")
@@ -22,48 +22,16 @@ def detect_objects(image):
22
  target_sizes = torch.tensor([image.size[::-1]])
23
  results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
24
 
25
- return results
26
-
27
- # Main Streamlit app
28
- def main():
29
- st.title("SnapSpot")
30
- st.markdown(
31
- """
32
- <style>
33
- .reportview-container {
34
- background: #0e1117;
35
- color: #f0f6fc;
36
- }
37
- .st-bq {
38
- background-color: #0e1117;
39
- }
40
- .st-bm {
41
- padding-top: 2rem;
42
- }
43
- </style>
44
- """,
45
- unsafe_allow_html=True,
46
- )
47
-
48
- # Upload image
49
- uploaded_image = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
50
-
51
- if uploaded_image is not None:
52
- # Display uploaded image
53
- image = Image.open(uploaded_image)
54
- st.image(image, caption="Uploaded Image", use_column_width=True)
55
 
56
- # Perform object detection
57
- results = detect_objects(image)
58
 
59
- # Display detection results
60
- st.subheader("Detection Results:")
61
- for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
62
- box = [round(i, 2) for i in box.tolist()]
63
- st.write(
64
- f"Detected {model.config.id2label[label.item()]} with confidence "
65
- f"{round(score.item(), 3)} at location {box}"
66
- )
67
 
68
- if __name__ == "__main__":
69
- main()
 
1
+ import io
2
+ import json
3
  import torch
4
  from PIL import Image
5
+ from transformers import DetrImageProcessor, DetrForObjectDetection
6
 
7
+ # Initialize the DETR model and processor
8
+ processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
9
+ model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
10
 
11
+ def predict(inputs):
12
+ # Load the image from the provided inputs
13
+ image = Image.open(io.BytesIO(inputs["image"]))
 
 
14
 
15
  # Preprocess the image
16
  inputs = processor(images=image, return_tensors="pt")
 
22
  target_sizes = torch.tensor([image.size[::-1]])
23
  results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
24
 
25
+ # Prepare the results in a dictionary format
26
+ detections = [{"label": model.config.id2label[label.item()], "confidence": score.item(), "box": box.tolist()}
27
+ for score, label, box in zip(results["scores"], results["labels"], results["boxes"])]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
+ return detections
 
30
 
31
+ # Define the API endpoint for Hugging Face Spaces
32
+ def huggingface_spaces_endpoint(inputs):
33
+ # Call the predict function with the provided inputs
34
+ detections = predict(inputs)
 
 
 
 
35
 
36
+ # Return the detections as a JSON object
37
+ return json.dumps(detections)