|
from typing import Tuple, Union |
|
|
|
import gradio as gr |
|
import matplotlib.pyplot as plt |
|
import torch |
|
from PIL import Image |
|
|
|
import bcos.models.pretrained as pretrained |
|
from bcos.data.categories import IMAGENET_CATEGORIES |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
def get_model(model_name): |
|
model = getattr(pretrained, model_name)(pretrained=True) |
|
model = model.to(device) |
|
model.eval() |
|
return model |
|
|
|
|
|
MODEL_NAMES = pretrained.list_available() |
|
|
|
|
|
class NormalizationMode: |
|
|
|
INDIVIDUAL = "individual" |
|
WRT_PREDICTION = "wrt prediction's confidence" |
|
INDIVIDUAL_X_CONFIDENCE = "individual×confidence" |
|
|
|
@classmethod |
|
def all(cls): |
|
return [cls.WRT_PREDICTION, cls.INDIVIDUAL_X_CONFIDENCE, cls.INDIVIDUAL] |
|
|
|
|
|
def freeze(model): |
|
for param in model.parameters(): |
|
param.requires_grad = False |
|
|
|
|
|
def run( |
|
model_name: str, |
|
input_image: Image, |
|
do_resize: bool, |
|
do_center_crop: bool, |
|
normalization_mode: str, |
|
smooth: int, |
|
alpha_percentile: Union[int, float], |
|
plot_dpi: int, |
|
topk: int = 5, |
|
) -> Tuple[dict, plt.Figure]: |
|
|
|
plt.close("all") |
|
torch.cuda.empty_cache() |
|
|
|
|
|
model = get_model(model_name) |
|
freeze(model) |
|
x = model.transform.transform_with_options( |
|
input_image, |
|
center_crop=do_center_crop, |
|
resize=do_resize, |
|
) |
|
x = x.unsqueeze(0).to(device).requires_grad_() |
|
|
|
|
|
with model.explanation_mode(): |
|
out = model(x) |
|
|
|
topk_values, topk_preds = torch.topk(out, topk, dim=1) |
|
topk_values, topk_preds = topk_values[0], topk_preds[0] |
|
|
|
dynamic_weights = [] |
|
for i in range(topk): |
|
topk_values[i].backward(inputs=[x], retain_graph=i < topk - 1) |
|
dynamic_weights.append( |
|
x.grad.detach().cpu()[0], |
|
) |
|
x.grad = None |
|
|
|
|
|
topk_probabilities = ( |
|
model.to_probabilities(out.detach()).topk(topk, dim=1).values[0].cpu() |
|
) |
|
confidences = { |
|
IMAGENET_CATEGORIES[i]: v.item() for i, v in zip(topk_preds, topk_probabilities) |
|
} |
|
|
|
|
|
output_fig, axs = plt.subplots( |
|
1, topk + 1, dpi=plot_dpi, figsize=((topk + 1) * 2.1, 2) |
|
) |
|
|
|
|
|
x = x.detach().cpu()[0] |
|
axs[0].imshow(x[:3].permute(1, 2, 0).numpy()) |
|
axs[0].set_xlabel("Input Image") |
|
|
|
|
|
pred_confidence = topk_probabilities[0] |
|
for i, ax in enumerate(axs[1:]): |
|
expl = model.gradient_to_image( |
|
x, |
|
dynamic_weights[i], |
|
smooth=smooth, |
|
alpha_percentile=alpha_percentile, |
|
) |
|
|
|
if normalization_mode == NormalizationMode.INDIVIDUAL_X_CONFIDENCE: |
|
expl[:, :, -1] *= topk_probabilities[i].item() |
|
elif normalization_mode == NormalizationMode.WRT_PREDICTION and i > 0: |
|
expl[:, :, -1] *= (topk_probabilities[i] / pred_confidence).item() |
|
else: |
|
pass |
|
|
|
ax.imshow(expl) |
|
ax.set_xlabel(IMAGENET_CATEGORIES[topk_preds[i]]) |
|
|
|
for ax in axs: |
|
ax.set_xticks([]) |
|
ax.set_yticks([]) |
|
|
|
output_fig.tight_layout() |
|
|
|
return confidences, output_fig |
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
|
gr.Markdown( |
|
"""# B-cos Explanation Generation Demo |
|
[Repository](https://github.com/B-cos/B-cos-v2/) |
|
""" |
|
) |
|
|
|
with gr.Row(): |
|
selected_model = gr.Dropdown( |
|
MODEL_NAMES, value="densenet121_long", label="Select model" |
|
) |
|
|
|
with gr.Accordion("Options", open=False): |
|
do_resize = gr.Checkbox( |
|
label="Resize input image's shorter side to 256", value=True |
|
) |
|
do_center_crop = gr.Checkbox( |
|
label="Center crop input image to 224x224", value=False |
|
) |
|
normalization_mode = gr.Radio( |
|
NormalizationMode.all(), |
|
value=NormalizationMode.WRT_PREDICTION, |
|
label="Normalization Mode", |
|
) |
|
|
|
smooth = gr.Slider(1, 51, value=15, step=2, label="Smoothing kernel size") |
|
alpha_percentile = gr.Number(value=99.99, label="Percentile") |
|
plot_dpi = gr.Number(value=100, label="Plot DPI") |
|
|
|
input_image = gr.Image(type="pil", label="Image") |
|
run_button = gr.Button("Predict and Explain", variant="primary") |
|
|
|
|
|
output = gr.Plot(label="Explanations") |
|
|
|
output_labels = gr.Label(label="Top-5 Predictions") |
|
|
|
run_button.click( |
|
fn=run, |
|
inputs=[ |
|
selected_model, |
|
input_image, |
|
do_resize, |
|
do_center_crop, |
|
normalization_mode, |
|
smooth, |
|
alpha_percentile, |
|
plot_dpi, |
|
], |
|
outputs=[output_labels, output], |
|
scroll_to_output=True, |
|
) |
|
|
|
|
|
demo.launch( |
|
queue=True, |
|
) |
|
|