SaniaE's picture
split endpoint views
55df2c0 verified
import os
import io
import cv2
import numpy as np
import os
os.environ['TF_USE_LEGACY_KERAS'] = '1'
import tf_keras as keras
import tensorflow as tf
from fastapi import FastAPI, UploadFile, File
from PIL import Image
from huggingface_hub import snapshot_download
from fastapi.responses import StreamingResponse
from object_detection.utils import label_map_util, config_util
from object_detection.builders import model_builder
app = FastAPI()
# 1. Download Private Models
HF_TOKEN = os.getenv("HF_Token")
REPO_ID = "SaniaE/Car_Damage_Detection"
print("Downloading models from Hugging Face...")
model_dir = snapshot_download(
repo_id=REPO_ID,
token=HF_TOKEN,
local_dir="./models_data"
)
PIPELINE_CONFIG = os.path.join(model_dir, "object_detection_model/pipeline.config")
CHECKPOINT_PATH = os.path.join(model_dir, "object_detection_model/ckpt-37")
LABEL_MAP_PATH = os.path.join(model_dir, "object_detection_model/label_map.pbtxt")
CNN_MODEL_PATH = os.path.join(model_dir, "cnn_filter.h5")
# 3. Load Models
# Load CNN Filter
cnn_filter = tf.keras.models.load_model(CNN_MODEL_PATH, compile=False)
# Load Object Detection Model
configs = config_util.get_configs_from_pipeline_file(PIPELINE_CONFIG)
detection_model = model_builder.build(model_config=configs['model'], is_training=False)
ckpt = tf.compat.v2.train.Checkpoint(model=detection_model)
ckpt.restore(CHECKPOINT_PATH).expect_partial()
category_index = label_map_util.create_category_index_from_labelmap(LABEL_MAP_PATH)
@tf.function
def detect_fn(image):
image, shapes = detection_model.preprocess(image)
prediction_dict = detection_model.predict(image, shapes)
detections = detection_model.postprocess(prediction_dict, shapes)
return detections
@app.get("/")
def read_root():
return {"status": "Model is Online", "model_repo": REPO_ID}
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
# Read Image
contents = await file.read()
image_pil = Image.open(io.BytesIO(contents)).convert("RGB")
image_np = np.array(image_pil)
# We need a BGR version for OpenCV drawing
image_cv = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
height, width, _ = image_cv.shape
# Step 1: CNN Filter
img_cnn = image_pil.resize((64, 64))
x = tf.keras.preprocessing.image.img_to_array(img_cnn)
x = np.expand_dims(x, axis=0)
cnn_pred = cnn_filter.predict(x)
is_damage_labels = ['Clear', 'Damaged']
status = is_damage_labels[np.argmax(cnn_pred)]
# Step 2: Object Detection (If damaged)
if status == 'Damaged':
input_tensor = tf.convert_to_tensor(np.expand_dims(image_np, 0), dtype=tf.float32)
detections = detect_fn(input_tensor)
scores = detections['detection_scores'][0].numpy()
classes = detections['detection_classes'][0].numpy().astype(int)
boxes = detections['detection_boxes'][0].numpy()
for i in range(len(scores)):
if scores[i] > 0.4:
# TFOD Boxes are [ymin, xmin, ymax, xmax] in normalized coordinates
ymin, xmin, ymax, xmax = boxes[i]
(left, right, top, bottom) = (xmin * width, xmax * width,
ymin * height, ymax * height)
# Draw Bounding Box (Teal color to match your vibe)
cv2.rectangle(image_cv, (int(left), int(top)), (int(right), int(bottom)), (255, 255, 0), 2)
# Draw Label
label = f"{category_index.get(classes[i] + 1, {}).get('name', 'unknown')}: {int(scores[i]*100)}%"
cv2.putText(image_cv, label, (int(left), int(top) - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 0), 2)
# Encode the image back to JPEG
_, buffer = cv2.imencode('.jpg', image_cv)
return StreamingResponse(io.BytesIO(buffer.tobytes()), media_type="image/jpeg")
def get_top_prediction(detections):
"""Extracts the index of the most confident detection."""
scores = detections['detection_scores'][0].numpy()
if len(scores) > 0 and scores[0] > 0.4:
# Returns index 0 (top score) and the class ID
return 0, int(detections['detection_classes'][0].numpy()[0])
return None, None
@app.post("/explain")
async def explain(file: UploadFile = File(...)):
# 1. Prepare Image
contents = await file.read()
image_pil = Image.open(io.BytesIO(contents)).convert("RGB")
image_np = np.array(image_pil).astype(np.float32)
input_tensor = tf.convert_to_tensor(np.expand_dims(image_np, 0), dtype=tf.float32)
# 2. Gradient Tape for Saliency
with tf.GradientTape() as tape:
tape.watch(input_tensor)
# Manually run the forward pass through the detection model
image, shapes = detection_model.preprocess(input_tensor)
prediction_dict = detection_model.predict(image, shapes)
# 'class_predictions_with_background' is standard for TFOD SSD/FasterRCNN models
# It usually has shape [1, num_anchors, num_classes]
raw_scores = prediction_dict['class_predictions_with_background'][0]
# We need a reference detection to know which class to compute gradients for
detections = detection_model.postprocess(prediction_dict, shapes)
_, top_class = get_top_prediction(detections)
if top_class is None:
return {"error": "No object detected with sufficient confidence to explain."}
# Focus loss on the max score for that specific class across all anchors
loss = tf.reduce_max(raw_scores[:, top_class])
# 3. Compute Gradients
grads = tape.gradient(loss, input_tensor)
# Take absolute max across color channels
saliency = np.max(np.abs(grads.numpy()), axis=-1)[0]
# 4. Normalize and Create Heatmap
# Using 95th percentile to reduce noise/outliers
v_min, v_max = np.percentile(saliency, (5, 95))
saliency = np.clip((saliency - v_min) / (v_max - v_min + 1e-8), 0, 1)
# Create the JET heatmap (Blue = low, Red = high)
heatmap = cv2.applyColorMap(np.uint8(255 * saliency), cv2.COLORMAP_JET)
# 5. Overlay on original image (Convert original to BGR first)
original_bgr = cv2.cvtColor(image_np.astype(np.uint8), cv2.COLOR_RGB2BGR)
overlay = cv2.addWeighted(original_bgr, 0.6, heatmap, 0.4, 0)
# Add text label for what we are explaining
class_name = category_index.get(top_class + 1, {}).get('name', 'unknown')
cv2.putText(overlay, f"Explaining: {class_name}", (10, 30),
cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2)
# 6. Stream Result
_, buffer = cv2.imencode('.jpg', overlay)
return StreamingResponse(io.BytesIO(buffer.tobytes()), media_type="image/jpeg")
@app.post("/explain/tiled")
async def explain_tiled(file: UploadFile = File(...)):
# 1. Prepare Base Image
contents = await file.read()
image_pil = Image.open(io.BytesIO(contents)).convert("RGB")
image_np = np.array(image_pil).astype(np.float32)
input_tensor = tf.convert_to_tensor(np.expand_dims(image_np, 0), dtype=tf.float32)
# 2. Get Initial Detections to know what to "Explain"
detections = detect_fn(input_tensor)
scores = detections['detection_scores'][0].numpy()
classes = detections['detection_classes'][0].numpy().astype(int)
boxes = detections['detection_boxes'][0].numpy()
# Create the Top-Left "Base" image with all boxes
base_image = cv2.cvtColor(image_np.astype(np.uint8), cv2.COLOR_RGB2BGR)
h_img, w_img, _ = base_image.shape
for i in range(min(len(scores), 3)):
if scores[i] > 0.4:
ymin, xmin, ymax, xmax = boxes[i]
cv2.rectangle(base_image, (int(xmin*w_img), int(ymin*h_img)),
(int(xmax*w_img), int(ymax*h_img)), (255, 255, 0), 2)
# 3. Generate Saliency Maps for the Top 3 detections
panels = [base_image]
for i in range(3):
if i < len(scores) and scores[i] > 0.4:
target_class = classes[i]
with tf.GradientTape() as tape:
tape.watch(input_tensor)
image, shapes = detection_model.preprocess(input_tensor)
prediction_dict = detection_model.predict(image, shapes)
raw_scores = prediction_dict['class_predictions_with_background'][0]
# Target the specific class at its most active anchor
loss = tf.reduce_max(raw_scores[:, target_class])
grads = tape.gradient(loss, input_tensor)
saliency = np.max(np.abs(grads.numpy()), axis=-1)[0]
# Normalize and Colorize
v_min, v_max = np.percentile(saliency, (5, 95))
saliency = np.clip((saliency - v_min) / (v_max - v_min + 1e-8), 0, 1)
heatmap = cv2.applyColorMap(np.uint8(255 * saliency), cv2.COLORMAP_JET)
# Overlay
overlay = cv2.addWeighted(cv2.cvtColor(image_np.astype(np.uint8), cv2.COLOR_RGB2BGR), 0.6, heatmap, 0.4, 0)
# Label the panel
class_name = category_index.get(target_class + 1, {}).get('name', 'unknown')
cv2.putText(overlay, f"Top {i+1}: {class_name}", (10, 30),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
panels.append(overlay)
else:
# Placeholder for empty slots if fewer than 3 detections exist
panels.append(np.zeros_like(base_image))
# 4. Assemble the 2x2 Grid
# Panels are: [0:Base, 1:Top1, 2:Top2, 3:Top3]
top_row = np.hstack((panels[0], panels[1]))
bottom_row = np.hstack((panels[2], panels[3]))
tiled_output = np.vstack((top_row, bottom_row))
# 5. Stream Result
_, buffer = cv2.imencode('.jpg', tiled_output)
return StreamingResponse(io.BytesIO(buffer.tobytes()), media_type="image/jpeg")
@app.post("/explain/global")
async def explain_global(file: UploadFile = File(...)):
# 1. Read and Prepare Image
contents = await file.read()
image_pil = Image.open(io.BytesIO(contents)).convert("RGB")
image_np = np.array(image_pil).astype(np.float32)
# Keeping a uint8 copy for the final BGR overlay
image_bgr = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
input_tensor = tf.convert_to_tensor(np.expand_dims(image_np, 0), dtype=tf.float32)
# 2. Gradient Tape for Global Activation
with tf.GradientTape() as tape:
tape.watch(input_tensor)
# Forward pass
image, shapes = detection_model.preprocess(input_tensor)
prediction_dict = detection_model.predict(image, shapes)
# 'class_predictions_with_background' shape: [1, num_anchors, num_classes]
raw_scores = prediction_dict['class_predictions_with_background'][0]
# We ignore index 0 (Background/Clear) and look at all damage classes
# We take the max score at each anchor point, then sum them for the global loss
foreground_scores = raw_scores[:, 1:]
loss = tf.reduce_sum(tf.reduce_max(foreground_scores, axis=-1))
# 3. Compute and Process Gradients
grads = tape.gradient(loss, input_tensor)
saliency = np.max(np.abs(grads.numpy()), axis=-1)[0]
# 4. Refine Saliency Visualization
# Using the 95th percentile helps ignore "pixel noise" and highlights the actual damage
v_min, v_max = np.percentile(saliency, (5, 95))
saliency = np.clip((saliency - v_min) / (v_max - v_min + 1e-8), 0, 1)
# Create the heatmap overlay
heatmap = cv2.applyColorMap(np.uint8(255 * saliency), cv2.COLORMAP_JET)
# Blend: 60% original image, 40% heatmap
# This maintains the "Pinterest-chic" aesthetic without washing out the car details
overlay = cv2.addWeighted(image_bgr, 0.6, heatmap, 0.4, 0)
# 5. Add Branding/Label
# Teal text to match your office setup/portfolio theme
cv2.putText(overlay, "Global Model Attention", (20, 40),
cv2.FONT_HERSHEY_SIMPLEX, 1.0, (255, 255, 0), 2)
# 6. Stream Result
_, buffer = cv2.imencode('.jpg', overlay)
return StreamingResponse(io.BytesIO(buffer.tobytes()), media_type="image/jpeg")