import gradio as gr from fastai.vision.all import * import skimage learn = load_learner("model/model.pkl") labels = learn.dls.vocab def predict(img): img = PILImage.create(img) pred, pred_idx, probs = learn.predict(img) return dict(zip(labels, map(float, probs))) title = "Latte Art Classifier" description = """ A latte art classifier trained with fastai. Currently supports 4 classes: heart, rosetta, swan, tulip. The model is trained with resnet18 and achieved 91% accuracy on validation set. Dataset is created by myself storing on kaggle - https://www.kaggle.com/datasets/mingchenadam/latte-art-train. """ article = "<p style='text-align: center'><a href='https://github.com/mchen50' target='_blank'>My Github</a></p>" examples = [ "examples/heart.jpg", "examples/rosetta.jpg", "examples/swan.jpg", "examples/tulip.jpg" ] interpretation = "default" app = gr.Interface( fn=predict, inputs=gr.Image(shape=(512, 512)), outputs=gr.Label(num_top_classes=3), title=title, description=description, article=article, examples=examples, interpretation=interpretation, ) app.queue() app.launch()