Spaces:
Runtime error
Runtime error
zachwormgoor@gmail.com
commited on
Commit
β’
39d70e1
1
Parent(s):
95f3d81
Initial upload
Browse files- README.md +1 -1
- app.py +132 -0
- model - convnext_tiny - pad - 150imgs each - cleaned.pkl +3 -0
- model - resnet18.pkl +3 -0
README.md
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
---
|
2 |
title: Age Predictor
|
3 |
-
emoji:
|
4 |
colorFrom: yellow
|
5 |
colorTo: purple
|
6 |
sdk: gradio
|
|
|
1 |
---
|
2 |
title: Age Predictor
|
3 |
+
emoji: π―π
|
4 |
colorFrom: yellow
|
5 |
colorTo: purple
|
6 |
sdk: gradio
|
app.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
|
4 |
+
# Kept getting "No module named 'fastai'" from huggingface..workaround:
|
5 |
+
# https://stackoverflow.com/a/50255019
|
6 |
+
import subprocess
|
7 |
+
import sys
|
8 |
+
|
9 |
+
def install(package):
|
10 |
+
subprocess.check_call([sys.executable, "-m", "pip", "install", package])
|
11 |
+
|
12 |
+
install("fastai")
|
13 |
+
|
14 |
+
from fastai.vision.all import *
|
15 |
+
|
16 |
+
|
17 |
+
learn_resnet = load_learner('model - resnet18.pkl')
|
18 |
+
learn_convnext = load_learner('model - convnext_tiny - pad - 150imgs each - cleaned.pkl')
|
19 |
+
|
20 |
+
categories = ('1', '4', '8', '12', '16', '20s', '30s', '40s','50s', '60s', '70s', '80s', '90s')
|
21 |
+
|
22 |
+
|
23 |
+
|
24 |
+
def classify_image_resnet(img):
|
25 |
+
return classify_image(learn_resnet, img)
|
26 |
+
|
27 |
+
def classify_image_convnext(img):
|
28 |
+
return classify_image(learn_convnext, img)
|
29 |
+
|
30 |
+
|
31 |
+
def classify_image(learn, img):
|
32 |
+
tens = tensor(img) #fix apparently needed after fastai 2.7.11 released
|
33 |
+
pred,idx,probs = learn.predict(tens)
|
34 |
+
return dict(zip(categories, map(float,probs)))
|
35 |
+
|
36 |
+
|
37 |
+
|
38 |
+
from fastdownload import download_url
|
39 |
+
import os
|
40 |
+
|
41 |
+
|
42 |
+
|
43 |
+
def classify_image_url_resnet(url_text):
|
44 |
+
return classify_image_url(learn_resnet, url_text)
|
45 |
+
|
46 |
+
def classify_image_url_convnext(url_text):
|
47 |
+
return classify_image_url(learn_convnext, url_text)
|
48 |
+
|
49 |
+
def classify_image_url(learn, url_text):
|
50 |
+
try:
|
51 |
+
dest = 'temp.jpg'
|
52 |
+
download_url(url_text, dest, show_progress=False)
|
53 |
+
im = Image.open(dest)
|
54 |
+
img = im.to_thumb(256,256)
|
55 |
+
#resize_images(dest, max_size=400, dest=dest)
|
56 |
+
os.remove(dest)
|
57 |
+
return classify_image(learn, img)
|
58 |
+
except:
|
59 |
+
# in case there is any error, invalid URL or invalid image, etc., not sure how Gradio will handle a runtime exception so catching it to be safe
|
60 |
+
return { categories[0]: 0.0, categories[1]: 0.0 }
|
61 |
+
|
62 |
+
|
63 |
+
def classify_image_url_debug(url_text):
|
64 |
+
try:
|
65 |
+
dest = 'temp.jpg'
|
66 |
+
download_url(url_text, dest, show_progress=False)
|
67 |
+
im = Image.open(dest)
|
68 |
+
img = im.to_thumb(256,256)
|
69 |
+
#resize_images(dest, max_size=400, dest=dest)
|
70 |
+
os.remove(dest)
|
71 |
+
temp = classify_image(learn_resnet, img)
|
72 |
+
return "Success: " + str(temp)
|
73 |
+
except Exception as ex:
|
74 |
+
error = f"{type(ex).__name__} was raised: {ex}"
|
75 |
+
return error;
|
76 |
+
|
77 |
+
|
78 |
+
|
79 |
+
|
80 |
+
|
81 |
+
import gradio as gr
|
82 |
+
|
83 |
+
|
84 |
+
|
85 |
+
demo = gr.Blocks()
|
86 |
+
|
87 |
+
with demo:
|
88 |
+
gr.Markdown("Rudimentary age predictor. No refinement, just a hacked together experiment to try multiple output classes with fast.ai. Many ways the accuracy could be improved. See: https://www.kaggle.com/code/zachwormgoor/age-predictor \r\n")
|
89 |
+
gr.Markdown("Predict age from uploaded image or from provide URL to image file:")
|
90 |
+
with gr.Tabs():
|
91 |
+
with gr.TabItem("Image - resnet18"):
|
92 |
+
with gr.Row():
|
93 |
+
img_r_input = gr.Image()
|
94 |
+
img_r_output = gr.outputs.Label()
|
95 |
+
image_r_button = gr.Button("Predict")
|
96 |
+
with gr.TabItem("URL - resnet18"):
|
97 |
+
with gr.Row():
|
98 |
+
text_r_input = gr.Textbox()
|
99 |
+
text_r_output = gr.outputs.Label()
|
100 |
+
url_r_button = gr.Button("Predict")
|
101 |
+
with gr.TabItem("Image - convnext_tiny"):
|
102 |
+
with gr.Row():
|
103 |
+
img_c_input = gr.Image()
|
104 |
+
img_c_output = gr.outputs.Label()
|
105 |
+
image_c_button = gr.Button("Predict")
|
106 |
+
with gr.TabItem("URL - convnext_tiny"):
|
107 |
+
with gr.Row():
|
108 |
+
text_c_input = gr.Textbox()
|
109 |
+
text_c_output = gr.outputs.Label()
|
110 |
+
text_c_button = gr.Button("Predict")
|
111 |
+
with gr.TabItem("URL - debug"):
|
112 |
+
with gr.Row():
|
113 |
+
text_d_input = gr.Textbox()
|
114 |
+
text_d_output = gr.outputs.Label()
|
115 |
+
text_d_button = gr.Button("Predict - debug")
|
116 |
+
|
117 |
+
image_r_button.click(classify_image_resnet, inputs=img_r_input, outputs=img_r_output)
|
118 |
+
url_r_button.click( classify_image_url_resnet, inputs=text_r_input, outputs=text_r_output)
|
119 |
+
image_c_button.click(classify_image_convnext, inputs=img_c_input, outputs=img_c_output)
|
120 |
+
text_c_button.click( classify_image_url_convnext, inputs=text_c_input, outputs=text_d_output)
|
121 |
+
text_d_button.click( classify_image_url_debug, inputs=text_d_input, outputs=text_d_output)
|
122 |
+
|
123 |
+
demo.launch()
|
124 |
+
|
125 |
+
|
126 |
+
|
127 |
+
|
128 |
+
|
129 |
+
|
130 |
+
|
131 |
+
|
132 |
+
|
model - convnext_tiny - pad - 150imgs each - cleaned.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a8ab30e8f418a5feb43a594f56b6cdfdceb8f19084f73ac3ea15a26fbc4628ac
|
3 |
+
size 114640499
|
model - resnet18.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4d7659296c5f6b608a344640ee5f573f9884a65ca7166107965f16c2c69152b2
|
3 |
+
size 47021807
|