|
import gradio |
|
from fastai.vision.all import * |
|
|
|
MODELS_PATH = Path('./models') |
|
EXAMPLES_PATH = Path('./examples') |
|
|
|
|
|
|
|
|
|
|
|
def label_func(filepath): |
|
return filepath.parent.name |
|
|
|
LEARN = load_learner(MODELS_PATH/'flowers-fruits-resnet50-model.pkl') |
|
LABELS = LEARN.dls.vocab |
|
|
|
def gradio_predict(img): |
|
img = PILImage.create(img) |
|
_pred, _pred_idx, probs = LEARN.predict(img) |
|
labels_probs = {LABELS[i]: float(probs[i]) for i, _ in enumerate(LABELS)} |
|
return labels_probs |
|
|
|
with open('gradio_article.md') as f: |
|
article = f.read() |
|
|
|
interface_options = { |
|
"title": "flowers-and-fruits-classifier (ResNet50|fast.ai)", |
|
"description": "A Flowers-Fruits image classifier trained on the 'https://duckduckgo.com/' dataset, using ResNet50 via fast.ai.", |
|
"article": article, |
|
"examples" : [f'{EXAMPLES_PATH}/{f.name}' for f in EXAMPLES_PATH.iterdir()], |
|
"layout": "horizontal", |
|
"theme": "default", |
|
} |
|
|
|
demo = gradio.Interface(fn=gradio_predict, |
|
inputs=gradio.inputs.Image(shape=(512, 512)), |
|
outputs=gradio.outputs.Label(num_top_classes=5), |
|
**interface_options) |
|
|
|
launch_options = { |
|
"enable_queue": True, |
|
"share": False, |
|
} |
|
|
|
demo.launch(**launch_options) |