Spaces:
Sleeping
Sleeping
split endpoint views
Browse files
app.py
CHANGED
|
@@ -167,4 +167,127 @@ async def explain(file: UploadFile = File(...)):
|
|
| 167 |
|
| 168 |
# 6. Stream Result
|
| 169 |
_, buffer = cv2.imencode('.jpg', overlay)
|
| 170 |
-
return StreamingResponse(io.BytesIO(buffer.tobytes()), media_type="image/jpeg")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
|
| 168 |
# 6. Stream Result
|
| 169 |
_, buffer = cv2.imencode('.jpg', overlay)
|
| 170 |
+
return StreamingResponse(io.BytesIO(buffer.tobytes()), media_type="image/jpeg")
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
@app.post("/explain/tiled")
|
| 174 |
+
async def explain_tiled(file: UploadFile = File(...)):
|
| 175 |
+
# 1. Prepare Base Image
|
| 176 |
+
contents = await file.read()
|
| 177 |
+
image_pil = Image.open(io.BytesIO(contents)).convert("RGB")
|
| 178 |
+
image_np = np.array(image_pil).astype(np.float32)
|
| 179 |
+
input_tensor = tf.convert_to_tensor(np.expand_dims(image_np, 0), dtype=tf.float32)
|
| 180 |
+
|
| 181 |
+
# 2. Get Initial Detections to know what to "Explain"
|
| 182 |
+
detections = detect_fn(input_tensor)
|
| 183 |
+
scores = detections['detection_scores'][0].numpy()
|
| 184 |
+
classes = detections['detection_classes'][0].numpy().astype(int)
|
| 185 |
+
boxes = detections['detection_boxes'][0].numpy()
|
| 186 |
+
|
| 187 |
+
# Create the Top-Left "Base" image with all boxes
|
| 188 |
+
base_image = cv2.cvtColor(image_np.astype(np.uint8), cv2.COLOR_RGB2BGR)
|
| 189 |
+
h_img, w_img, _ = base_image.shape
|
| 190 |
+
|
| 191 |
+
for i in range(min(len(scores), 3)):
|
| 192 |
+
if scores[i] > 0.4:
|
| 193 |
+
ymin, xmin, ymax, xmax = boxes[i]
|
| 194 |
+
cv2.rectangle(base_image, (int(xmin*w_img), int(ymin*h_img)),
|
| 195 |
+
(int(xmax*w_img), int(ymax*h_img)), (255, 255, 0), 2)
|
| 196 |
+
|
| 197 |
+
# 3. Generate Saliency Maps for the Top 3 detections
|
| 198 |
+
panels = [base_image]
|
| 199 |
+
|
| 200 |
+
for i in range(3):
|
| 201 |
+
if i < len(scores) and scores[i] > 0.4:
|
| 202 |
+
target_class = classes[i]
|
| 203 |
+
|
| 204 |
+
with tf.GradientTape() as tape:
|
| 205 |
+
tape.watch(input_tensor)
|
| 206 |
+
image, shapes = detection_model.preprocess(input_tensor)
|
| 207 |
+
prediction_dict = detection_model.predict(image, shapes)
|
| 208 |
+
raw_scores = prediction_dict['class_predictions_with_background'][0]
|
| 209 |
+
# Target the specific class at its most active anchor
|
| 210 |
+
loss = tf.reduce_max(raw_scores[:, target_class])
|
| 211 |
+
|
| 212 |
+
grads = tape.gradient(loss, input_tensor)
|
| 213 |
+
saliency = np.max(np.abs(grads.numpy()), axis=-1)[0]
|
| 214 |
+
|
| 215 |
+
# Normalize and Colorize
|
| 216 |
+
v_min, v_max = np.percentile(saliency, (5, 95))
|
| 217 |
+
saliency = np.clip((saliency - v_min) / (v_max - v_min + 1e-8), 0, 1)
|
| 218 |
+
heatmap = cv2.applyColorMap(np.uint8(255 * saliency), cv2.COLORMAP_JET)
|
| 219 |
+
|
| 220 |
+
# Overlay
|
| 221 |
+
overlay = cv2.addWeighted(cv2.cvtColor(image_np.astype(np.uint8), cv2.COLOR_RGB2BGR), 0.6, heatmap, 0.4, 0)
|
| 222 |
+
|
| 223 |
+
# Label the panel
|
| 224 |
+
class_name = category_index.get(target_class + 1, {}).get('name', 'unknown')
|
| 225 |
+
cv2.putText(overlay, f"Top {i+1}: {class_name}", (10, 30),
|
| 226 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
|
| 227 |
+
panels.append(overlay)
|
| 228 |
+
else:
|
| 229 |
+
# Placeholder for empty slots if fewer than 3 detections exist
|
| 230 |
+
panels.append(np.zeros_like(base_image))
|
| 231 |
+
|
| 232 |
+
# 4. Assemble the 2x2 Grid
|
| 233 |
+
# Panels are: [0:Base, 1:Top1, 2:Top2, 3:Top3]
|
| 234 |
+
top_row = np.hstack((panels[0], panels[1]))
|
| 235 |
+
bottom_row = np.hstack((panels[2], panels[3]))
|
| 236 |
+
tiled_output = np.vstack((top_row, bottom_row))
|
| 237 |
+
|
| 238 |
+
# 5. Stream Result
|
| 239 |
+
_, buffer = cv2.imencode('.jpg', tiled_output)
|
| 240 |
+
return StreamingResponse(io.BytesIO(buffer.tobytes()), media_type="image/jpeg")
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
@app.post("/explain/global")
|
| 244 |
+
async def explain_global(file: UploadFile = File(...)):
|
| 245 |
+
# 1. Read and Prepare Image
|
| 246 |
+
contents = await file.read()
|
| 247 |
+
image_pil = Image.open(io.BytesIO(contents)).convert("RGB")
|
| 248 |
+
image_np = np.array(image_pil).astype(np.float32)
|
| 249 |
+
# Keeping a uint8 copy for the final BGR overlay
|
| 250 |
+
image_bgr = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
|
| 251 |
+
|
| 252 |
+
input_tensor = tf.convert_to_tensor(np.expand_dims(image_np, 0), dtype=tf.float32)
|
| 253 |
+
|
| 254 |
+
# 2. Gradient Tape for Global Activation
|
| 255 |
+
with tf.GradientTape() as tape:
|
| 256 |
+
tape.watch(input_tensor)
|
| 257 |
+
|
| 258 |
+
# Forward pass
|
| 259 |
+
image, shapes = detection_model.preprocess(input_tensor)
|
| 260 |
+
prediction_dict = detection_model.predict(image, shapes)
|
| 261 |
+
|
| 262 |
+
# 'class_predictions_with_background' shape: [1, num_anchors, num_classes]
|
| 263 |
+
raw_scores = prediction_dict['class_predictions_with_background'][0]
|
| 264 |
+
|
| 265 |
+
# We ignore index 0 (Background/Clear) and look at all damage classes
|
| 266 |
+
# We take the max score at each anchor point, then sum them for the global loss
|
| 267 |
+
foreground_scores = raw_scores[:, 1:]
|
| 268 |
+
loss = tf.reduce_sum(tf.reduce_max(foreground_scores, axis=-1))
|
| 269 |
+
|
| 270 |
+
# 3. Compute and Process Gradients
|
| 271 |
+
grads = tape.gradient(loss, input_tensor)
|
| 272 |
+
saliency = np.max(np.abs(grads.numpy()), axis=-1)[0]
|
| 273 |
+
|
| 274 |
+
# 4. Refine Saliency Visualization
|
| 275 |
+
# Using the 95th percentile helps ignore "pixel noise" and highlights the actual damage
|
| 276 |
+
v_min, v_max = np.percentile(saliency, (5, 95))
|
| 277 |
+
saliency = np.clip((saliency - v_min) / (v_max - v_min + 1e-8), 0, 1)
|
| 278 |
+
|
| 279 |
+
# Create the heatmap overlay
|
| 280 |
+
heatmap = cv2.applyColorMap(np.uint8(255 * saliency), cv2.COLORMAP_JET)
|
| 281 |
+
|
| 282 |
+
# Blend: 60% original image, 40% heatmap
|
| 283 |
+
# This maintains the "Pinterest-chic" aesthetic without washing out the car details
|
| 284 |
+
overlay = cv2.addWeighted(image_bgr, 0.6, heatmap, 0.4, 0)
|
| 285 |
+
|
| 286 |
+
# 5. Add Branding/Label
|
| 287 |
+
# Teal text to match your office setup/portfolio theme
|
| 288 |
+
cv2.putText(overlay, "Global Model Attention", (20, 40),
|
| 289 |
+
cv2.FONT_HERSHEY_SIMPLEX, 1.0, (255, 255, 0), 2)
|
| 290 |
+
|
| 291 |
+
# 6. Stream Result
|
| 292 |
+
_, buffer = cv2.imencode('.jpg', overlay)
|
| 293 |
+
return StreamingResponse(io.BytesIO(buffer.tobytes()), media_type="image/jpeg")
|