Spaces:
Running
Running
import streamlit as st | |
from transformers import DetrImageProcessor, DetrForObjectDetection | |
from PIL import Image, ImageDraw | |
import torch | |
import re | |
def load_detection_model(): | |
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50") | |
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50") | |
return processor, model | |
def parse_detection_text(detection_text): | |
"""Robust parsing of detection text with error handling""" | |
detections = [] | |
pattern = r'\[([\d\s,]+)\]\s+([a-zA-Z\s]+)\s+([\d.]+)' | |
for line in detection_text.split('\n'): | |
if not line.strip(): | |
continue | |
try: | |
match = re.match(pattern, line) | |
if match: | |
coords = [int(x.strip()) for x in match.group(1).split(',')] | |
label = match.group(2).strip() | |
score = float(match.group(3)) | |
if len(coords) == 4: | |
detections.append({ | |
'box': {'xmin': coords[0], 'ymin': coords[1], | |
'xmax': coords[2], 'ymax': coords[3]}, | |
'label': label, | |
'score': score | |
}) | |
except (ValueError, AttributeError) as e: | |
st.warning(f"Skipping malformed detection line: {line}") | |
continue | |
return detections | |
def detect_objects(image, processor, model): | |
"""Run DETR object detection with proper error handling""" | |
try: | |
inputs = processor(images=image, return_tensors="pt") | |
outputs = model(**inputs) | |
target_sizes = torch.tensor([image.size[::-1]]) | |
results = processor.post_process_object_detection( | |
outputs, | |
target_sizes=target_sizes, | |
threshold=0.7 | |
)[0] | |
detection_text = "" | |
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): | |
detection_text += f"[{int(box[0])}, {int(box[1])}, {int(box[2])}, {int(box[3])}] " \ | |
f"{model.config.id2label[label.item()]} {score.item()}\n" | |
return detection_text, results | |
except Exception as e: | |
st.error(f"Detection failed: {str(e)}") | |
return "", None | |
def draw_boxes(image, detections): | |
"""Draw bounding boxes with different colors for different classes""" | |
draw = ImageDraw.Draw(image) | |
color_map = { | |
'person': 'red', | |
'cell phone': 'blue', | |
'default': 'green' | |
} | |
for det in detections: | |
box = det['box'] | |
label = det['label'] | |
color = color_map.get(label.lower(), color_map['default']) | |
draw.rectangle( | |
[(box['xmin'], box['ymin']), (box['xmax'], box['ymax'])], | |
outline=color, | |
width=3 | |
) | |
draw.text( | |
(box['xmin'], box['ymin'] - 15), | |
f"{label} ({det['score']:.2f})", | |
fill=color | |
) | |
return image | |
def main(): | |
st.title("Object Detection with DETR") | |
processor, model = load_detection_model() | |
uploaded_file = st.file_uploader("Upload image", type=["jpg", "png", "jpeg"]) | |
if uploaded_file: | |
image = Image.open(uploaded_file) | |
st.image(image, caption="Original Image", use_column_width=True) | |
if st.button("Detect Objects"): | |
with st.spinner("Detecting objects..."): | |
detection_text, results = detect_objects(image, processor, model) | |
if detection_text: | |
st.subheader("Detection Results") | |
# Show raw detections | |
with st.expander("Raw Detection Output"): | |
st.text(detection_text) | |
# Show parsed results | |
detections = parse_detection_text(detection_text) | |
if detections: | |
annotated_image = draw_boxes(image.copy(), detections) | |
st.image(annotated_image, caption="Detected Objects", use_column_width=True) | |
# Display in table | |
st.subheader("Detected Objects") | |
st.table([ | |
{ | |
"Object": d["label"], | |
"Confidence": f"{d['score']:.2%}", | |
"Position": f"({d['box']['xmin']}, {d['box']['ymin']}) to ({d['box']['xmax']}, {d['box']['ymax']})" | |
} | |
for d in detections | |
]) | |
else: | |
st.warning("No valid detections found") | |
if __name__ == "__main__": | |
main() |