import gradio as gr import timm import torch from PIL import Image import requests from io import BytesIO import numpy as np from pytorch_grad_cam import GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget, get_target_layer from pytorch_grad_cam.utils.image import show_cam_on_image from timm.data import create_transform # List of available timm models MODELS = timm.list_models() # List of available GradCAM methods CAM_METHODS = { "GradCAM": GradCAM, "HiResCAM": HiResCAM, "ScoreCAM": ScoreCAM, "GradCAM++": GradCAMPlusPlus, "AblationCAM": AblationCAM, "XGradCAM": XGradCAM, "EigenCAM": EigenCAM, "FullGrad": FullGrad } def load_model(model_name): model = timm.create_model(model_name, pretrained=True) model.eval() return model def process_image(image_path, model): if image_path.startswith('http'): response = requests.get(image_path) image = Image.open(BytesIO(response.content)) else: image = Image.open(image_path) config = model.pretrained_cfg transform = create_transform( input_size=config['input_size'], crop_pct=config['crop_pct'], mean=config['mean'], std=config['std'], interpolation=config['interpolation'], is_training=False ) tensor = transform(image).unsqueeze(0) return tensor def get_cam_image(model, image, target_layer, cam_method): cam = CAM_METHODS[cam_method](model=model, target_layers=[target_layer], use_cuda=torch.cuda.is_available()) grayscale_cam = cam(input_tensor=image) config = model.pretrained_cfg mean = torch.tensor(config['mean']).view(3, 1, 1) std = torch.tensor(config['std']).view(3, 1, 1) rgb_img = (image.squeeze(0) * std + mean).permute(1, 2, 0).cpu().numpy() rgb_img = np.clip(rgb_img, 0, 1) cam_image = show_cam_on_image(rgb_img, grayscale_cam[0, :], use_rgb=True) return Image.fromarray(cam_image) def get_feature_info(model): if hasattr(model, 'feature_info'): return [f['module'] for f in model.feature_info] else: return [] def explain_image(model_name, image_path, cam_method, feature_module): model = load_model(model_name) image = process_image(image_path, model) if feature_module: target_layer = get_target_layer(model, feature_module) print(f"Using feature module: {feature_module}") else: # Fallback to the last feature module or last convolutional layer feature_info = get_feature_info(model) if feature_info: target_layer = get_target_layer(model, feature_info[-1]) print(f"Using last feature module: {feature_info[-1]}") else: # Fallback to finding last convolutional layer for name, module in reversed(list(model.named_modules())): if isinstance(module, torch.nn.Conv2d): target_layer = module print(f"Fallback: Using last convolutional layer: {name}") break if target_layer is None: raise ValueError("Could not find a suitable target layer.") cam_image = get_cam_image(model, image, target_layer, cam_method) return cam_image def update_feature_modules(model_name): model = load_model(model_name) feature_modules = get_feature_info(model) return gr.Dropdown.update(choices=feature_modules, value=feature_modules[-1] if feature_modules else None) iface = gr.Interface( fn=explain_image, inputs=[ gr.Dropdown(choices=MODELS, label="Select Model"), gr.Image(type="filepath", label="Upload Image"), gr.Dropdown(choices=list(CAM_METHODS.keys()), label="Select CAM Method"), gr.Dropdown(label="Select Feature Module (optional)") ], outputs=gr.Image(type="pil", label="Explained Image"), title="Explainable AI with timm models", description="Upload an image, select a model, CAM method, and optionally a specific feature module to visualize the explanation.", allow_flagging="never" ) iface.load(update_feature_modules, inputs=[iface.inputs[0]], outputs=[iface.inputs[3]]) iface.launch()