sayakpaul's picture
sayakpaul HF staff
debug
2e6234b
raw
history blame contribute delete
No virus
2.21 kB
import gradio as gr
from PIL import Image
from torchvision import transforms
from gradcam import do_gradcam
from lrp import do_lrp, do_partial_lrp
from rollout import do_rollout
from tiba import do_tiba
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
TRANSFORM = transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
]
)
METHOD_MAP = {
"tiba": do_tiba,
"gradcam": do_gradcam,
"lrp": do_lrp,
"partial_lrp": do_partial_lrp,
"rollout": do_rollout,
}
def generate_viz(image, method, class_index=None):
if class_index is not None:
class_index = int(class_index)
print(f"Image: {image.size}")
print(f"Method: {method}")
print(f"Class: {class_index}")
viz_method = METHOD_MAP[method]
viz = viz_method(TRANSFORM, image, class_index=class_index)
viz.savefig("visualization.png")
return Image.open("visualization.png").convert("RGB")
title = "Compare different methods of explaining ViTs πŸ€–"
article = "Different methods for explaining Vision Transformers as explored by Chefer et al. in [Transformer Interpretability Beyond Attention Visualization, a novel method to visualize classifications by Transformer based networks](https://arxiv.org/abs/2012.09838)."
iface = gr.Interface(
generate_viz,
inputs=[
gr.Image(type="pil", label="Input Image"),
gr.Dropdown(
list(METHOD_MAP.keys()),
label="Method",
info="Explainability method to investigate.",
),
gr.Number(label="Class Index", info="Class index to inspect"),
],
outputs=gr.Image(),
title=title,
article=article,
allow_flagging="never",
cache_examples=True,
examples=[
["Transformer-Explainability/samples/catdog.png", "tiba", None],
["Transformer-Explainability/samples/catdog.png", "rollout", 243],
["Transformer-Explainability/samples/el2.png", "tiba", None],
["Transformer-Explainability/samples/el2.png", "gradcam", 340],
["Transformer-Explainability/samples/dogbird.png", "lrp", 161],
],
)
iface.launch(debug=True)