arsath-sm commited on
Commit
e4dcfbb
·
verified ·
1 Parent(s): de391b1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -48
app.py CHANGED
@@ -5,14 +5,11 @@ import onnxruntime as ort
5
  from PIL import Image
6
  import tempfile
7
 
8
- # Dynamically assign colors to classes
9
- def get_color(class_id):
10
- """Generate a color for any class ID"""
11
- np.random.seed(class_id) # For consistent colors
12
- return tuple(map(int, np.random.randint(0, 255, 3)))
13
-
14
- # Class labels - will be populated dynamically
15
- CLASSES = {}
16
 
17
  # Load the ONNX model
18
  @st.cache_resource
@@ -41,11 +38,6 @@ def postprocess_results(output, original_shape, confidence_threshold=0.25, iou_t
41
  else:
42
  raise ValueError(f"Unexpected output type: {type(output)}")
43
 
44
- # Debug: Print the shape and first few entries of predictions
45
- st.write(f"Debug - Predictions shape: {predictions.shape}")
46
- if len(predictions) > 0:
47
- st.write(f"Debug - First prediction entry: {predictions[0]}")
48
-
49
  if len(predictions.shape) == 4:
50
  predictions = predictions.squeeze((0, 1))
51
  elif len(predictions.shape) == 3:
@@ -56,15 +48,6 @@ def postprocess_results(output, original_shape, confidence_threshold=0.25, iou_t
56
  scores = predictions[:, 4]
57
  class_ids = predictions[:, 5]
58
 
59
- # Debug: Print unique class IDs
60
- unique_classes = np.unique(class_ids)
61
- st.write(f"Debug - Unique class IDs found: {unique_classes}")
62
-
63
- # Update CLASSES dictionary with any new class IDs
64
- for class_id in unique_classes:
65
- if int(class_id) not in CLASSES:
66
- CLASSES[int(class_id)] = f"Class_{int(class_id)}"
67
-
68
  # Filter by confidence
69
  mask = scores > confidence_threshold
70
  boxes = boxes[mask]
@@ -113,32 +96,36 @@ def process_image(image):
113
 
114
  # Draw bounding boxes on the image
115
  for x1, y1, x2, y2, score, class_id in results:
116
- # Get color dynamically
117
- color = get_color(class_id)
118
- cv2.rectangle(orig_image, (x1, y1), (x2, y2), color, 2)
 
 
 
119
 
120
- label = f"{CLASSES[class_id]}: {score:.2f}"
121
- # Calculate text size for better positioning
122
  (text_width, text_height), _ = cv2.getTextSize(
123
- label, cv2.FONT_HERSHEY_SIMPLEX, 0.9, 2
124
  )
125
- # Draw background rectangle for text
 
126
  cv2.rectangle(
127
  orig_image,
128
- (x1, y1 - text_height - 10),
129
  (x1 + text_width, y1),
130
- color,
131
  -1
132
  )
133
- # Draw text
 
134
  cv2.putText(
135
  orig_image,
136
  label,
137
  (x1, y1 - 5),
138
  cv2.FONT_HERSHEY_SIMPLEX,
139
- 0.9,
140
  (255, 255, 255),
141
- 2
142
  )
143
 
144
  return cv2.cvtColor(orig_image, cv2.COLOR_BGR2RGB)
@@ -180,7 +167,7 @@ def process_video(video_path):
180
  return temp_file.name
181
 
182
  # Streamlit UI
183
- st.title("Object Detection")
184
 
185
  # Add confidence threshold slider
186
  confidence_threshold = st.slider(
@@ -219,16 +206,4 @@ if uploaded_file is not None:
219
  processed_video = process_video(tfile.name)
220
  st.video(processed_video)
221
 
222
- # Add legend after processing to include all detected classes
223
- if CLASSES:
224
- st.markdown("### Detection Legend")
225
- for class_id, class_name in CLASSES.items():
226
- color = get_color(class_id)
227
- st.markdown(
228
- f'<div style="display: flex; align-items: center;">'
229
- f'<div style="width: 20px; height: 20px; background-color: rgb{color}; margin-right: 10px;"></div>'
230
- f'<span>{class_name}</span></div>',
231
- unsafe_allow_html=True
232
- )
233
-
234
- st.write("Upload an image or video to detect objects.")
 
5
  from PIL import Image
6
  import tempfile
7
 
8
+ # Define class labels
9
+ CLASSES = {
10
+ 0: "Vehicle",
11
+ 1: "License_Plate"
12
+ }
 
 
 
13
 
14
  # Load the ONNX model
15
  @st.cache_resource
 
38
  else:
39
  raise ValueError(f"Unexpected output type: {type(output)}")
40
 
 
 
 
 
 
41
  if len(predictions.shape) == 4:
42
  predictions = predictions.squeeze((0, 1))
43
  elif len(predictions.shape) == 3:
 
48
  scores = predictions[:, 4]
49
  class_ids = predictions[:, 5]
50
 
 
 
 
 
 
 
 
 
 
51
  # Filter by confidence
52
  mask = scores > confidence_threshold
53
  boxes = boxes[mask]
 
96
 
97
  # Draw bounding boxes on the image
98
  for x1, y1, x2, y2, score, class_id in results:
99
+ # Draw rectangle with white color
100
+ cv2.rectangle(orig_image, (x1, y1), (x2, y2), (255, 255, 255), 2)
101
+
102
+ # Get class name
103
+ class_name = CLASSES.get(class_id, f"Class_{class_id}")
104
+ label = f"{class_name}: {score:.2f}"
105
 
106
+ # Add label background and text
 
107
  (text_width, text_height), _ = cv2.getTextSize(
108
+ label, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 1
109
  )
110
+
111
+ # Draw black background for text
112
  cv2.rectangle(
113
  orig_image,
114
+ (x1, y1 - text_height - 4),
115
  (x1 + text_width, y1),
116
+ (0, 0, 0),
117
  -1
118
  )
119
+
120
+ # Draw white text
121
  cv2.putText(
122
  orig_image,
123
  label,
124
  (x1, y1 - 5),
125
  cv2.FONT_HERSHEY_SIMPLEX,
126
+ 0.6,
127
  (255, 255, 255),
128
+ 1
129
  )
130
 
131
  return cv2.cvtColor(orig_image, cv2.COLOR_BGR2RGB)
 
167
  return temp_file.name
168
 
169
  # Streamlit UI
170
+ st.title("Vehicle and License Plate Detection")
171
 
172
  # Add confidence threshold slider
173
  confidence_threshold = st.slider(
 
206
  processed_video = process_video(tfile.name)
207
  st.video(processed_video)
208
 
209
+ st.write("Upload an image or video to detect vehicles and license plates.")