jacksonwambali's picture
Rename app to app.py
f8348fb verified
import gradio as gr
import torch
import matplotlib.pyplot as plt
import numpy as np
from fastai.vision.all import load_learner, PILImage
import io
# Ensure custom classes exist before loading the model
class Hook:
def __init__(self, module, func):
self.hook = module.register_forward_hook(lambda mod, inp, out: func(out))
def remove(self): self.hook.remove()
class HookBwd:
def __init__(self, module, func):
self.hook = module.register_full_backward_hook(lambda mod, grad_input, grad_output: func(grad_output[0]))
def remove(self): self.hook.remove()
# Load the learner
try:
learn = load_learner('export.pkl')
print("Model loaded successfully!")
except Exception as e:
print(f"Error loading model: {e}")
# Function to predict + generate Class Activation Map (CAM)
def predict_with_cam(img):
img = PILImage.create(img)
# Get model and target layer (modify as needed for your architecture)
model = learn.model
target_layer = model[0][-1] # Adjust based on model architecture
activations, gradients = [], []
# Define hook functions
def hook_activations(out): activations.append(out)
def hook_gradients(grad): gradients.append(grad)
# Attach hooks
h1 = Hook(target_layer, hook_activations)
h2 = HookBwd(target_layer, hook_gradients)
# Run prediction
pred_class, pred_idx, probs = learn.predict(img)
# Perform backward pass for gradients
img_tensor = learn.dls.test_dl([img]).one_batch()[0]
img_tensor.requires_grad_()
output = model(img_tensor)
output[0, pred_idx].backward()
# Remove hooks
h1.remove()
h2.remove()
# Generate Class Activation Map (CAM)
act = activations[0].detach().cpu().squeeze(0)
grad = gradients[0].detach().cpu().squeeze(0)
weights = grad.mean(dim=(1, 2), keepdim=True)
cam = (weights * act).sum(0)
cam = cam.clamp(min=0).numpy()
# Normalize CAM
cam = (cam - cam.min()) / (cam.max() - cam.min())
# Plot CAM
fig, ax = plt.subplots()
ax.imshow(img)
ax.imshow(cam, alpha=0.5, cmap='jet')
ax.axis('off')
# Save CAM image
buf = io.BytesIO()
plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
buf.seek(0)
return {learn.dls.vocab[i]: float(probs[i]) for i in range(len(probs))}, buf
# Create Gradio interface
interface = gr.Interface(
fn=predict_with_cam,
inputs=gr.Image(type='pil'),
outputs=[gr.Label(num_top_classes=3), gr.Image(type='pil')],
title="Image Classifier with CAM"
)
interface.launch()