zachwormgoor@gmail.com commited on
Commit
39d70e1
β€’
1 Parent(s): 95f3d81

Initial upload

Browse files
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
  title: Age Predictor
3
- emoji: πŸ“Š
4
  colorFrom: yellow
5
  colorTo: purple
6
  sdk: gradio
 
1
  ---
2
  title: Age Predictor
3
+ emoji: πŸ•―πŸ“Š
4
  colorFrom: yellow
5
  colorTo: purple
6
  sdk: gradio
app.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+
4
+ # Kept getting "No module named 'fastai'" from huggingface..workaround:
5
+ # https://stackoverflow.com/a/50255019
6
+ import subprocess
7
+ import sys
8
+
9
+ def install(package):
10
+ subprocess.check_call([sys.executable, "-m", "pip", "install", package])
11
+
12
+ install("fastai")
13
+
14
+ from fastai.vision.all import *
15
+
16
+
17
+ learn_resnet = load_learner('model - resnet18.pkl')
18
+ learn_convnext = load_learner('model - convnext_tiny - pad - 150imgs each - cleaned.pkl')
19
+
20
+ categories = ('1', '4', '8', '12', '16', '20s', '30s', '40s','50s', '60s', '70s', '80s', '90s')
21
+
22
+
23
+
24
+ def classify_image_resnet(img):
25
+ return classify_image(learn_resnet, img)
26
+
27
+ def classify_image_convnext(img):
28
+ return classify_image(learn_convnext, img)
29
+
30
+
31
+ def classify_image(learn, img):
32
+ tens = tensor(img) #fix apparently needed after fastai 2.7.11 released
33
+ pred,idx,probs = learn.predict(tens)
34
+ return dict(zip(categories, map(float,probs)))
35
+
36
+
37
+
38
+ from fastdownload import download_url
39
+ import os
40
+
41
+
42
+
43
+ def classify_image_url_resnet(url_text):
44
+ return classify_image_url(learn_resnet, url_text)
45
+
46
+ def classify_image_url_convnext(url_text):
47
+ return classify_image_url(learn_convnext, url_text)
48
+
49
+ def classify_image_url(learn, url_text):
50
+ try:
51
+ dest = 'temp.jpg'
52
+ download_url(url_text, dest, show_progress=False)
53
+ im = Image.open(dest)
54
+ img = im.to_thumb(256,256)
55
+ #resize_images(dest, max_size=400, dest=dest)
56
+ os.remove(dest)
57
+ return classify_image(learn, img)
58
+ except:
59
+ # in case there is any error, invalid URL or invalid image, etc., not sure how Gradio will handle a runtime exception so catching it to be safe
60
+ return { categories[0]: 0.0, categories[1]: 0.0 }
61
+
62
+
63
+ def classify_image_url_debug(url_text):
64
+ try:
65
+ dest = 'temp.jpg'
66
+ download_url(url_text, dest, show_progress=False)
67
+ im = Image.open(dest)
68
+ img = im.to_thumb(256,256)
69
+ #resize_images(dest, max_size=400, dest=dest)
70
+ os.remove(dest)
71
+ temp = classify_image(learn_resnet, img)
72
+ return "Success: " + str(temp)
73
+ except Exception as ex:
74
+ error = f"{type(ex).__name__} was raised: {ex}"
75
+ return error;
76
+
77
+
78
+
79
+
80
+
81
+ import gradio as gr
82
+
83
+
84
+
85
+ demo = gr.Blocks()
86
+
87
+ with demo:
88
+ gr.Markdown("Rudimentary age predictor. No refinement, just a hacked together experiment to try multiple output classes with fast.ai. Many ways the accuracy could be improved. See: https://www.kaggle.com/code/zachwormgoor/age-predictor \r\n")
89
+ gr.Markdown("Predict age from uploaded image or from provide URL to image file:")
90
+ with gr.Tabs():
91
+ with gr.TabItem("Image - resnet18"):
92
+ with gr.Row():
93
+ img_r_input = gr.Image()
94
+ img_r_output = gr.outputs.Label()
95
+ image_r_button = gr.Button("Predict")
96
+ with gr.TabItem("URL - resnet18"):
97
+ with gr.Row():
98
+ text_r_input = gr.Textbox()
99
+ text_r_output = gr.outputs.Label()
100
+ url_r_button = gr.Button("Predict")
101
+ with gr.TabItem("Image - convnext_tiny"):
102
+ with gr.Row():
103
+ img_c_input = gr.Image()
104
+ img_c_output = gr.outputs.Label()
105
+ image_c_button = gr.Button("Predict")
106
+ with gr.TabItem("URL - convnext_tiny"):
107
+ with gr.Row():
108
+ text_c_input = gr.Textbox()
109
+ text_c_output = gr.outputs.Label()
110
+ text_c_button = gr.Button("Predict")
111
+ with gr.TabItem("URL - debug"):
112
+ with gr.Row():
113
+ text_d_input = gr.Textbox()
114
+ text_d_output = gr.outputs.Label()
115
+ text_d_button = gr.Button("Predict - debug")
116
+
117
+ image_r_button.click(classify_image_resnet, inputs=img_r_input, outputs=img_r_output)
118
+ url_r_button.click( classify_image_url_resnet, inputs=text_r_input, outputs=text_r_output)
119
+ image_c_button.click(classify_image_convnext, inputs=img_c_input, outputs=img_c_output)
120
+ text_c_button.click( classify_image_url_convnext, inputs=text_c_input, outputs=text_d_output)
121
+ text_d_button.click( classify_image_url_debug, inputs=text_d_input, outputs=text_d_output)
122
+
123
+ demo.launch()
124
+
125
+
126
+
127
+
128
+
129
+
130
+
131
+
132
+
model - convnext_tiny - pad - 150imgs each - cleaned.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a8ab30e8f418a5feb43a594f56b6cdfdceb8f19084f73ac3ea15a26fbc4628ac
3
+ size 114640499
model - resnet18.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4d7659296c5f6b608a344640ee5f573f9884a65ca7166107965f16c2c69152b2
3
+ size 47021807