Spaces:
Running
Running
File size: 5,014 Bytes
c46d8ad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import streamlit as st
from transformers import DetrImageProcessor, DetrForObjectDetection
from PIL import Image, ImageDraw
import torch
import re
@st.cache_resource
def load_detection_model():
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
return processor, model
def parse_detection_text(detection_text):
"""Robust parsing of detection text with error handling"""
detections = []
pattern = r'\[([\d\s,]+)\]\s+([a-zA-Z\s]+)\s+([\d.]+)'
for line in detection_text.split('\n'):
if not line.strip():
continue
try:
match = re.match(pattern, line)
if match:
coords = [int(x.strip()) for x in match.group(1).split(',')]
label = match.group(2).strip()
score = float(match.group(3))
if len(coords) == 4:
detections.append({
'box': {'xmin': coords[0], 'ymin': coords[1],
'xmax': coords[2], 'ymax': coords[3]},
'label': label,
'score': score
})
except (ValueError, AttributeError) as e:
st.warning(f"Skipping malformed detection line: {line}")
continue
return detections
def detect_objects(image, processor, model):
"""Run DETR object detection with proper error handling"""
try:
inputs = processor(images=image, return_tensors="pt")
outputs = model(**inputs)
target_sizes = torch.tensor([image.size[::-1]])
results = processor.post_process_object_detection(
outputs,
target_sizes=target_sizes,
threshold=0.7
)[0]
detection_text = ""
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
detection_text += f"[{int(box[0])}, {int(box[1])}, {int(box[2])}, {int(box[3])}] " \
f"{model.config.id2label[label.item()]} {score.item()}\n"
return detection_text, results
except Exception as e:
st.error(f"Detection failed: {str(e)}")
return "", None
def draw_boxes(image, detections):
"""Draw bounding boxes with different colors for different classes"""
draw = ImageDraw.Draw(image)
color_map = {
'person': 'red',
'cell phone': 'blue',
'default': 'green'
}
for det in detections:
box = det['box']
label = det['label']
color = color_map.get(label.lower(), color_map['default'])
draw.rectangle(
[(box['xmin'], box['ymin']), (box['xmax'], box['ymax'])],
outline=color,
width=3
)
draw.text(
(box['xmin'], box['ymin'] - 15),
f"{label} ({det['score']:.2f})",
fill=color
)
return image
def main():
st.title("Object Detection with DETR")
processor, model = load_detection_model()
uploaded_file = st.file_uploader("Upload image", type=["jpg", "png", "jpeg"])
if uploaded_file:
image = Image.open(uploaded_file)
st.image(image, caption="Original Image", use_column_width=True)
if st.button("Detect Objects"):
with st.spinner("Detecting objects..."):
detection_text, results = detect_objects(image, processor, model)
if detection_text:
st.subheader("Detection Results")
# Show raw detections
with st.expander("Raw Detection Output"):
st.text(detection_text)
# Show parsed results
detections = parse_detection_text(detection_text)
if detections:
annotated_image = draw_boxes(image.copy(), detections)
st.image(annotated_image, caption="Detected Objects", use_column_width=True)
# Display in table
st.subheader("Detected Objects")
st.table([
{
"Object": d["label"],
"Confidence": f"{d['score']:.2%}",
"Position": f"({d['box']['xmin']}, {d['box']['ymin']}) to ({d['box']['xmax']}, {d['box']['ymax']})"
}
for d in detections
])
else:
st.warning("No valid detections found")
if __name__ == "__main__":
main() |