File size: 4,031 Bytes
e0fc84b
a692c75
 
 
 
 
e0fc84b
a692c75
 
 
 
 
 
 
 
 
 
e0fc84b
a692c75
 
 
 
 
 
e0fc84b
a692c75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
abee87a
a692c75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e0fc84b
a692c75
 
 
 
e0fc84b
a692c75
 
 
abee87a
a692c75
abee87a
 
 
 
 
a692c75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
# AUTOGENERATED! DO NOT EDIT! File to edit: four-way-classifier.ipynb.

# %% auto 0
__all__ = ['learn', 'categories', 'model', 'description', 'image', 'label', 'examples', 'interf', 'get_best_layout',
           'create_plots', 'classify_image']

# %% four-way-classifier.ipynb 1
from fastai.vision.all import *
import gradio as gr
import PIL
import glob

import primefac

from torchviz import make_dot
from pytorch_grad_cam.utils.image import preprocess_image

# %% four-way-classifier.ipynb 2
warnings.filterwarnings( 'ignore' )
warnings.simplefilter( 'ignore' )

# plt.rcParams[ 'figure.figsize' ] = [ 15, 15 ]
plt.rcParams[ 'figure.autolayout' ] = True

# %% four-way-classifier.ipynb 3
def get_best_layout( num_imgs, img_size ):
    prime_factors = list( primefac.primefac( num_imgs ) )
    aspect_ratios = []
    for n in range( len( prime_factors[1:-1] ) ):
        n1 = np.prod( prime_factors[:n] )
        n2 = np.prod( prime_factors[n:] )
        assert n1 * n2 == num_imgs
        aspect_ratios.append( [ n1, n2, np.abs( img_size[0]/img_size[1] - ( n1/n2 ) ) ] )
    # print( aspect_ratios )
    n1_final, n2_final = tuple( aspect_ratios[ np.argmin( [ elem[2] for elem in aspect_ratios ] ) ] )[:2]
    return n1_final, n2_final

def create_plots( output_tensor, title, fontsize=24 ):
    temp = output_tensor.detach().numpy()
    mn, mx = temp.min(), temp.max()
    img_list = [ output_tensor[0][n].detach().numpy() for n in range( output_tensor.shape[1] ) ]
    fig = plt.figure()
    M, N = get_best_layout( len( img_list ), img_list[0].shape )
    ax = fig.subplots(  M, N )
    for N0, img in enumerate( img_list ):
        m = N0//N
        n = N0%N
        im = ax[m,n].imshow( img, cmap='gray' )
        ax[m,n].axis( 'off' )
        im.set_clim( [ mn, mx ] )

    plt.suptitle( title, fontsize=fontsize )
    # plt.subplots_adjust( hspace=0.01, wspace=0.01 )
    plt.tight_layout()
    img_buf = io.BytesIO()
    plt.savefig(img_buf, format='png')
    im = Image.open(img_buf)
    return im

def classify_image( img ):
    pred, idx, probs = learn.predict( img )
    # print( 'This is a %s'%pred )
    model = learn.model.eval()
    rgb_img = np.float32( np.array( img ) ) / 255
    input_tensor = preprocess_image(
        rgb_img,
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
    im1 = create_plots( model[0][:3]( input_tensor ), title='Inferred features at model[0][:3]' )
    im2 = create_plots( model[0][:5]( input_tensor ), title='Inferred features at model[0][:5]' )
    im3 = create_plots( model[0][:7]( input_tensor ), title='Inferred features at model[0][:7]' )
    
    return ( dict( zip( categories, map( float, probs ) ) ), im1, im2, im3 )

# %% four-way-classifier.ipynb 4
learn = load_learner( 'simple-image-classifier.pkl' )
categories = [ 'bird', 'forest', 'otter', 'snake' ]
model = learn.model.eval()

# %% four-way-classifier.ipynb 5
description='''
A simple 4-way classifier that categorizes images as 'snake', 'bird', 'otter' or 'forest'. 
Refined from an initial ResNet18 model downloaded from HuggingFace.
The test images given here are chosen to demonstrate the effect of lack of training data on the classification outcome. 

The actual classification for each test image actually takes a very short time; the delay in predicing results here is due to 
the extra step of platting the intermediate activation maps and inferred features in `matplotlib`. 


**DISCLAIMER**: the images here are merely for demonstration purposes. I don't own any of them and I'm not making money from them.  
'''

# with gr.Blocks() as layout: 
image = gr.components.Image()
label = [ gr.components.Label(), gr.components.Image(), gr.components.Image(), gr.components.Image() ]
examples = glob.glob( './*.jpg' )

# with gr.Row():
interf = gr.Interface( 
    title='Simple 4-way image classifier', 
    description=description, 
    fn=classify_image, 
    inputs=image, 
    outputs=label, 
    examples=examples, 
    allow_flagging='manual' 
)
interf.launch( inline=True )