import gradio as gr from PIL import Image from torchvision import transforms from explain import do_explain normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) TRANSFORM = transforms.Compose( [ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize, ] ) def generate_viz(image, class_index=None): if class_index is not None: class_index = int(class_index) print(f"Image: {image.size}") print(f"Class: {class_index}") viz, pred = do_explain(TRANSFORM, image, class_index=class_index) viz.savefig("visualization.png") return Image.open("visualization.png").convert("RGB"), pred title = "Explain ViT 😊" iface = gr.Interface(fn=generate_viz, inputs=[ gr.Image(type="pil", label="Input Image"), gr.Number(label="Class Index", info="Class index to explain"), ], outputs=[ gr.Image(label="XAI-Image"), gr.Text(label="prob"),], title=title, allow_flagging="never", cache_examples=True, examples=[ ["ViT_DeiT/samples/catdog.png",None], ["ViT_DeiT/samples/catdog.png", 243], ["ViT_DeiT/samples/el2.png", None], ["ViT_DeiT/samples/el2.png", 340], ], ) iface.launch(debug=True)