File size: 3,714 Bytes
c6c7820
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
# AUTOGENERATED! DO NOT EDIT! File to edit: app.ipynb.

# %% auto 0
__all__ = ['single_classifier', 'multi_class_classifier', 'multi_label_classifier', 'binary_labels', 'multi_class_labels',
           'multi_label_labels', 'label_func', 'single_classification', 'multi_class_classification',
           'multi_label_classification']

# %% app.ipynb 1
import gradio as gr
import nbdev
from fastai.vision.all import *
import os

# %% app.ipynb 2
def label_func(f): return 'Cat' if f[0].isupper() else 'Dog'

# %% app.ipynb 3
single_classifier       = load_learner('models/dog-cat-classifier.pkl')
multi_class_classifier  = load_learner('models/breeds-classifier.pkl')
multi_label_classifier  = load_learner('models/multi-label-classification.pkl')

# %% app.ipynb 4
binary_labels = single_classifier.dls.vocab

def single_classification(img):
    img = PILImage.create(img)
    pred, pred_idx, probs = single_classifier.predict(img)
    return dict(zip(binary_labels, map(float, probs)))

# %% app.ipynb 5
multi_class_labels = multi_class_classifier.dls.vocab

def multi_class_classification(img):
    img = PILImage.create(img)
    pred, pred_idx, probs = multi_class_classifier.predict(img)
    return dict(zip(multi_class_labels, map(float, probs)))

# %% app.ipynb 6
multi_label_labels = multi_label_classifier.dls.vocab

def multi_label_classification(img):
    img = PILImage.create(img)
    pred, pred_idx, probs = multi_label_classifier.predict(img)
    return dict(zip(multi_label_labels, map(float, probs)))

# %% app.ipynb 7
with gr.Blocks() as demo:
    gr.Markdown("This demo allowing you to try different vision classification models - \

    From binary classification through multi-class and multi-label classification and finally segmentation.")

    with gr.Tab("Binary"):
        with gr.Row():
            with gr.Column():
                b_image_input = gr.inputs.Image(shape = (460,460))
                with gr.Row():
                    b_button = gr.Button("Run")

                b_examples = 'models/Examples/Pets'
                examples = gr.Examples(examples=[b_examples + '/shiba_inu_44.jpg', b_examples + '/Bengal_132.jpg',], inputs = b_image_input)
            binary_out = gr.Label(num_top_classes=len(binary_labels))
        
        
    with gr.Tab("MultiClass"):
        with gr.Row():
            with gr.Column():
                m_image_input = gr.inputs.Image(shape = (460,460))
                with gr.Row():
                    m_button = gr.Button("Run")

                m_examples = 'models/Examples/Pets'
                examples = gr.Examples(examples=[os.path.join(m_examples, s) for s in os.listdir(m_examples) if s.endswith('jpg')], inputs = m_image_input)
            multi_out = gr.Label(num_top_classes=len(multi_class_labels))
        

    with gr.Tab("MultiLabel"):
        with gr.Row():
            with gr.Column():
                ml_image_input = gr.inputs.Image(shape = (460,460))
                with gr.Row():
                    ml_button = gr.Button("Run")
                
                ml_examples = 'models/Examples/Pascal'
                examples = gr.Examples(examples=[os.path.join(ml_examples, s) for s in os.listdir(ml_examples) if s.endswith('jpg')], inputs = ml_image_input)
    
            multil_out = gr.Label(num_top_classes=len(multi_label_labels))
        
    b_button.click(single_classification, inputs=b_image_input, outputs=binary_out)
    m_button.click(multi_class_classification, inputs=m_image_input, outputs=multi_out)
    ml_button.click(multi_label_classification, inputs=ml_image_input, outputs=multil_out)

demo.launch()