Spaces:
Runtime error
Runtime error
# Kept getting "No module named 'fastai'" from huggingface..workaround: | |
# https://stackoverflow.com/a/50255019 | |
import subprocess | |
import sys | |
def install(package): | |
subprocess.check_call([sys.executable, "-m", "pip", "install", package]) | |
install("fastai") | |
# Errors from within libraries suggest that importing timm must be done before importing fast.ai, something like this: | |
# https://forums.fast.ai/t/nameerror-name-timm-is-not-defined/96158 | |
#!pip install -Uqq fastai timm | |
#!pip install timm | |
install("timm") | |
from fastai.vision.all import * | |
import timm | |
learn_resnet = load_learner('model - resnet18.pkl') | |
learn_convnext = load_learner('model - convnext_tiny - pad - 150imgs each - cleaned.pkl') | |
categories = ('1', '4', '8', '12', '16', '20s', '30s', '40s','50s', '60s', '70s', '80s', '90s') | |
def classify_image_resnet(img): | |
return classify_image(learn_resnet, img) | |
def classify_image_convnext(img): | |
return classify_image(learn_convnext, img) | |
def classify_image(learn, img): | |
tens = tensor(img) #fix apparently needed after fastai 2.7.11 released | |
pred,idx,probs = learn.predict(tens) | |
return dict(zip(categories, map(float,probs))) | |
from fastdownload import download_url | |
import os | |
def classify_image_url_resnet(url_text): | |
return classify_image_url(learn_resnet, url_text) | |
def classify_image_url_convnext(url_text): | |
return classify_image_url(learn_convnext, url_text) | |
def classify_image_url(learn, url_text): | |
try: | |
dest = 'temp.jpg' | |
download_url(url_text, dest, show_progress=False) | |
im = Image.open(dest) | |
img = im.to_thumb(256,256) | |
#resize_images(dest, max_size=400, dest=dest) | |
os.remove(dest) | |
return classify_image(learn, img),img | |
except: | |
# in case there is any error, invalid URL or invalid image, etc., not sure how Gradio will handle a runtime exception so catching it to be safe | |
return { categories[0]: 0.0, categories[1]: 0.0 },None | |
def classify_image_url_debug(url_text): | |
try: | |
dest = 'temp.jpg' | |
download_url(url_text, dest, show_progress=False) | |
im = Image.open(dest) | |
img = im.to_thumb(256,256) | |
#resize_images(dest, max_size=400, dest=dest) | |
os.remove(dest) | |
temp = classify_image(learn_resnet, img) | |
return "Success: " + str(temp) | |
except Exception as ex: | |
error = f"{type(ex).__name__} was raised: {ex}" | |
return error; | |
import gradio as gr | |
demo = gr.Blocks() | |
with demo: | |
gr.Markdown(" ") | |
gr.Markdown("Rudimentary age predictor. No refinement, just a hacked together experiment to try multiple output classes with fast.ai. Many ways the accuracy could be improved.") | |
gr.Markdown("Note that training images were taken from DuckDuckGo and results for '[n] year old' were majority women, so accuracy is expected to be reduced for men. (Could easily be updated.)") | |
gr.Markdown("See: [https://www.kaggle.com/code/zachwormgoor/age-predictor](https://www.kaggle.com/code/zachwormgoor/age-predictor) ") | |
gr.Markdown(" ") | |
gr.Markdown("---") | |
gr.Markdown("Predict age from uploaded image or from provide URL to image file: ") | |
gr.Markdown(" ") | |
with gr.Tabs(): | |
with gr.TabItem("Image - resnet18"): | |
with gr.Row(): | |
img_r_input = gr.Image() | |
img_r_output = gr.outputs.Label() | |
image_r_button = gr.Button("Predict") | |
with gr.TabItem("URL - resnet18"): | |
with gr.Row(): | |
text_r_input = gr.Textbox() | |
text_r_output = gr.outputs.Label() | |
text_r_preview = gr.Image() | |
url_r_button = gr.Button("Predict") | |
with gr.TabItem("Image - convnext_tiny"): | |
with gr.Row(): | |
img_c_input = gr.Image() | |
img_c_output = gr.outputs.Label() | |
image_c_button = gr.Button("Predict") | |
with gr.TabItem("URL - convnext_tiny"): | |
with gr.Row(): | |
text_c_input = gr.Textbox() | |
text_c_output = gr.outputs.Label() | |
text_c_preview = gr.Image() | |
text_c_button = gr.Button("Predict") | |
with gr.TabItem("URL - debug"): | |
with gr.Row(): | |
text_d_input = gr.Textbox() | |
text_d_output = gr.outputs.Label() | |
text_d_button = gr.Button("Predict - debug") | |
image_r_button.click(classify_image_resnet, inputs=img_r_input, outputs=img_r_output) | |
url_r_button.click( classify_image_url_resnet, inputs=text_r_input, outputs=[text_r_output,text_r_preview]) | |
image_c_button.click(classify_image_convnext, inputs=img_c_input, outputs=img_c_output) | |
text_c_button.click( classify_image_url_convnext, inputs=text_c_input, outputs=[text_c_output,text_c_preview]) | |
text_d_button.click( classify_image_url_debug, inputs=text_d_input, outputs=text_d_output) | |
demo.launch() | |