import gradio as gr
from fastai.vision.all import *
import skimage

learn = load_learner("model/model.pkl")

labels = learn.dls.vocab


def predict(img):
    img = PILImage.create(img)
    pred, pred_idx, probs = learn.predict(img)
    return dict(zip(labels, map(float, probs)))


title = "Latte Art Classifier"
description = """
A latte art classifier trained with fastai.
Currently supports 4 classes: heart, rosetta, swan, tulip.

The model is trained with resnet18 and achieved 91% accuracy on validation set.
Dataset is created by myself storing on kaggle - https://www.kaggle.com/datasets/mingchenadam/latte-art-train. 
"""
article = "<p style='text-align: center'><a href='https://github.com/mchen50' target='_blank'>My Github</a></p>"
examples = [
    "examples/heart.jpg",
    "examples/rosetta.jpg",
    "examples/swan.jpg",
    "examples/tulip.jpg"
]
interpretation = "default"

app = gr.Interface(
    fn=predict,
    inputs=gr.Image(shape=(512, 512)),
    outputs=gr.Label(num_top_classes=3),
    title=title,
    description=description,
    article=article,
    examples=examples,
    interpretation=interpretation,
)
app.queue()
app.launch()