File size: 3,915 Bytes
e0fc84b
a692c75
 
 
 
 
e0fc84b
a692c75
 
 
 
 
 
 
 
 
 
e0fc84b
a692c75
 
 
 
 
 
e0fc84b
a692c75
 
 
 
 
 
 
 
 
 
 
 
 
5a43470
 
a692c75
 
 
abee87a
a692c75
 
 
 
 
5a43470
a692c75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e0fc84b
a692c75
 
 
 
e0fc84b
a692c75
 
5a43470
a692c75
5a43470
 
abee87a
 
 
a692c75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
# AUTOGENERATED! DO NOT EDIT! File to edit: four-way-classifier.ipynb.

# %% auto 0
__all__ = ['learn', 'categories', 'model', 'description', 'image', 'label', 'examples', 'interf', 'get_best_layout',
           'create_plots', 'classify_image']

# %% four-way-classifier.ipynb 1
from fastai.vision.all import *
import gradio as gr
import PIL
import glob

import primefac

from torchviz import make_dot
from pytorch_grad_cam.utils.image import preprocess_image

# %% four-way-classifier.ipynb 2
warnings.filterwarnings( 'ignore' )
warnings.simplefilter( 'ignore' )

# plt.rcParams[ 'figure.figsize' ] = [ 15, 15 ]
plt.rcParams[ 'figure.autolayout' ] = True

# %% four-way-classifier.ipynb 3
def get_best_layout( num_imgs, img_size ):
    prime_factors = list( primefac.primefac( num_imgs ) )
    aspect_ratios = []
    for n in range( len( prime_factors[1:-1] ) ):
        n1 = np.prod( prime_factors[:n] )
        n2 = np.prod( prime_factors[n:] )
        assert n1 * n2 == num_imgs
        aspect_ratios.append( [ n1, n2, np.abs( img_size[0]/img_size[1] - ( n1/n2 ) ) ] )
    # print( aspect_ratios )
    n1_final, n2_final = tuple( aspect_ratios[ np.argmin( [ elem[2] for elem in aspect_ratios ] ) ] )[:2]
    return n1_final, n2_final

def create_plots( output_tensor, title, fontsize=24 ):
    # temp = output_tensor.detach().numpy()
    # mn, mx = temp.min(), temp.max()
    img_list = [ output_tensor[0][n].detach().numpy() for n in range( output_tensor.shape[1] ) ]
    fig = plt.figure()
    M, N = get_best_layout( len( img_list ), img_list[0].shape )
    ax = fig.subplots(  M, N )
    for N0, img in enumerate( img_list ):
        m = N0//N
        n = N0%N
        im = ax[m,n].imshow( img, cmap='gray' )
        ax[m,n].axis( 'off' )
        # im.set_clim( [ mn, mx ] )

    plt.suptitle( title, fontsize=fontsize )
    # plt.subplots_adjust( hspace=0.01, wspace=0.01 )
    plt.tight_layout()
    img_buf = io.BytesIO()
    plt.savefig(img_buf, format='png')
    im = Image.open(img_buf)
    return im

def classify_image( img ):
    pred, idx, probs = learn.predict( img )
    # print( 'This is a %s'%pred )
    model = learn.model.eval()
    rgb_img = np.float32( np.array( img ) ) / 255
    input_tensor = preprocess_image(
        rgb_img,
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
    im1 = create_plots( model[0][:3]( input_tensor ), title='Inferred features at model[0][:3]' )
    im2 = create_plots( model[0][:5]( input_tensor ), title='Inferred features at model[0][:5]' )
    im3 = create_plots( model[0][:7]( input_tensor ), title='Inferred features at model[0][:7]' )
    
    return ( dict( zip( categories, map( float, probs ) ) ), im1, im2, im3 )

# %% four-way-classifier.ipynb 4
learn = load_learner( 'simple-image-classifier.pkl' )
categories = [ 'bird', 'forest', 'otter', 'snake' ]
model = learn.model.eval()

# %% four-way-classifier.ipynb 5
description='''
A simple 4-way classifier that categorizes images as 'snake', 'bird', 'otter' or 'forest'. 
Refined from a pre-trained ResNet18 model downloaded from HuggingFace.

The actual classification for each test image actually takes a very short time; the delay in displaying results is due to 
the extra step of plotting the intermediate activation maps and inferred features in `matplotlib`. 


**DISCLAIMER**: the images here are merely for demonstration purposes. I don't own any of them and I'm not making money from them.  
'''

# with gr.Blocks() as layout: 
image = gr.components.Image()
label = [ gr.components.Label(), gr.components.Image(), gr.components.Image(), gr.components.Image() ]
examples = glob.glob( './*.jpg' )

# with gr.Row():
interf = gr.Interface( 
    title='Simple 4-way image classifier', 
    description=description, 
    fn=classify_image, 
    inputs=image, 
    outputs=label, 
    examples=examples, 
    allow_flagging='manual' 
)
interf.launch( inline=True )