from fastai.vision.all import * import cv2 import gradio as gr import glob class Hook(): def hook_func(self, m, i, o): self.stored = o.detach().clone() learn = load_learner("resnet152_fit_one_cycle_freeze_91acc.pkl", cpu=True) path = "train_val_cropped" dblock = DataBlock(blocks = (ImageBlock, CategoryBlock), get_items = get_image_files, splitter = GrandparentSplitter(train_name="train", valid_name="valid"), get_y=parent_label, item_tfms=RandomResizedCrop(128, min_scale=0.7), batch_tfms=[*aug_transforms(), Normalize.from_stats(*imagenet_stats)]) dls_augmented = dblock.dataloaders(path, shuffle=True) class Hook(): def hook_func(self, m, i, o): self.stored = o.detach().clone() def gradcam(img_create): pred,idx,probs=learn.predict(img_create) x,= first(dls_augmented.test_dl([img_create])) hook_output = Hook() hook = learn.model[0].register_forward_hook(hook_output.hook_func) with torch.no_grad(): output = learn.model.eval()(x) act = hook_output.stored[0] hook.remove() input_size=act.shape[0] out_size=learn.model[1][-1].in_features kernel_size=act.shape[1] a=act new_act=tensor(np.zeros((out_size,kernel_size,kernel_size))) sum=tensor(np.zeros((1,kernel_size,kernel_size))) for i in range(0,input_size,4): sum=tensor(np.zeros((1,kernel_size,kernel_size))) for j in range(i,i+4): sum=sum+act[j,:,:] new_act[int(i/4),:,:]=sum/4 cam_map = torch.einsum('ck,kij->cij', learn.model[1][-1].weight, new_act) x_dec = TensorImage(dls_augmented.train.decode((x,))[0][0]) #plt.rcParams['figure.figsize'] = (128, 128) fig,ax = plt.subplots() x_dec.show(ctx=ax) ax.imshow(cam_map[1].detach().cpu(), alpha=0.6, extent=(0,128,128,0), interpolation='bilinear', cmap='magma'); plt.tight_layout() fig.savefig("gcam.jpg", dpi=199) im=cv2.imread("gcam.jpg",1) im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) #im=np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') return im, dict(zip(categories, map(float, probs))) def classify_img(img): pred,idx,probs=learn.predict(img) return dict(zip(categories, map(float, probs))) categories = ('arbanasi', 'filibe', 'gjirokoster', 'iskodra', 'kula', 'kuzguncuk', 'larissa_ampelakia', 'mardin', 'ohrid', 'pristina', 'safranbolu', 'selanik', 'sozopol_suzebolu', 'tiran', 'varna') #def classify_img(img): # pred,idx,probs=learn.predict(img) # return dict(zip(categories, map(float, probs))) image=gr.inputs.Image(shape=(128,128)) label=gr.outputs.Label() #examples_=[] #for i in glob.glob("valid/**/*.jpg", recursive=True): # examples_.append(i) examples=["filibe-1-1.jpg", "ohrid-3-1.jpg", "varna-1-1.jpg"] demo = gr.Interface(fn=classify_img, inputs=image, outputs=label, examples=examples) demo.launch(inline=False)