Spaces:
Build error
Build error
import timm | |
from fastai.vision.all import * | |
import gradio as gr | |
import os | |
import platform | |
if platform.system() == 'Windows': | |
import pathlib | |
temp = pathlib.PosixPath | |
pathlib.PosixPath = pathlib.WindowsPath | |
title = 'LEGO sets&creations theme classifier' | |
description = f''' | |
# {title} | |
This demo showcases the LEGO theme classifier built with the help of fast.ai. A model was trained using over 1800 images of sets released in 2005-19 scraped from the Brickset LEGO database. | |
To test how much overfitting might be present due to the model memorizing the color(s) associated with a particular theme, I ran the training again using the same set of images, but in grayscale. Hence two available models. | |
I was especially intrested in how the model will do on MOCS a.k.a. community creations, since the boundries between themes are not well-defined. Enjoy! | |
''' | |
themes = sorted(('City', 'Technic', 'Star-Wars', 'Creator', 'Ninjago', 'Architecture', 'Duplo', 'Friends', 'DC-Comics-Super-Heroes')) | |
learn_color = load_learner('models/lego_convnext_small_4ep_sets05-19.pkl') | |
learn_gray = load_learner('models/lego_convnext_small_4ep_grayscale.pkl') | |
def classify(img, is_color): | |
if is_color == 'Grayscale model': | |
_, _, probs = learn_gray.predict(img) | |
else: | |
_, _, probs = learn_color.predict(img) | |
return dict(zip(themes, map(float, probs))) | |
examples_sets = [[f'images/sets/{img_name}', img_name.split('2', 1)[0].capitalize(), img_name.split('.', 1)[0][-4:]] for img_name in os.listdir('images/sets')] | |
examples_mocs = [['images/mocs/modernlibrary.jpg', 'Modern library MOC'], | |
['images/mocs/keanu.jpg', 'Keanu Reeves himself'], | |
['images/mocs/solaris.jfif', 'Solaris Urbino articulated bus'], | |
['images/mocs/aroundtheworld.jpg', '"Around the World" MOC'], | |
['images/mocs/walkingminicooper.jpg', 'Walking mini cooper. Yes, walking mini cooper']] | |
with gr.Blocks() as app: | |
gr.Markdown(description) | |
with gr.Row(equal_height=True): | |
with gr.Column(): | |
img = gr.components.Image(shape=(192, 192), label="Input image") | |
is_color = gr.components.Radio(['Color model', 'Grayscale model'], value='Color model', show_label=False) | |
real_label = gr.components.Textbox("", label='Real theme', interactive=False) | |
run_btn = gr.Button("Predict!") | |
# placeholders for additional info | |
name = gr.components.Textbox("", label='Name', visible=False) | |
year = gr.components.Textbox("", label='Release year', visible=False) | |
with gr.Column(): | |
prediction = gr.components.Label(label='Prediction') | |
with gr.Row(): | |
with gr.Column(): | |
ex_sets = gr.Examples(examples_sets, inputs=[img, real_label, year], outputs=prediction, label='Examples - official sets') | |
with gr.Column(): | |
ex_mocs = gr.Examples(examples_mocs, inputs=[img, name], outputs=prediction, label='Examples - community creations') | |
run_btn.click(fn=classify, inputs=[img, is_color], outputs=prediction) | |
app.launch() |