Siddharth Maddali commited on
Commit
a692c75
1 Parent(s): 3455f46

Complete classifier app with activation map plotting.

Browse files
app.py DELETED
@@ -1,44 +0,0 @@
1
- # AUTOGENERATED! DO NOT EDIT! File to edit: four-way-classifier.ipynb.
2
-
3
- # %% auto 0
4
- __all__ = ['learn', 'categories', 'description', 'image', 'label', 'examples', 'interf', 'classify_image']
5
-
6
- # %% four-way-classifier.ipynb 1
7
- from fastai.vision.all import *
8
- import gradio as gr
9
- import PIL
10
- import glob
11
-
12
- # %% four-way-classifier.ipynb 3
13
- learn = load_learner( 'simple-image-classifier.pkl' )
14
-
15
- # %% four-way-classifier.ipynb 5
16
- categories = [ 'bird', 'forest', 'otter', 'snake' ]
17
-
18
- def classify_image( img ):
19
- pred, idx, probs = learn.predict( img )
20
- # print( 'This is a %s'%pred )
21
- return dict( zip( categories, map( float, probs ) ) )
22
-
23
- # %% four-way-classifier.ipynb 7
24
- description='''
25
- A simple 4-way classifier that categorizes images as 'snake', 'bird', 'otter' or 'forest'.
26
- Refined from an initial ResNet18 model downloaded from HuggingFace.
27
-
28
- **DISCLAIMER**: the images here are merely for demonstration purposes. I don't own any of them
29
- and I'm not making money from them.
30
- '''
31
- image = gr.components.Image()
32
- label = gr.components.Label()
33
- examples = glob.glob( './*.jpg' )
34
-
35
- interf = gr.Interface(
36
- title='Simple 4-way image classifier',
37
- description=description,
38
- fn=classify_image,
39
- inputs=image,
40
- outputs=label,
41
- examples=examples,
42
- allow_flagging='manual'
43
- )
44
- interf.launch( inline=True )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
simple_image_classifier/.ipynb_checkpoints/app-checkpoint.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AUTOGENERATED! DO NOT EDIT! File to edit: ../gradio_test.ipynb.
2
+
3
+ # %% auto 0
4
+ __all__ = ['iface', 'greet']
5
+
6
+ # %% ../gradio_test.ipynb 1
7
+ import gradio as gr
8
+
9
+ # %% ../gradio_test.ipynb 3
10
+ def greet( name ):
11
+ return 'Hello %s!!'%name
12
+
13
+ iface = gr.Interface( fn=greet, inputs='text', outputs='text' )
14
+ iface.launch( inline=True )
15
+
simple_image_classifier/app.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AUTOGENERATED! DO NOT EDIT! File to edit: ../four-way-classifier.ipynb.
2
+
3
+ # %% auto 0
4
+ __all__ = ['learn', 'categories', 'model', 'description', 'image', 'label', 'examples', 'interf', 'get_best_layout',
5
+ 'create_plots', 'classify_image']
6
+
7
+ # %% ../four-way-classifier.ipynb 1
8
+ from fastai.vision.all import *
9
+ import gradio as gr
10
+ import PIL
11
+ import glob
12
+
13
+ import primefac
14
+
15
+ from torchviz import make_dot
16
+ from pytorch_grad_cam.utils.image import preprocess_image
17
+
18
+ # %% ../four-way-classifier.ipynb 2
19
+ warnings.filterwarnings( 'ignore' )
20
+ warnings.simplefilter( 'ignore' )
21
+
22
+ # plt.rcParams[ 'figure.figsize' ] = [ 15, 15 ]
23
+ plt.rcParams[ 'figure.autolayout' ] = True
24
+
25
+ # %% ../four-way-classifier.ipynb 3
26
+ def get_best_layout( num_imgs, img_size ):
27
+ prime_factors = list( primefac.primefac( num_imgs ) )
28
+ aspect_ratios = []
29
+ for n in range( len( prime_factors[1:-1] ) ):
30
+ n1 = np.prod( prime_factors[:n] )
31
+ n2 = np.prod( prime_factors[n:] )
32
+ assert n1 * n2 == num_imgs
33
+ aspect_ratios.append( [ n1, n2, np.abs( img_size[0]/img_size[1] - ( n1/n2 ) ) ] )
34
+ # print( aspect_ratios )
35
+ n1_final, n2_final = tuple( aspect_ratios[ np.argmin( [ elem[2] for elem in aspect_ratios ] ) ] )[:2]
36
+ return n1_final, n2_final
37
+
38
+ def create_plots( output_tensor, title, fontsize=24 ):
39
+ temp = output_tensor.detach().numpy()
40
+ mn, mx = temp.min(), temp.max()
41
+ img_list = [ output_tensor[0][n].detach().numpy() for n in range( output_tensor.shape[1] ) ]
42
+ fig = plt.figure()
43
+ M, N = get_best_layout( len( img_list ), img_list[0].shape )
44
+ ax = fig.subplots( M, N, gridspec_kw={ 'wspace':0.01, 'hspace':0.01 } )
45
+ for N0, img in enumerate( img_list ):
46
+ m = N0//N
47
+ n = N0%N
48
+ im = ax[m,n].imshow( img, cmap='gray' )
49
+ ax[m,n].axis( 'off' )
50
+ im.set_clim( [ mn, mx ] )
51
+
52
+ plt.suptitle( title, fontsize=fontsize )
53
+ # plt.subplots_adjust( hspace=0.01, wspace=0.01 )
54
+ plt.tight_layout()
55
+ img_buf = io.BytesIO()
56
+ plt.savefig(img_buf, format='png')
57
+ im = Image.open(img_buf)
58
+ return im
59
+
60
+ def classify_image( img ):
61
+ pred, idx, probs = learn.predict( img )
62
+ # print( 'This is a %s'%pred )
63
+ model = learn.model.eval()
64
+ rgb_img = np.float32( np.array( img ) ) / 255
65
+ input_tensor = preprocess_image(
66
+ rgb_img,
67
+ mean=[0.485, 0.456, 0.406],
68
+ std=[0.229, 0.224, 0.225]
69
+ )
70
+ im1 = create_plots( model[0][:3]( input_tensor ), title='Inferred features at model[0][:3]' )
71
+ im2 = create_plots( model[0][:5]( input_tensor ), title='Inferred features at model[0][:5]' )
72
+ im3 = create_plots( model[0][:7]( input_tensor ), title='Inferred features at model[0][:7]' )
73
+
74
+ return ( dict( zip( categories, map( float, probs ) ) ), im1, im2, im3 )
75
+
76
+ # %% ../four-way-classifier.ipynb 4
77
+ learn = load_learner( 'simple-image-classifier.pkl' )
78
+ categories = [ 'bird', 'forest', 'otter', 'snake' ]
79
+ model = learn.model.eval()
80
+
81
+ # %% ../four-way-classifier.ipynb 5
82
+ description='''
83
+ A simple 4-way classifier that categorizes images as 'snake', 'bird', 'otter' or 'forest'.
84
+ Refined from an initial ResNet18 model downloaded from HuggingFace.
85
+
86
+ **DISCLAIMER**: the images here are merely for demonstration purposes. I don't own any of them
87
+ and I'm not making money from them.
88
+ '''
89
+
90
+ # with gr.Blocks() as layout:
91
+ image = gr.components.Image()
92
+ label = [ gr.components.Label(), gr.components.Image(), gr.components.Image(), gr.components.Image() ]
93
+ examples = glob.glob( './*.jpg' )
94
+
95
+ # with gr.Row():
96
+ interf = gr.Interface(
97
+ title='Simple 4-way image classifier',
98
+ description=description,
99
+ fn=classify_image,
100
+ inputs=image,
101
+ outputs=label,
102
+ examples=examples,
103
+ allow_flagging='manual'
104
+ )
105
+ interf.launch( inline=True )