|
|
|
|
|
|
|
__all__ = ['learn', 'categories', 'model', 'description', 'image', 'label', 'examples', 'interf', 'get_best_layout', |
|
'create_plots', 'classify_image'] |
|
|
|
|
|
from fastai.vision.all import * |
|
import gradio as gr |
|
import PIL |
|
import glob |
|
|
|
import primefac |
|
|
|
from torchviz import make_dot |
|
from pytorch_grad_cam.utils.image import preprocess_image |
|
|
|
|
|
warnings.filterwarnings( 'ignore' ) |
|
warnings.simplefilter( 'ignore' ) |
|
|
|
|
|
plt.rcParams[ 'figure.autolayout' ] = True |
|
|
|
|
|
def get_best_layout( num_imgs, img_size ): |
|
prime_factors = list( primefac.primefac( num_imgs ) ) |
|
aspect_ratios = [] |
|
for n in range( len( prime_factors[1:-1] ) ): |
|
n1 = np.prod( prime_factors[:n] ) |
|
n2 = np.prod( prime_factors[n:] ) |
|
assert n1 * n2 == num_imgs |
|
aspect_ratios.append( [ n1, n2, np.abs( img_size[0]/img_size[1] - ( n1/n2 ) ) ] ) |
|
|
|
n1_final, n2_final = tuple( aspect_ratios[ np.argmin( [ elem[2] for elem in aspect_ratios ] ) ] )[:2] |
|
return n1_final, n2_final |
|
|
|
def create_plots( output_tensor, title, fontsize=24 ): |
|
|
|
|
|
img_list = [ output_tensor[0][n].detach().numpy() for n in range( output_tensor.shape[1] ) ] |
|
fig = plt.figure() |
|
M, N = get_best_layout( len( img_list ), img_list[0].shape ) |
|
ax = fig.subplots( M, N ) |
|
for N0, img in enumerate( img_list ): |
|
m = N0//N |
|
n = N0%N |
|
im = ax[m,n].imshow( img, cmap='gray' ) |
|
ax[m,n].axis( 'off' ) |
|
|
|
|
|
plt.suptitle( title, fontsize=fontsize ) |
|
|
|
plt.tight_layout() |
|
img_buf = io.BytesIO() |
|
plt.savefig(img_buf, format='png') |
|
im = Image.open(img_buf) |
|
return im |
|
|
|
def classify_image( img ): |
|
pred, idx, probs = learn.predict( img ) |
|
|
|
model = learn.model.eval() |
|
rgb_img = np.float32( np.array( img ) ) / 255 |
|
input_tensor = preprocess_image( |
|
rgb_img, |
|
mean=[0.485, 0.456, 0.406], |
|
std=[0.229, 0.224, 0.225] |
|
) |
|
im1 = create_plots( model[0][:3]( input_tensor ), title='Inferred features at model[0][:3]' ) |
|
im2 = create_plots( model[0][:5]( input_tensor ), title='Inferred features at model[0][:5]' ) |
|
im3 = create_plots( model[0][:7]( input_tensor ), title='Inferred features at model[0][:7]' ) |
|
|
|
return ( dict( zip( categories, map( float, probs ) ) ), im1, im2, im3 ) |
|
|
|
|
|
learn = load_learner( 'simple-image-classifier.pkl' ) |
|
categories = [ 'bird', 'forest', 'otter', 'snake' ] |
|
model = learn.model.eval() |
|
|
|
|
|
description=''' |
|
A simple 4-way classifier that categorizes images as 'snake', 'bird', 'otter' or 'forest'. |
|
Refined from a pre-trained ResNet18 model downloaded from HuggingFace. |
|
|
|
The actual classification for each test image actually takes a very short time; the delay in displaying results is due to |
|
the extra step of plotting the intermediate activation maps and inferred features in `matplotlib`. |
|
|
|
|
|
**DISCLAIMER**: the images here are merely for demonstration purposes. I don't own any of them and I'm not making money from them. |
|
''' |
|
|
|
|
|
image = gr.components.Image() |
|
label = [ gr.components.Label(), gr.components.Image(), gr.components.Image(), gr.components.Image() ] |
|
examples = glob.glob( './*.jpg' ) |
|
|
|
|
|
interf = gr.Interface( |
|
title='Simple 4-way image classifier', |
|
description=description, |
|
fn=classify_image, |
|
inputs=image, |
|
outputs=label, |
|
examples=examples, |
|
allow_flagging='manual' |
|
) |
|
interf.launch( inline=True ) |
|
|