import timm from fastai.vision.all import * import gradio as gr import os import platform if platform.system() == 'Windows': import pathlib temp = pathlib.PosixPath pathlib.PosixPath = pathlib.WindowsPath title = 'LEGO sets&creations theme classifier' description = f''' # {title} This demo showcases the LEGO theme classifier built with the help of fast.ai. A model was trained using over 1800 images of sets released in 2005-19 scraped from the Brickset LEGO database. To test how much overfitting might be present due to the model memorizing the color(s) associated with a particular theme, I ran the training again using the same set of images, but in grayscale. Hence two available models. I was especially intrested in how the model will do on MOCS a.k.a. community creations, since the boundries between themes are not well-defined. Enjoy! ''' themes = sorted(('City', 'Technic', 'Star-Wars', 'Creator', 'Ninjago', 'Architecture', 'Duplo', 'Friends', 'DC-Comics-Super-Heroes')) learn_color = load_learner('models/lego_convnext_small_4ep_sets05-19.pkl') learn_gray = load_learner('models/lego_convnext_small_4ep_grayscale.pkl') def classify(img, is_color): if is_color == 'Grayscale model': _, _, probs = learn_gray.predict(img) else: _, _, probs = learn_color.predict(img) return dict(zip(themes, map(float, probs))) examples_sets = [[f'images/sets/{img_name}', img_name.split('2', 1)[0].capitalize(), img_name.split('.', 1)[0][-4:]] for img_name in os.listdir('images/sets')] examples_mocs = [['images/mocs/modernlibrary.jpg', 'Modern library MOC'], ['images/mocs/keanu.jpg', 'Keanu Reeves himself'], ['images/mocs/solaris.jfif', 'Solaris Urbino articulated bus'], ['images/mocs/aroundtheworld.jpg', '"Around the World" MOC'], ['images/mocs/walkingminicooper.jpg', 'Walking mini cooper. Yes, walking mini cooper']] with gr.Blocks() as app: gr.Markdown(description) with gr.Row(equal_height=True): with gr.Column(): img = gr.components.Image(shape=(192, 192), label="Input image") is_color = gr.components.Radio(['Color model', 'Grayscale model'], value='Color model', show_label=False) real_label = gr.components.Textbox("", label='Real theme', interactive=False) run_btn = gr.Button("Predict!") # placeholders for additional info name = gr.components.Textbox("", label='Name', visible=False) year = gr.components.Textbox("", label='Release year', visible=False) with gr.Column(): prediction = gr.components.Label(label='Prediction') with gr.Row(): with gr.Column(): ex_sets = gr.Examples(examples_sets, inputs=[img, real_label, year], outputs=prediction, label='Examples - official sets') with gr.Column(): ex_mocs = gr.Examples(examples_mocs, inputs=[img, name], outputs=prediction, label='Examples - community creations') run_btn.click(fn=classify, inputs=[img, is_color], outputs=prediction) app.launch()