SaniaE commited on
Commit
55df2c0
·
verified ·
1 Parent(s): 037bb79

split endpoint views

Browse files
Files changed (1) hide show
  1. app.py +124 -1
app.py CHANGED
@@ -167,4 +167,127 @@ async def explain(file: UploadFile = File(...)):
167
 
168
  # 6. Stream Result
169
  _, buffer = cv2.imencode('.jpg', overlay)
170
- return StreamingResponse(io.BytesIO(buffer.tobytes()), media_type="image/jpeg")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
 
168
  # 6. Stream Result
169
  _, buffer = cv2.imencode('.jpg', overlay)
170
+ return StreamingResponse(io.BytesIO(buffer.tobytes()), media_type="image/jpeg")
171
+
172
+
173
+ @app.post("/explain/tiled")
174
+ async def explain_tiled(file: UploadFile = File(...)):
175
+ # 1. Prepare Base Image
176
+ contents = await file.read()
177
+ image_pil = Image.open(io.BytesIO(contents)).convert("RGB")
178
+ image_np = np.array(image_pil).astype(np.float32)
179
+ input_tensor = tf.convert_to_tensor(np.expand_dims(image_np, 0), dtype=tf.float32)
180
+
181
+ # 2. Get Initial Detections to know what to "Explain"
182
+ detections = detect_fn(input_tensor)
183
+ scores = detections['detection_scores'][0].numpy()
184
+ classes = detections['detection_classes'][0].numpy().astype(int)
185
+ boxes = detections['detection_boxes'][0].numpy()
186
+
187
+ # Create the Top-Left "Base" image with all boxes
188
+ base_image = cv2.cvtColor(image_np.astype(np.uint8), cv2.COLOR_RGB2BGR)
189
+ h_img, w_img, _ = base_image.shape
190
+
191
+ for i in range(min(len(scores), 3)):
192
+ if scores[i] > 0.4:
193
+ ymin, xmin, ymax, xmax = boxes[i]
194
+ cv2.rectangle(base_image, (int(xmin*w_img), int(ymin*h_img)),
195
+ (int(xmax*w_img), int(ymax*h_img)), (255, 255, 0), 2)
196
+
197
+ # 3. Generate Saliency Maps for the Top 3 detections
198
+ panels = [base_image]
199
+
200
+ for i in range(3):
201
+ if i < len(scores) and scores[i] > 0.4:
202
+ target_class = classes[i]
203
+
204
+ with tf.GradientTape() as tape:
205
+ tape.watch(input_tensor)
206
+ image, shapes = detection_model.preprocess(input_tensor)
207
+ prediction_dict = detection_model.predict(image, shapes)
208
+ raw_scores = prediction_dict['class_predictions_with_background'][0]
209
+ # Target the specific class at its most active anchor
210
+ loss = tf.reduce_max(raw_scores[:, target_class])
211
+
212
+ grads = tape.gradient(loss, input_tensor)
213
+ saliency = np.max(np.abs(grads.numpy()), axis=-1)[0]
214
+
215
+ # Normalize and Colorize
216
+ v_min, v_max = np.percentile(saliency, (5, 95))
217
+ saliency = np.clip((saliency - v_min) / (v_max - v_min + 1e-8), 0, 1)
218
+ heatmap = cv2.applyColorMap(np.uint8(255 * saliency), cv2.COLORMAP_JET)
219
+
220
+ # Overlay
221
+ overlay = cv2.addWeighted(cv2.cvtColor(image_np.astype(np.uint8), cv2.COLOR_RGB2BGR), 0.6, heatmap, 0.4, 0)
222
+
223
+ # Label the panel
224
+ class_name = category_index.get(target_class + 1, {}).get('name', 'unknown')
225
+ cv2.putText(overlay, f"Top {i+1}: {class_name}", (10, 30),
226
+ cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
227
+ panels.append(overlay)
228
+ else:
229
+ # Placeholder for empty slots if fewer than 3 detections exist
230
+ panels.append(np.zeros_like(base_image))
231
+
232
+ # 4. Assemble the 2x2 Grid
233
+ # Panels are: [0:Base, 1:Top1, 2:Top2, 3:Top3]
234
+ top_row = np.hstack((panels[0], panels[1]))
235
+ bottom_row = np.hstack((panels[2], panels[3]))
236
+ tiled_output = np.vstack((top_row, bottom_row))
237
+
238
+ # 5. Stream Result
239
+ _, buffer = cv2.imencode('.jpg', tiled_output)
240
+ return StreamingResponse(io.BytesIO(buffer.tobytes()), media_type="image/jpeg")
241
+
242
+
243
+ @app.post("/explain/global")
244
+ async def explain_global(file: UploadFile = File(...)):
245
+ # 1. Read and Prepare Image
246
+ contents = await file.read()
247
+ image_pil = Image.open(io.BytesIO(contents)).convert("RGB")
248
+ image_np = np.array(image_pil).astype(np.float32)
249
+ # Keeping a uint8 copy for the final BGR overlay
250
+ image_bgr = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
251
+
252
+ input_tensor = tf.convert_to_tensor(np.expand_dims(image_np, 0), dtype=tf.float32)
253
+
254
+ # 2. Gradient Tape for Global Activation
255
+ with tf.GradientTape() as tape:
256
+ tape.watch(input_tensor)
257
+
258
+ # Forward pass
259
+ image, shapes = detection_model.preprocess(input_tensor)
260
+ prediction_dict = detection_model.predict(image, shapes)
261
+
262
+ # 'class_predictions_with_background' shape: [1, num_anchors, num_classes]
263
+ raw_scores = prediction_dict['class_predictions_with_background'][0]
264
+
265
+ # We ignore index 0 (Background/Clear) and look at all damage classes
266
+ # We take the max score at each anchor point, then sum them for the global loss
267
+ foreground_scores = raw_scores[:, 1:]
268
+ loss = tf.reduce_sum(tf.reduce_max(foreground_scores, axis=-1))
269
+
270
+ # 3. Compute and Process Gradients
271
+ grads = tape.gradient(loss, input_tensor)
272
+ saliency = np.max(np.abs(grads.numpy()), axis=-1)[0]
273
+
274
+ # 4. Refine Saliency Visualization
275
+ # Using the 95th percentile helps ignore "pixel noise" and highlights the actual damage
276
+ v_min, v_max = np.percentile(saliency, (5, 95))
277
+ saliency = np.clip((saliency - v_min) / (v_max - v_min + 1e-8), 0, 1)
278
+
279
+ # Create the heatmap overlay
280
+ heatmap = cv2.applyColorMap(np.uint8(255 * saliency), cv2.COLORMAP_JET)
281
+
282
+ # Blend: 60% original image, 40% heatmap
283
+ # This maintains the "Pinterest-chic" aesthetic without washing out the car details
284
+ overlay = cv2.addWeighted(image_bgr, 0.6, heatmap, 0.4, 0)
285
+
286
+ # 5. Add Branding/Label
287
+ # Teal text to match your office setup/portfolio theme
288
+ cv2.putText(overlay, "Global Model Attention", (20, 40),
289
+ cv2.FONT_HERSHEY_SIMPLEX, 1.0, (255, 255, 0), 2)
290
+
291
+ # 6. Stream Result
292
+ _, buffer = cv2.imencode('.jpg', overlay)
293
+ return StreamingResponse(io.BytesIO(buffer.tobytes()), media_type="image/jpeg")