LF-netizen
change pathlib depending on os
cb6f6ef
import timm
from fastai.vision.all import *
import gradio as gr
import os
import platform
if platform.system() == 'Windows':
import pathlib
temp = pathlib.PosixPath
pathlib.PosixPath = pathlib.WindowsPath
title = 'LEGO sets&creations theme classifier'
description = f'''
# {title}
This demo showcases the LEGO theme classifier built with the help of fast.ai. A model was trained using over 1800 images of sets released in 2005-19 scraped from the Brickset LEGO database.
To test how much overfitting might be present due to the model memorizing the color(s) associated with a particular theme, I ran the training again using the same set of images, but in grayscale. Hence two available models.
I was especially intrested in how the model will do on MOCS a.k.a. community creations, since the boundries between themes are not well-defined. Enjoy!
'''
themes = sorted(('City', 'Technic', 'Star-Wars', 'Creator', 'Ninjago', 'Architecture', 'Duplo', 'Friends', 'DC-Comics-Super-Heroes'))
learn_color = load_learner('models/lego_convnext_small_4ep_sets05-19.pkl')
learn_gray = load_learner('models/lego_convnext_small_4ep_grayscale.pkl')
def classify(img, is_color):
if is_color == 'Grayscale model':
_, _, probs = learn_gray.predict(img)
else:
_, _, probs = learn_color.predict(img)
return dict(zip(themes, map(float, probs)))
examples_sets = [[f'images/sets/{img_name}', img_name.split('2', 1)[0].capitalize(), img_name.split('.', 1)[0][-4:]] for img_name in os.listdir('images/sets')]
examples_mocs = [['images/mocs/modernlibrary.jpg', 'Modern library MOC'],
['images/mocs/keanu.jpg', 'Keanu Reeves himself'],
['images/mocs/solaris.jfif', 'Solaris Urbino articulated bus'],
['images/mocs/aroundtheworld.jpg', '"Around the World" MOC'],
['images/mocs/walkingminicooper.jpg', 'Walking mini cooper. Yes, walking mini cooper']]
with gr.Blocks() as app:
gr.Markdown(description)
with gr.Row(equal_height=True):
with gr.Column():
img = gr.components.Image(shape=(192, 192), label="Input image")
is_color = gr.components.Radio(['Color model', 'Grayscale model'], value='Color model', show_label=False)
real_label = gr.components.Textbox("", label='Real theme', interactive=False)
run_btn = gr.Button("Predict!")
# placeholders for additional info
name = gr.components.Textbox("", label='Name', visible=False)
year = gr.components.Textbox("", label='Release year', visible=False)
with gr.Column():
prediction = gr.components.Label(label='Prediction')
with gr.Row():
with gr.Column():
ex_sets = gr.Examples(examples_sets, inputs=[img, real_label, year], outputs=prediction, label='Examples - official sets')
with gr.Column():
ex_mocs = gr.Examples(examples_mocs, inputs=[img, name], outputs=prediction, label='Examples - community creations')
run_btn.click(fn=classify, inputs=[img, is_color], outputs=prediction)
app.launch()