|
__all__ = ['modelname', 'pokemon_types', 'pokemon_types_en', 'examplespath', 'learn_inf', 'lang', 'prob_threshold', |
|
'classify_image'] |
|
|
|
|
|
import pandas as pd |
|
|
|
modelname = 'model_gen0.pkl' |
|
pokemon_types = pd.read_csv('pokemon.csv') |
|
pokemon_types_en = pokemon_types['en'] |
|
|
|
examplespath = 'images/' |
|
|
|
|
|
from huggingface_hub import hf_hub_download |
|
from fastai.learner import load_learner |
|
|
|
learn_inf = load_learner(hf_hub_download("Okkoman/PokeFace", modelname)) |
|
|
|
|
|
import gradio as gr |
|
|
|
lang = 'en' |
|
|
|
prob_threshold = 0.75 |
|
|
|
from flask import request |
|
if request: |
|
lang = request.headers.get("Accept-Language") |
|
|
|
if lang == 'fr': |
|
title = "# PokeFace - Quel est ce pokemon ?" |
|
description = "## Un classifieur pour les pokemons de 1ere et 2eme générations (001-251)" |
|
unknown = 'inconnu' |
|
else: |
|
title = "# PokeFace - What is this pokemon ?" |
|
description = "## An classifier for 1st-2nd generation pokemons (001-251)" |
|
unknown = 'unknown' |
|
|
|
def classify_image(img): |
|
pred, pred_idx, probs = learn_inf.predict(img) |
|
index = pokemon_types_en[pokemon_types_en == pred].index[0] |
|
label = pokemon_types[lang].iloc[index] |
|
if probs[pred_idx] > prob_threshold: |
|
return f"{index+1} - {label} ({probs[pred_idx]*100:.0f}%)" |
|
else: |
|
return unknown |
|
|
|
with gr.Blocks() as demo: |
|
|
|
with gr.Row(): |
|
gr.Markdown(title) |
|
with gr.Row(): |
|
gr.Markdown(description) |
|
with gr.Row(): |
|
image_input = gr.Image(label="Upload an image", width=192, height=192) |
|
submit_button = gr.Button("Classify") |
|
label_output = gr.Label(label="Prediction") |
|
with gr.Row(): |
|
gr.Examples(examples=examplespath, inputs=image_input) |
|
|
|
submit_button.click(fn=classify_image, inputs=image_input, outputs=label_output) |
|
|
|
demo.launch(inline=False) |
|
|
|
|