|
from fastai.vision.all import * |
|
import cv2 |
|
import gradio as gr |
|
import glob |
|
|
|
class Hook(): |
|
def hook_func(self, m, i, o): self.stored = o.detach().clone() |
|
|
|
learn = load_learner("resnet152_fit_one_cycle_freeze_91acc.pkl", cpu=True) |
|
|
|
path = "train_val_cropped" |
|
dblock = DataBlock(blocks = (ImageBlock, CategoryBlock), |
|
get_items = get_image_files, |
|
splitter = GrandparentSplitter(train_name="train", |
|
valid_name="valid"), |
|
get_y=parent_label, |
|
item_tfms=RandomResizedCrop(128, min_scale=0.7), |
|
batch_tfms=[*aug_transforms(), |
|
Normalize.from_stats(*imagenet_stats)]) |
|
dls_augmented = dblock.dataloaders(path, shuffle=True) |
|
|
|
class Hook(): |
|
def hook_func(self, m, i, o): self.stored = o.detach().clone() |
|
|
|
def gradcam(img_create): |
|
|
|
pred,idx,probs=learn.predict(img_create) |
|
|
|
x,= first(dls_augmented.test_dl([img_create])) |
|
hook_output = Hook() |
|
hook = learn.model[0].register_forward_hook(hook_output.hook_func) |
|
with torch.no_grad(): output = learn.model.eval()(x) |
|
act = hook_output.stored[0] |
|
hook.remove() |
|
|
|
input_size=act.shape[0] |
|
out_size=learn.model[1][-1].in_features |
|
kernel_size=act.shape[1] |
|
a=act |
|
new_act=tensor(np.zeros((out_size,kernel_size,kernel_size))) |
|
sum=tensor(np.zeros((1,kernel_size,kernel_size))) |
|
for i in range(0,input_size,4): |
|
sum=tensor(np.zeros((1,kernel_size,kernel_size))) |
|
for j in range(i,i+4): |
|
sum=sum+act[j,:,:] |
|
new_act[int(i/4),:,:]=sum/4 |
|
cam_map = torch.einsum('ck,kij->cij', learn.model[1][-1].weight, new_act) |
|
|
|
x_dec = TensorImage(dls_augmented.train.decode((x,))[0][0]) |
|
|
|
fig,ax = plt.subplots() |
|
x_dec.show(ctx=ax) |
|
ax.imshow(cam_map[1].detach().cpu(), alpha=0.6, extent=(0,128,128,0), interpolation='bilinear', cmap='magma'); |
|
plt.tight_layout() |
|
fig.savefig("gcam.jpg", dpi=199) |
|
im=cv2.imread("gcam.jpg",1) |
|
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) |
|
|
|
|
|
return im, dict(zip(categories, map(float, probs))) |
|
|
|
def classify_img(img): |
|
pred,idx,probs=learn.predict(img) |
|
return dict(zip(categories, map(float, probs))) |
|
|
|
categories = ('arbanasi', 'filibe', 'gjirokoster', 'iskodra', 'kula', 'kuzguncuk', 'larissa_ampelakia', 'mardin', 'ohrid', 'pristina', 'safranbolu', 'selanik', 'sozopol_suzebolu', 'tiran', 'varna') |
|
|
|
|
|
|
|
|
|
image=gr.inputs.Image(shape=(128,128)) |
|
label=gr.outputs.Label() |
|
|
|
|
|
|
|
|
|
examples=["filibe-1-1.jpg", |
|
"ohrid-3-1.jpg", |
|
"varna-1-1.jpg"] |
|
|
|
|
|
demo = gr.Interface(fn=classify_img, inputs=image, outputs=label, examples=examples) |
|
|
|
demo.launch(inline=False) |