File size: 5,110 Bytes
39d70e1
 
 
66b0b77
39d70e1
 
 
 
 
 
 
 
 
 
66b0b77
 
 
 
 
 
 
39d70e1
66b0b77
39d70e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54dd2be
39d70e1
 
54dd2be
39d70e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a476c73
 
 
481c015
 
 
54dd2be
 
39d70e1
 
 
 
 
 
 
 
 
 
4faf144
39d70e1
 
 
 
 
 
 
 
 
 
4faf144
39d70e1
 
 
 
 
 
 
 
4faf144
39d70e1
4faf144
39d70e1
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150




# Kept getting "No module named 'fastai'" from huggingface..workaround:
# https://stackoverflow.com/a/50255019
import subprocess
import sys

def install(package):
    subprocess.check_call([sys.executable, "-m", "pip", "install", package])

install("fastai")

# Errors from within libraries suggest that importing timm must be done before importing fast.ai, something like this:
# https://forums.fast.ai/t/nameerror-name-timm-is-not-defined/96158
#!pip install -Uqq fastai timm
#!pip install timm
install("timm")


from fastai.vision.all import *
import timm


learn_resnet   = load_learner('model - resnet18.pkl')
learn_convnext = load_learner('model - convnext_tiny - pad - 150imgs each - cleaned.pkl')

categories = ('1', '4', '8', '12', '16', '20s', '30s', '40s','50s', '60s', '70s', '80s', '90s')



def classify_image_resnet(img):
    return classify_image(learn_resnet, img)
    
def classify_image_convnext(img):
    return classify_image(learn_convnext, img)


def classify_image(learn, img):
    tens = tensor(img) #fix apparently needed after fastai 2.7.11 released
    pred,idx,probs = learn.predict(tens)
    return dict(zip(categories, map(float,probs)))
    
    
    
from fastdownload import download_url
import os

    

def classify_image_url_resnet(url_text):
    return classify_image_url(learn_resnet, url_text)
    
def classify_image_url_convnext(url_text):
    return classify_image_url(learn_convnext, url_text)

def classify_image_url(learn, url_text):
    try:
        dest = 'temp.jpg'
        download_url(url_text, dest, show_progress=False)
        im = Image.open(dest)
        img = im.to_thumb(256,256)
        #resize_images(dest, max_size=400, dest=dest)
        os.remove(dest)
        return classify_image(learn, img),img
    except:
        # in case there is any error, invalid URL or invalid image, etc., not sure how Gradio will handle a runtime exception so catching it to be safe
        return { categories[0]: 0.0, categories[1]: 0.0 },None


def classify_image_url_debug(url_text):
    try:
        dest = 'temp.jpg'
        download_url(url_text, dest, show_progress=False)
        im = Image.open(dest)
        img = im.to_thumb(256,256)
        #resize_images(dest, max_size=400, dest=dest)
        os.remove(dest)
        temp = classify_image(learn_resnet, img)
        return "Success: " + str(temp)
    except Exception as ex:
        error = f"{type(ex).__name__} was raised: {ex}"
        return error;





import gradio as gr



demo = gr.Blocks()

with demo:
    gr.Markdown("  ")
    gr.Markdown("Rudimentary age predictor.  No refinement, just a hacked together experiment to try multiple output classes with fast.ai.  Many ways the accuracy could be improved.")
    gr.Markdown("Note that training images were taken from DuckDuckGo and results for '[n] year old' were majority women, so accuracy is expected to be reduced for men.  (Could easily be updated.)")
    gr.Markdown("See: [https://www.kaggle.com/code/zachwormgoor/age-predictor](https://www.kaggle.com/code/zachwormgoor/age-predictor)  ")
    gr.Markdown("  ")
    gr.Markdown("---")
    gr.Markdown("Predict age from uploaded image or from provide URL to image file:  ")
    gr.Markdown("  ")
    with gr.Tabs():
        with gr.TabItem("Image - resnet18"):
            with gr.Row():
                img_r_input = gr.Image()
                img_r_output = gr.outputs.Label()
            image_r_button = gr.Button("Predict")
        with gr.TabItem("URL - resnet18"):
            with gr.Row():
                text_r_input = gr.Textbox()
                text_r_output = gr.outputs.Label()
                text_r_preview = gr.Image()
            url_r_button = gr.Button("Predict")
        with gr.TabItem("Image - convnext_tiny"):
            with gr.Row():
                img_c_input = gr.Image()
                img_c_output = gr.outputs.Label()
            image_c_button = gr.Button("Predict")
        with gr.TabItem("URL - convnext_tiny"):
            with gr.Row():
                text_c_input = gr.Textbox()
                text_c_output = gr.outputs.Label()
                text_c_preview = gr.Image()
            text_c_button = gr.Button("Predict")
        with gr.TabItem("URL - debug"):
            with gr.Row():
                text_d_input = gr.Textbox()
                text_d_output = gr.outputs.Label()
            text_d_button = gr.Button("Predict - debug")

    image_r_button.click(classify_image_resnet,       inputs=img_r_input,  outputs=img_r_output)
    url_r_button.click(  classify_image_url_resnet,   inputs=text_r_input, outputs=[text_r_output,text_r_preview])
    image_c_button.click(classify_image_convnext,     inputs=img_c_input,  outputs=img_c_output)
    text_c_button.click( classify_image_url_convnext, inputs=text_c_input, outputs=[text_c_output,text_c_preview])
    text_d_button.click( classify_image_url_debug,    inputs=text_d_input, outputs=text_d_output)
 
demo.launch()