msivanes's picture
Update app.py
c0ed310
from fastai.vision.core import PILImageBW, TensorImageBW
from datasets import ClassLabel
import gradio as gr
from fastai.learner import load_learner
from PIL import Image
from numpy import array
def get_image_attr(x): return x['image']
def get_target_attr(x): return x['target']
def get_label_attr(x): return x['label']
def img2tensor(im: Image.Image):
return TensorImageBW(array(im)).unsqueeze(0)
classLabel = ClassLabel(names=['T - shirt / top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'], id=None)
labels = classLabel.names
def add_target(x:dict):
x['target'] = classLabel.int2str(x['label'])
return x
learn = load_learner('export.pkl', cpu=True)
def classify(inp):
img = PILImageBW.create(inp)
item = dict(image=img)
pred, _, prob = learn.predict(item)
return {label: float(prob[i]) for i, label in enumerate(labels)}
# return classLabel.int2str(int(pred))
examples = ['shoes.jpg', 't-shirt.jpg']
interpretation='default'
iface = gr.Interface(
fn=classify,
inputs=gr.inputs.Image(image_mode='L'),
outputs=gr.outputs.Label(num_top_classes=3),
title="Fashion Mnist Classifier",
description="fastai deployment in Gradio.",
examples=examples,
interpretation=interpretation,
).launch()