MankiratSingh1315's picture
Update app.py
882a036
raw
history blame contribute delete
No virus
2.76 kB
import gradio
from fastai.vision.all import *
MODELS_PATH = Path('./models')
EXAMPLES_PATH = Path('./examples')
learn = load_learner(MODELS_PATH/'model.pkl')
labels = learn.dls.vocab
class Hook():
def __init__(self, m):
self.hook = m.register_forward_hook(self.hook_func)
def hook_func(self, m, i, o): self.stored = o.detach().clone()
def __enter__(self, *args): return self
def __exit__(self, *args): self.hook.remove()
class HookBwd():
def __init__(self, m):
self.hook = m.register_backward_hook(self.hook_func)
def hook_func(self, m, gi, go): self.stored = go[0].detach().clone()
def __enter__(self, *args): return self
def __exit__(self, *args): self.hook.remove()
def predict(img):
img = PILImage.create(img)
_pred, _pred_w_idx, probs = learn.predict(img)
labels_probs = {labels[i]: float(probs[i]) for i, _ in enumerate(labels)}
x, = first(learn.dls.test_dl([img]))
with torch.no_grad():
output = learn.model.eval()(x)
cls = int(output.argmax())
x_dec = TensorImage(learn.dls.train.decode((x,))[0][0])
with HookBwd(learn.model[0].model.layer4) as hookg:
with Hook(learn.model[0].model.layer4) as hook:
output = learn.model.eval()(x)
act = hook.stored
output[0,cls].backward()
grad = hookg.stored
w = grad[0].mean(dim=[1,2], keepdim=True)
cam_map = (w * act[0]).sum(0)
_,ax = plt.subplots()
x_dec.show(ctx=ax)
ax.imshow(cam_map.detach().cpu(), alpha=0.7, extent=(0,128,128,0),
interpolation='bilinear', cmap='magma');
if os.path.exists("gradcam.jpg"):
os.remove("gradcam.jpg")
plt.savefig("gradcam.jpg", format="jpg", bbox_inches='tight')
plt.close()
if learn.dls.vocab[cls] == "Negative":
img.save("gradcam.jpg", format="JPEG")
return labels_probs, Path("gradcam.jpg")
# with open('gradio_article.md') as f:
# article = f.read()
# interface_options = {
# "title": "RSNA Pneumonia Detection",
# "description": "An algorithm that automatically detects potential pneumonia cases. Upload an image or select from the examples below.",
# "examples": [f'{EXAMPLES_PATH}/{f.name}' for f in EXAMPLES_PATH.iterdir()],
# "article": article,
# "layout": "horizontal",
# "theme": "default",
# }
demo = gradio.Interface(fn=predict,
inputs=gradio.inputs.Image(shape=(512, 512), label="Chest X-ray"),
outputs=[gradio.outputs.Label(num_top_classes=5, label="Detected Class"),
gradio.outputs.Image(type="filepath", label="GradCAM")])
launch_options = {
"enable_queue": True,
"share": False,
}
demo.launch(**launch_options)