|
import gradio |
|
from fastai.vision.all import * |
|
|
|
MODELS_PATH = Path('./models') |
|
EXAMPLES_PATH = Path('./examples') |
|
|
|
learn = load_learner(MODELS_PATH/'model.pkl') |
|
labels = learn.dls.vocab |
|
|
|
class Hook(): |
|
def __init__(self, m): |
|
self.hook = m.register_forward_hook(self.hook_func) |
|
def hook_func(self, m, i, o): self.stored = o.detach().clone() |
|
def __enter__(self, *args): return self |
|
def __exit__(self, *args): self.hook.remove() |
|
|
|
class HookBwd(): |
|
def __init__(self, m): |
|
self.hook = m.register_backward_hook(self.hook_func) |
|
def hook_func(self, m, gi, go): self.stored = go[0].detach().clone() |
|
def __enter__(self, *args): return self |
|
def __exit__(self, *args): self.hook.remove() |
|
|
|
def predict(img): |
|
img = PILImage.create(img) |
|
_pred, _pred_w_idx, probs = learn.predict(img) |
|
labels_probs = {labels[i]: float(probs[i]) for i, _ in enumerate(labels)} |
|
|
|
x, = first(learn.dls.test_dl([img])) |
|
with torch.no_grad(): |
|
output = learn.model.eval()(x) |
|
cls = int(output.argmax()) |
|
|
|
x_dec = TensorImage(learn.dls.train.decode((x,))[0][0]) |
|
with HookBwd(learn.model[0].model.layer4) as hookg: |
|
with Hook(learn.model[0].model.layer4) as hook: |
|
output = learn.model.eval()(x) |
|
act = hook.stored |
|
output[0,cls].backward() |
|
grad = hookg.stored |
|
|
|
w = grad[0].mean(dim=[1,2], keepdim=True) |
|
cam_map = (w * act[0]).sum(0) |
|
|
|
_,ax = plt.subplots() |
|
x_dec.show(ctx=ax) |
|
ax.imshow(cam_map.detach().cpu(), alpha=0.7, extent=(0,128,128,0), |
|
interpolation='bilinear', cmap='magma'); |
|
if os.path.exists("gradcam.jpg"): |
|
os.remove("gradcam.jpg") |
|
plt.savefig("gradcam.jpg", format="jpg", bbox_inches='tight') |
|
plt.close() |
|
|
|
if learn.dls.vocab[cls] == "Negative": |
|
img.save("gradcam.jpg", format="JPEG") |
|
|
|
return labels_probs, Path("gradcam.jpg") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
demo = gradio.Interface(fn=predict, |
|
inputs=gradio.inputs.Image(shape=(512, 512), label="Chest X-ray"), |
|
outputs=[gradio.outputs.Label(num_top_classes=5, label="Detected Class"), |
|
gradio.outputs.Image(type="filepath", label="GradCAM")]) |
|
|
|
launch_options = { |
|
"enable_queue": True, |
|
"share": True, |
|
} |
|
|
|
demo.launch(**launch_options) |
|
|