File size: 2,941 Bytes
8638497 750a16f 8638497 a10cf7f 8638497 53069b2 8638497 f7cb746 b00a19c 750a16f b00a19c c512393 b00a19c ce132fe 8614737 c378bd2 b00a19c 2579041 b00a19c 87b8eae 8638497 1e1bc8b 8638497 1e1bc8b e6e8fd1 598891b 53069b2 87b8eae 76ac7fe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
from fastai.vision.all import *
import cv2
import gradio as gr
import glob
class Hook():
def hook_func(self, m, i, o): self.stored = o.detach().clone()
learn = load_learner("resnet152_fit_one_cycle_freeze_91acc.pkl", cpu=True)
path = "train_val_cropped"
dblock = DataBlock(blocks = (ImageBlock, CategoryBlock),
get_items = get_image_files,
splitter = GrandparentSplitter(train_name="train",
valid_name="valid"),
get_y=parent_label,
item_tfms=RandomResizedCrop(128, min_scale=0.7),
batch_tfms=[*aug_transforms(),
Normalize.from_stats(*imagenet_stats)])
dls_augmented = dblock.dataloaders(path, shuffle=True)
class Hook():
def hook_func(self, m, i, o): self.stored = o.detach().clone()
def gradcam(img_create):
pred,idx,probs=learn.predict(img_create)
x,= first(dls_augmented.test_dl([img_create]))
hook_output = Hook()
hook = learn.model[0].register_forward_hook(hook_output.hook_func)
with torch.no_grad(): output = learn.model.eval()(x)
act = hook_output.stored[0]
hook.remove()
input_size=act.shape[0]
out_size=learn.model[1][-1].in_features
kernel_size=act.shape[1]
a=act
new_act=tensor(np.zeros((out_size,kernel_size,kernel_size)))
sum=tensor(np.zeros((1,kernel_size,kernel_size)))
for i in range(0,input_size,4):
sum=tensor(np.zeros((1,kernel_size,kernel_size)))
for j in range(i,i+4):
sum=sum+act[j,:,:]
new_act[int(i/4),:,:]=sum/4
cam_map = torch.einsum('ck,kij->cij', learn.model[1][-1].weight, new_act)
x_dec = TensorImage(dls_augmented.train.decode((x,))[0][0])
#plt.rcParams['figure.figsize'] = (128, 128)
fig,ax = plt.subplots()
x_dec.show(ctx=ax)
ax.imshow(cam_map[1].detach().cpu(), alpha=0.6, extent=(0,128,128,0), interpolation='bilinear', cmap='magma');
plt.tight_layout()
fig.savefig("gcam.jpg", dpi=199)
im=cv2.imread("gcam.jpg",1)
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
#im=np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
return im, dict(zip(categories, map(float, probs)))
def classify_img(img):
pred,idx,probs=learn.predict(img)
return dict(zip(categories, map(float, probs)))
categories = ('arbanasi', 'filibe', 'gjirokoster', 'iskodra', 'kula', 'kuzguncuk', 'larissa_ampelakia', 'mardin', 'ohrid', 'pristina', 'safranbolu', 'selanik', 'sozopol_suzebolu', 'tiran', 'varna')
#def classify_img(img):
# pred,idx,probs=learn.predict(img)
# return dict(zip(categories, map(float, probs)))
image=gr.inputs.Image(shape=(128,128))
label=gr.outputs.Label()
#examples_=[]
#for i in glob.glob("valid/**/*.jpg", recursive=True):
# examples_.append(i)
examples=["filibe-1-1.jpg",
"ohrid-3-1.jpg",
"varna-1-1.jpg"]
demo = gr.Interface(fn=classify_img, inputs=image, outputs=label, examples=examples)
demo.launch(inline=False) |