draft / app.py
khan994's picture
Update app.py
87b8eae verified
from fastai.vision.all import *
import cv2
import gradio as gr
import glob
class Hook():
def hook_func(self, m, i, o): self.stored = o.detach().clone()
learn = load_learner("resnet152_fit_one_cycle_freeze_91acc.pkl", cpu=True)
path = "train_val_cropped"
dblock = DataBlock(blocks = (ImageBlock, CategoryBlock),
get_items = get_image_files,
splitter = GrandparentSplitter(train_name="train",
valid_name="valid"),
get_y=parent_label,
item_tfms=RandomResizedCrop(128, min_scale=0.7),
batch_tfms=[*aug_transforms(),
Normalize.from_stats(*imagenet_stats)])
dls_augmented = dblock.dataloaders(path, shuffle=True)
class Hook():
def hook_func(self, m, i, o): self.stored = o.detach().clone()
def gradcam(img_create):
pred,idx,probs=learn.predict(img_create)
x,= first(dls_augmented.test_dl([img_create]))
hook_output = Hook()
hook = learn.model[0].register_forward_hook(hook_output.hook_func)
with torch.no_grad(): output = learn.model.eval()(x)
act = hook_output.stored[0]
hook.remove()
input_size=act.shape[0]
out_size=learn.model[1][-1].in_features
kernel_size=act.shape[1]
a=act
new_act=tensor(np.zeros((out_size,kernel_size,kernel_size)))
sum=tensor(np.zeros((1,kernel_size,kernel_size)))
for i in range(0,input_size,4):
sum=tensor(np.zeros((1,kernel_size,kernel_size)))
for j in range(i,i+4):
sum=sum+act[j,:,:]
new_act[int(i/4),:,:]=sum/4
cam_map = torch.einsum('ck,kij->cij', learn.model[1][-1].weight, new_act)
x_dec = TensorImage(dls_augmented.train.decode((x,))[0][0])
#plt.rcParams['figure.figsize'] = (128, 128)
fig,ax = plt.subplots()
x_dec.show(ctx=ax)
ax.imshow(cam_map[1].detach().cpu(), alpha=0.6, extent=(0,128,128,0), interpolation='bilinear', cmap='magma');
plt.tight_layout()
fig.savefig("gcam.jpg", dpi=199)
im=cv2.imread("gcam.jpg",1)
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
#im=np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
return im, dict(zip(categories, map(float, probs)))
def classify_img(img):
pred,idx,probs=learn.predict(img)
return dict(zip(categories, map(float, probs)))
categories = ('arbanasi', 'filibe', 'gjirokoster', 'iskodra', 'kula', 'kuzguncuk', 'larissa_ampelakia', 'mardin', 'ohrid', 'pristina', 'safranbolu', 'selanik', 'sozopol_suzebolu', 'tiran', 'varna')
#def classify_img(img):
# pred,idx,probs=learn.predict(img)
# return dict(zip(categories, map(float, probs)))
image=gr.inputs.Image(shape=(128,128))
label=gr.outputs.Label()
#examples_=[]
#for i in glob.glob("valid/**/*.jpg", recursive=True):
# examples_.append(i)
examples=["filibe-1-1.jpg",
"ohrid-3-1.jpg",
"varna-1-1.jpg"]
demo = gr.Interface(fn=classify_img, inputs=image, outputs=label, examples=examples)
demo.launch(inline=False)