Spaces:
Sleeping
Sleeping
added XAI endpoint
Browse files
app.py
CHANGED
|
@@ -101,4 +101,70 @@ async def predict(file: UploadFile = File(...)):
|
|
| 101 |
|
| 102 |
# Encode the image back to JPEG
|
| 103 |
_, buffer = cv2.imencode('.jpg', image_cv)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
return StreamingResponse(io.BytesIO(buffer.tobytes()), media_type="image/jpeg")
|
|
|
|
| 101 |
|
| 102 |
# Encode the image back to JPEG
|
| 103 |
_, buffer = cv2.imencode('.jpg', image_cv)
|
| 104 |
+
return StreamingResponse(io.BytesIO(buffer.tobytes()), media_type="image/jpeg")
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def get_top_prediction(detections):
|
| 108 |
+
"""Extracts the index of the most confident detection."""
|
| 109 |
+
scores = detections['detection_scores'][0].numpy()
|
| 110 |
+
if len(scores) > 0 and scores[0] > 0.4:
|
| 111 |
+
# Returns index 0 (top score) and the class ID
|
| 112 |
+
return 0, int(detections['detection_classes'][0].numpy()[0])
|
| 113 |
+
return None, None
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
@app.post("/explain")
|
| 117 |
+
async def explain(file: UploadFile = File(...)):
|
| 118 |
+
# 1. Prepare Image
|
| 119 |
+
contents = await file.read()
|
| 120 |
+
image_pil = Image.open(io.BytesIO(contents)).convert("RGB")
|
| 121 |
+
image_np = np.array(image_pil).astype(np.float32)
|
| 122 |
+
input_tensor = tf.convert_to_tensor(np.expand_dims(image_np, 0), dtype=tf.float32)
|
| 123 |
+
|
| 124 |
+
# 2. Gradient Tape for Saliency
|
| 125 |
+
with tf.GradientTape() as tape:
|
| 126 |
+
tape.watch(input_tensor)
|
| 127 |
+
|
| 128 |
+
# Manually run the forward pass through the detection model
|
| 129 |
+
image, shapes = detection_model.preprocess(input_tensor)
|
| 130 |
+
prediction_dict = detection_model.predict(image, shapes)
|
| 131 |
+
|
| 132 |
+
# 'class_predictions_with_background' is standard for TFOD SSD/FasterRCNN models
|
| 133 |
+
# It usually has shape [1, num_anchors, num_classes]
|
| 134 |
+
raw_scores = prediction_dict['class_predictions_with_background'][0]
|
| 135 |
+
|
| 136 |
+
# We need a reference detection to know which class to compute gradients for
|
| 137 |
+
detections = detection_model.postprocess(prediction_dict, shapes)
|
| 138 |
+
_, top_class = get_top_prediction(detections)
|
| 139 |
+
|
| 140 |
+
if top_class is None:
|
| 141 |
+
return {"error": "No object detected with sufficient confidence to explain."}
|
| 142 |
+
|
| 143 |
+
# Focus loss on the max score for that specific class across all anchors
|
| 144 |
+
loss = tf.reduce_max(raw_scores[:, top_class])
|
| 145 |
+
|
| 146 |
+
# 3. Compute Gradients
|
| 147 |
+
grads = tape.gradient(loss, input_tensor)
|
| 148 |
+
# Take absolute max across color channels
|
| 149 |
+
saliency = np.max(np.abs(grads.numpy()), axis=-1)[0]
|
| 150 |
+
|
| 151 |
+
# 4. Normalize and Create Heatmap
|
| 152 |
+
# Using 95th percentile to reduce noise/outliers
|
| 153 |
+
v_min, v_max = np.percentile(saliency, (5, 95))
|
| 154 |
+
saliency = np.clip((saliency - v_min) / (v_max - v_min + 1e-8), 0, 1)
|
| 155 |
+
|
| 156 |
+
# Create the JET heatmap (Blue = low, Red = high)
|
| 157 |
+
heatmap = cv2.applyColorMap(np.uint8(255 * saliency), cv2.COLORMAP_JET)
|
| 158 |
+
|
| 159 |
+
# 5. Overlay on original image (Convert original to BGR first)
|
| 160 |
+
original_bgr = cv2.cvtColor(image_np.astype(np.uint8), cv2.COLOR_RGB2BGR)
|
| 161 |
+
overlay = cv2.addWeighted(original_bgr, 0.6, heatmap, 0.4, 0)
|
| 162 |
+
|
| 163 |
+
# Add text label for what we are explaining
|
| 164 |
+
class_name = category_index.get(top_class + 1, {}).get('name', 'unknown')
|
| 165 |
+
cv2.putText(overlay, f"Explaining: {class_name}", (10, 30),
|
| 166 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2)
|
| 167 |
+
|
| 168 |
+
# 6. Stream Result
|
| 169 |
+
_, buffer = cv2.imencode('.jpg', overlay)
|
| 170 |
return StreamingResponse(io.BytesIO(buffer.tobytes()), media_type="image/jpeg")
|