from typing import Tuple, Union import gradio as gr import matplotlib.pyplot as plt import torch from PIL import Image import bcos.models.pretrained as pretrained from bcos.data.categories import IMAGENET_CATEGORIES device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def get_model(model_name): model = getattr(pretrained, model_name)(pretrained=True) model = model.to(device) model.eval() return model MODEL_NAMES = pretrained.list_available() class NormalizationMode: # this is normalization for the explanations! INDIVIDUAL = "individual" WRT_PREDICTION = "wrt prediction's confidence" INDIVIDUAL_X_CONFIDENCE = "individual×confidence" @classmethod def all(cls): return [cls.WRT_PREDICTION, cls.INDIVIDUAL_X_CONFIDENCE, cls.INDIVIDUAL] def freeze(model): for param in model.parameters(): param.requires_grad = False def run( model_name: str, input_image: Image, do_resize: bool, do_center_crop: bool, normalization_mode: str, smooth: int, alpha_percentile: Union[int, float], plot_dpi: int, topk: int = 5, ) -> Tuple[dict, plt.Figure]: # cleanup previous stuff plt.close("all") torch.cuda.empty_cache() # preprocess - get model and transform input image model = get_model(model_name) freeze(model) x = model.transform.transform_with_options( input_image, center_crop=do_center_crop, resize=do_resize, ) x = x.unsqueeze(0).to(device).requires_grad_() # predict and explain with model.explanation_mode(): out = model(x) topk_values, topk_preds = torch.topk(out, topk, dim=1) topk_values, topk_preds = topk_values[0], topk_preds[0] dynamic_weights = [] # list of grad tensors of shape (C, H, W) for i in range(topk): topk_values[i].backward(inputs=[x], retain_graph=i < topk - 1) dynamic_weights.append( x.grad.detach().cpu()[0], ) x.grad = None # reset # prepare output labels+confidences topk_probabilities = ( model.to_probabilities(out.detach()).topk(topk, dim=1).values[0].cpu() ) confidences = { IMAGENET_CATEGORIES[i]: v.item() for i, v in zip(topk_preds, topk_probabilities) } # output plot of images output_fig, axs = plt.subplots( 1, topk + 1, dpi=plot_dpi, figsize=((topk + 1) * 2.1, 2) ) # visualize input image x = x.detach().cpu()[0] axs[0].imshow(x[:3].permute(1, 2, 0).numpy()) axs[0].set_xlabel("Input Image") # visualize explanations pred_confidence = topk_probabilities[0] # first one is pred for i, ax in enumerate(axs[1:]): expl = model.gradient_to_image( x, dynamic_weights[i], smooth=smooth, alpha_percentile=alpha_percentile, ) if normalization_mode == NormalizationMode.INDIVIDUAL_X_CONFIDENCE: expl[:, :, -1] *= topk_probabilities[i].item() elif normalization_mode == NormalizationMode.WRT_PREDICTION and i > 0: expl[:, :, -1] *= (topk_probabilities[i] / pred_confidence).item() else: # NormalizationMode.INDIVIDUAL pass ax.imshow(expl) ax.set_xlabel(IMAGENET_CATEGORIES[topk_preds[i]]) for ax in axs: ax.set_xticks([]) ax.set_yticks([]) output_fig.tight_layout() return confidences, output_fig with gr.Blocks() as demo: # basic info gr.Markdown( """# B-cos Explanation Generation Demo [Repository](https://github.com/B-cos/B-cos-v2/) """ ) with gr.Row(): selected_model = gr.Dropdown( MODEL_NAMES, value="densenet121_long", label="Select model" ) with gr.Accordion("Options", open=False): do_resize = gr.Checkbox( label="Resize input image's shorter side to 256", value=True ) do_center_crop = gr.Checkbox( label="Center crop input image to 224x224", value=False ) normalization_mode = gr.Radio( NormalizationMode.all(), value=NormalizationMode.WRT_PREDICTION, label="Normalization Mode", ) smooth = gr.Slider(1, 51, value=15, step=2, label="Smoothing kernel size") alpha_percentile = gr.Number(value=99.99, label="Percentile") plot_dpi = gr.Number(value=100, label="Plot DPI") input_image = gr.Image(type="pil", label="Image") run_button = gr.Button("Predict and Explain", variant="primary") # will contain all outputs in a plot output = gr.Plot(label="Explanations") # labels output_labels = gr.Label(label="Top-5 Predictions") run_button.click( fn=run, inputs=[ selected_model, input_image, do_resize, do_center_crop, normalization_mode, smooth, alpha_percentile, plot_dpi, ], outputs=[output_labels, output], scroll_to_output=True, ) demo.launch()