SaniaE commited on
Commit
037bb79
·
verified ·
1 Parent(s): 1a81fb2

added XAI endpoint

Browse files
Files changed (1) hide show
  1. app.py +66 -0
app.py CHANGED
@@ -101,4 +101,70 @@ async def predict(file: UploadFile = File(...)):
101
 
102
  # Encode the image back to JPEG
103
  _, buffer = cv2.imencode('.jpg', image_cv)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  return StreamingResponse(io.BytesIO(buffer.tobytes()), media_type="image/jpeg")
 
101
 
102
  # Encode the image back to JPEG
103
  _, buffer = cv2.imencode('.jpg', image_cv)
104
+ return StreamingResponse(io.BytesIO(buffer.tobytes()), media_type="image/jpeg")
105
+
106
+
107
+ def get_top_prediction(detections):
108
+ """Extracts the index of the most confident detection."""
109
+ scores = detections['detection_scores'][0].numpy()
110
+ if len(scores) > 0 and scores[0] > 0.4:
111
+ # Returns index 0 (top score) and the class ID
112
+ return 0, int(detections['detection_classes'][0].numpy()[0])
113
+ return None, None
114
+
115
+
116
+ @app.post("/explain")
117
+ async def explain(file: UploadFile = File(...)):
118
+ # 1. Prepare Image
119
+ contents = await file.read()
120
+ image_pil = Image.open(io.BytesIO(contents)).convert("RGB")
121
+ image_np = np.array(image_pil).astype(np.float32)
122
+ input_tensor = tf.convert_to_tensor(np.expand_dims(image_np, 0), dtype=tf.float32)
123
+
124
+ # 2. Gradient Tape for Saliency
125
+ with tf.GradientTape() as tape:
126
+ tape.watch(input_tensor)
127
+
128
+ # Manually run the forward pass through the detection model
129
+ image, shapes = detection_model.preprocess(input_tensor)
130
+ prediction_dict = detection_model.predict(image, shapes)
131
+
132
+ # 'class_predictions_with_background' is standard for TFOD SSD/FasterRCNN models
133
+ # It usually has shape [1, num_anchors, num_classes]
134
+ raw_scores = prediction_dict['class_predictions_with_background'][0]
135
+
136
+ # We need a reference detection to know which class to compute gradients for
137
+ detections = detection_model.postprocess(prediction_dict, shapes)
138
+ _, top_class = get_top_prediction(detections)
139
+
140
+ if top_class is None:
141
+ return {"error": "No object detected with sufficient confidence to explain."}
142
+
143
+ # Focus loss on the max score for that specific class across all anchors
144
+ loss = tf.reduce_max(raw_scores[:, top_class])
145
+
146
+ # 3. Compute Gradients
147
+ grads = tape.gradient(loss, input_tensor)
148
+ # Take absolute max across color channels
149
+ saliency = np.max(np.abs(grads.numpy()), axis=-1)[0]
150
+
151
+ # 4. Normalize and Create Heatmap
152
+ # Using 95th percentile to reduce noise/outliers
153
+ v_min, v_max = np.percentile(saliency, (5, 95))
154
+ saliency = np.clip((saliency - v_min) / (v_max - v_min + 1e-8), 0, 1)
155
+
156
+ # Create the JET heatmap (Blue = low, Red = high)
157
+ heatmap = cv2.applyColorMap(np.uint8(255 * saliency), cv2.COLORMAP_JET)
158
+
159
+ # 5. Overlay on original image (Convert original to BGR first)
160
+ original_bgr = cv2.cvtColor(image_np.astype(np.uint8), cv2.COLOR_RGB2BGR)
161
+ overlay = cv2.addWeighted(original_bgr, 0.6, heatmap, 0.4, 0)
162
+
163
+ # Add text label for what we are explaining
164
+ class_name = category_index.get(top_class + 1, {}).get('name', 'unknown')
165
+ cv2.putText(overlay, f"Explaining: {class_name}", (10, 30),
166
+ cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2)
167
+
168
+ # 6. Stream Result
169
+ _, buffer = cv2.imencode('.jpg', overlay)
170
  return StreamingResponse(io.BytesIO(buffer.tobytes()), media_type="image/jpeg")