|
import gradio as gr |
|
import timm |
|
import torch |
|
from PIL import Image |
|
import requests |
|
from io import BytesIO |
|
import numpy as np |
|
from pytorch_grad_cam import GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad |
|
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget, get_target_layer |
|
from pytorch_grad_cam.utils.image import show_cam_on_image |
|
from timm.data import create_transform |
|
|
|
|
|
MODELS = timm.list_models() |
|
|
|
|
|
CAM_METHODS = { |
|
"GradCAM": GradCAM, |
|
"HiResCAM": HiResCAM, |
|
"ScoreCAM": ScoreCAM, |
|
"GradCAM++": GradCAMPlusPlus, |
|
"AblationCAM": AblationCAM, |
|
"XGradCAM": XGradCAM, |
|
"EigenCAM": EigenCAM, |
|
"FullGrad": FullGrad |
|
} |
|
|
|
def load_model(model_name): |
|
model = timm.create_model(model_name, pretrained=True) |
|
model.eval() |
|
return model |
|
|
|
def process_image(image_path, model): |
|
if image_path.startswith('http'): |
|
response = requests.get(image_path) |
|
image = Image.open(BytesIO(response.content)) |
|
else: |
|
image = Image.open(image_path) |
|
|
|
config = model.pretrained_cfg |
|
transform = create_transform( |
|
input_size=config['input_size'], |
|
crop_pct=config['crop_pct'], |
|
mean=config['mean'], |
|
std=config['std'], |
|
interpolation=config['interpolation'], |
|
is_training=False |
|
) |
|
|
|
tensor = transform(image).unsqueeze(0) |
|
return tensor |
|
|
|
def get_cam_image(model, image, target_layer, cam_method): |
|
cam = CAM_METHODS[cam_method](model=model, target_layers=[target_layer], use_cuda=torch.cuda.is_available()) |
|
grayscale_cam = cam(input_tensor=image) |
|
|
|
config = model.pretrained_cfg |
|
mean = torch.tensor(config['mean']).view(3, 1, 1) |
|
std = torch.tensor(config['std']).view(3, 1, 1) |
|
rgb_img = (image.squeeze(0) * std + mean).permute(1, 2, 0).cpu().numpy() |
|
rgb_img = np.clip(rgb_img, 0, 1) |
|
|
|
cam_image = show_cam_on_image(rgb_img, grayscale_cam[0, :], use_rgb=True) |
|
return Image.fromarray(cam_image) |
|
|
|
def get_feature_info(model): |
|
if hasattr(model, 'feature_info'): |
|
return [f['module'] for f in model.feature_info] |
|
else: |
|
return [] |
|
|
|
def explain_image(model_name, image_path, cam_method, feature_module): |
|
model = load_model(model_name) |
|
image = process_image(image_path, model) |
|
|
|
if feature_module: |
|
target_layer = get_target_layer(model, feature_module) |
|
print(f"Using feature module: {feature_module}") |
|
else: |
|
|
|
feature_info = get_feature_info(model) |
|
if feature_info: |
|
target_layer = get_target_layer(model, feature_info[-1]) |
|
print(f"Using last feature module: {feature_info[-1]}") |
|
else: |
|
|
|
for name, module in reversed(list(model.named_modules())): |
|
if isinstance(module, torch.nn.Conv2d): |
|
target_layer = module |
|
print(f"Fallback: Using last convolutional layer: {name}") |
|
break |
|
|
|
if target_layer is None: |
|
raise ValueError("Could not find a suitable target layer.") |
|
|
|
cam_image = get_cam_image(model, image, target_layer, cam_method) |
|
return cam_image |
|
|
|
def update_feature_modules(model_name): |
|
model = load_model(model_name) |
|
feature_modules = get_feature_info(model) |
|
return gr.Dropdown.update(choices=feature_modules, value=feature_modules[-1] if feature_modules else None) |
|
|
|
iface = gr.Interface( |
|
fn=explain_image, |
|
inputs=[ |
|
gr.Dropdown(choices=MODELS, label="Select Model"), |
|
gr.Image(type="filepath", label="Upload Image"), |
|
gr.Dropdown(choices=list(CAM_METHODS.keys()), label="Select CAM Method"), |
|
gr.Dropdown(label="Select Feature Module (optional)") |
|
], |
|
outputs=gr.Image(type="pil", label="Explained Image"), |
|
title="Explainable AI with timm models", |
|
description="Upload an image, select a model, CAM method, and optionally a specific feature module to visualize the explanation.", |
|
allow_flagging="never" |
|
) |
|
|
|
iface.load(update_feature_modules, inputs=[iface.inputs[0]], outputs=[iface.inputs[3]]) |
|
|
|
iface.launch() |