Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from PIL import Image | |
| from torchvision import transforms | |
| from explain import do_explain | |
| normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) | |
| TRANSFORM = transforms.Compose( | |
| [ | |
| transforms.Resize(256), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| normalize, | |
| ] | |
| ) | |
| def generate_viz(image, class_index=None): | |
| if class_index is not None: | |
| class_index = int(class_index) | |
| print(f"Image: {image.size}") | |
| print(f"Class: {class_index}") | |
| viz, pred = do_explain(TRANSFORM, image, class_index=class_index) | |
| viz.savefig("visualization.png") | |
| return Image.open("visualization.png").convert("RGB"), pred | |
| title = "Explain ViT π" | |
| iface = gr.Interface(fn=generate_viz, inputs=[ | |
| gr.Image(type="pil", label="Input Image"), | |
| gr.Number(label="Class Index", info="Class index to explain"), | |
| ], | |
| outputs=[ gr.Image(label="XAI-Image"), gr.Text(label="prob"),], | |
| title=title, | |
| allow_flagging="never", | |
| cache_examples=True, | |
| examples=[ | |
| ["ViT_DeiT/samples/catdog.png",None], | |
| ["ViT_DeiT/samples/catdog.png", 243], | |
| ["ViT_DeiT/samples/el2.png", None], | |
| ["ViT_DeiT/samples/el2.png", 340], | |
| ], | |
| ) | |
| iface.launch(debug=True) |