File size: 5,014 Bytes
c46d8ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import streamlit as st
from transformers import DetrImageProcessor, DetrForObjectDetection
from PIL import Image, ImageDraw
import torch
import re

@st.cache_resource
def load_detection_model():
    processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
    model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
    return processor, model

def parse_detection_text(detection_text):
    """Robust parsing of detection text with error handling"""
    detections = []
    pattern = r'\[([\d\s,]+)\]\s+([a-zA-Z\s]+)\s+([\d.]+)'
    
    for line in detection_text.split('\n'):
        if not line.strip():
            continue
        
        try:
            match = re.match(pattern, line)
            if match:
                coords = [int(x.strip()) for x in match.group(1).split(',')]
                label = match.group(2).strip()
                score = float(match.group(3))
                
                if len(coords) == 4:
                    detections.append({
                        'box': {'xmin': coords[0], 'ymin': coords[1], 
                                'xmax': coords[2], 'ymax': coords[3]},
                        'label': label,
                        'score': score
                    })
        except (ValueError, AttributeError) as e:
            st.warning(f"Skipping malformed detection line: {line}")
            continue
            
    return detections

def detect_objects(image, processor, model):
    """Run DETR object detection with proper error handling"""
    try:
        inputs = processor(images=image, return_tensors="pt")
        outputs = model(**inputs)
        
        target_sizes = torch.tensor([image.size[::-1]])
        results = processor.post_process_object_detection(
            outputs, 
            target_sizes=target_sizes, 
            threshold=0.7
        )[0]
        
        detection_text = ""
        for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
            detection_text += f"[{int(box[0])}, {int(box[1])}, {int(box[2])}, {int(box[3])}] " \
                            f"{model.config.id2label[label.item()]} {score.item()}\n"
        
        return detection_text, results
        
    except Exception as e:
        st.error(f"Detection failed: {str(e)}")
        return "", None

def draw_boxes(image, detections):
    """Draw bounding boxes with different colors for different classes"""
    draw = ImageDraw.Draw(image)
    color_map = {
        'person': 'red',
        'cell phone': 'blue',
        'default': 'green'
    }
    
    for det in detections:
        box = det['box']
        label = det['label']
        color = color_map.get(label.lower(), color_map['default'])
        
        draw.rectangle(
            [(box['xmin'], box['ymin']), (box['xmax'], box['ymax'])],
            outline=color,
            width=3
        )
        draw.text(
            (box['xmin'], box['ymin'] - 15),
            f"{label} ({det['score']:.2f})",
            fill=color
        )
    return image

def main():
    st.title("Object Detection with DETR")
    processor, model = load_detection_model()
    
    uploaded_file = st.file_uploader("Upload image", type=["jpg", "png", "jpeg"])
    
    if uploaded_file:
        image = Image.open(uploaded_file)
        st.image(image, caption="Original Image", use_column_width=True)
        
        if st.button("Detect Objects"):
            with st.spinner("Detecting objects..."):
                detection_text, results = detect_objects(image, processor, model)
                
                if detection_text:
                    st.subheader("Detection Results")
                    
                    # Show raw detections
                    with st.expander("Raw Detection Output"):
                        st.text(detection_text)
                    
                    # Show parsed results
                    detections = parse_detection_text(detection_text)
                    if detections:
                        annotated_image = draw_boxes(image.copy(), detections)
                        st.image(annotated_image, caption="Detected Objects", use_column_width=True)
                        
                        # Display in table
                        st.subheader("Detected Objects")
                        st.table([
                            {
                                "Object": d["label"],
                                "Confidence": f"{d['score']:.2%}",
                                "Position": f"({d['box']['xmin']}, {d['box']['ymin']}) to ({d['box']['xmax']}, {d['box']['ymax']})"
                            }
                            for d in detections
                        ])
                    else:
                        st.warning("No valid detections found")

if __name__ == "__main__":
    main()