import gradio from fastai.vision.all import * MODELS_PATH = Path('./models') EXAMPLES_PATH = Path('./examples') learn = load_learner(MODELS_PATH/'model.pkl') labels = learn.dls.vocab class Hook(): def __init__(self, m): self.hook = m.register_forward_hook(self.hook_func) def hook_func(self, m, i, o): self.stored = o.detach().clone() def __enter__(self, *args): return self def __exit__(self, *args): self.hook.remove() class HookBwd(): def __init__(self, m): self.hook = m.register_backward_hook(self.hook_func) def hook_func(self, m, gi, go): self.stored = go[0].detach().clone() def __enter__(self, *args): return self def __exit__(self, *args): self.hook.remove() def predict(img): img = PILImage.create(img) _pred, _pred_w_idx, probs = learn.predict(img) labels_probs = {labels[i]: float(probs[i]) for i, _ in enumerate(labels)} x, = first(learn.dls.test_dl([img])) with torch.no_grad(): output = learn.model.eval()(x) cls = int(output.argmax()) x_dec = TensorImage(learn.dls.train.decode((x,))[0][0]) with HookBwd(learn.model[0].model.layer4) as hookg: with Hook(learn.model[0].model.layer4) as hook: output = learn.model.eval()(x) act = hook.stored output[0,cls].backward() grad = hookg.stored w = grad[0].mean(dim=[1,2], keepdim=True) cam_map = (w * act[0]).sum(0) _,ax = plt.subplots() x_dec.show(ctx=ax) ax.imshow(cam_map.detach().cpu(), alpha=0.7, extent=(0,128,128,0), interpolation='bilinear', cmap='magma'); if os.path.exists("gradcam.jpg"): os.remove("gradcam.jpg") plt.savefig("gradcam.jpg", format="jpg", bbox_inches='tight') plt.close() if learn.dls.vocab[cls] == "Negative": img.save("gradcam.jpg", format="JPEG") return labels_probs, Path("gradcam.jpg") # with open('gradio_article.md') as f: # article = f.read() # interface_options = { # "title": "RSNA Pneumonia Detection", # "description": "An algorithm that automatically detects potential pneumonia cases. Upload an image or select from the examples below.", # "examples": [f'{EXAMPLES_PATH}/{f.name}' for f in EXAMPLES_PATH.iterdir()], # "article": article, # "layout": "horizontal", # "theme": "default", # } demo = gradio.Interface(fn=predict, inputs=gradio.inputs.Image(shape=(512, 512), label="Chest X-ray"), outputs=[gradio.outputs.Label(num_top_classes=5, label="Detected Class"), gradio.outputs.Image(type="filepath", label="GradCAM")]) launch_options = { "enable_queue": True, "share": False, } demo.launch(**launch_options)