age-predictor / app.py
zachwormgoor@gmail.com
Description update
a476c73
# Kept getting "No module named 'fastai'" from huggingface..workaround:
# https://stackoverflow.com/a/50255019
import subprocess
import sys
def install(package):
subprocess.check_call([sys.executable, "-m", "pip", "install", package])
install("fastai")
# Errors from within libraries suggest that importing timm must be done before importing fast.ai, something like this:
# https://forums.fast.ai/t/nameerror-name-timm-is-not-defined/96158
#!pip install -Uqq fastai timm
#!pip install timm
install("timm")
from fastai.vision.all import *
import timm
learn_resnet = load_learner('model - resnet18.pkl')
learn_convnext = load_learner('model - convnext_tiny - pad - 150imgs each - cleaned.pkl')
categories = ('1', '4', '8', '12', '16', '20s', '30s', '40s','50s', '60s', '70s', '80s', '90s')
def classify_image_resnet(img):
return classify_image(learn_resnet, img)
def classify_image_convnext(img):
return classify_image(learn_convnext, img)
def classify_image(learn, img):
tens = tensor(img) #fix apparently needed after fastai 2.7.11 released
pred,idx,probs = learn.predict(tens)
return dict(zip(categories, map(float,probs)))
from fastdownload import download_url
import os
def classify_image_url_resnet(url_text):
return classify_image_url(learn_resnet, url_text)
def classify_image_url_convnext(url_text):
return classify_image_url(learn_convnext, url_text)
def classify_image_url(learn, url_text):
try:
dest = 'temp.jpg'
download_url(url_text, dest, show_progress=False)
im = Image.open(dest)
img = im.to_thumb(256,256)
#resize_images(dest, max_size=400, dest=dest)
os.remove(dest)
return classify_image(learn, img),img
except:
# in case there is any error, invalid URL or invalid image, etc., not sure how Gradio will handle a runtime exception so catching it to be safe
return { categories[0]: 0.0, categories[1]: 0.0 },None
def classify_image_url_debug(url_text):
try:
dest = 'temp.jpg'
download_url(url_text, dest, show_progress=False)
im = Image.open(dest)
img = im.to_thumb(256,256)
#resize_images(dest, max_size=400, dest=dest)
os.remove(dest)
temp = classify_image(learn_resnet, img)
return "Success: " + str(temp)
except Exception as ex:
error = f"{type(ex).__name__} was raised: {ex}"
return error;
import gradio as gr
demo = gr.Blocks()
with demo:
gr.Markdown(" ")
gr.Markdown("Rudimentary age predictor. No refinement, just a hacked together experiment to try multiple output classes with fast.ai. Many ways the accuracy could be improved.")
gr.Markdown("Note that training images were taken from DuckDuckGo and results for '[n] year old' were majority women, so accuracy is expected to be reduced for men. (Could easily be updated.)")
gr.Markdown("See: [https://www.kaggle.com/code/zachwormgoor/age-predictor](https://www.kaggle.com/code/zachwormgoor/age-predictor) ")
gr.Markdown(" ")
gr.Markdown("---")
gr.Markdown("Predict age from uploaded image or from provide URL to image file: ")
gr.Markdown(" ")
with gr.Tabs():
with gr.TabItem("Image - resnet18"):
with gr.Row():
img_r_input = gr.Image()
img_r_output = gr.outputs.Label()
image_r_button = gr.Button("Predict")
with gr.TabItem("URL - resnet18"):
with gr.Row():
text_r_input = gr.Textbox()
text_r_output = gr.outputs.Label()
text_r_preview = gr.Image()
url_r_button = gr.Button("Predict")
with gr.TabItem("Image - convnext_tiny"):
with gr.Row():
img_c_input = gr.Image()
img_c_output = gr.outputs.Label()
image_c_button = gr.Button("Predict")
with gr.TabItem("URL - convnext_tiny"):
with gr.Row():
text_c_input = gr.Textbox()
text_c_output = gr.outputs.Label()
text_c_preview = gr.Image()
text_c_button = gr.Button("Predict")
with gr.TabItem("URL - debug"):
with gr.Row():
text_d_input = gr.Textbox()
text_d_output = gr.outputs.Label()
text_d_button = gr.Button("Predict - debug")
image_r_button.click(classify_image_resnet, inputs=img_r_input, outputs=img_r_output)
url_r_button.click( classify_image_url_resnet, inputs=text_r_input, outputs=[text_r_output,text_r_preview])
image_c_button.click(classify_image_convnext, inputs=img_c_input, outputs=img_c_output)
text_c_button.click( classify_image_url_convnext, inputs=text_c_input, outputs=[text_c_output,text_c_preview])
text_d_button.click( classify_image_url_debug, inputs=text_d_input, outputs=text_d_output)
demo.launch()