|
import gradio as gr |
|
import tensorflow as tf |
|
import numpy as np |
|
from keras.models import load_model |
|
from tensorflow.keras.preprocessing.image import load_img, img_to_array |
|
|
|
|
|
model = load_model('best_model_v2.keras', compile=False) |
|
|
|
|
|
model.compile( |
|
optimizer='adam', |
|
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), |
|
metrics=['accuracy'] |
|
) |
|
|
|
def format_decimal(value): |
|
decimal_value = format(value, ".2f") |
|
return decimal_value |
|
|
|
def detect(img): |
|
|
|
img = img.resize((256, 256)) |
|
img = img_to_array(img) |
|
img = np.expand_dims(img, axis=0) |
|
img = img / 255.0 |
|
|
|
|
|
prediction = model.predict(img)[0] |
|
|
|
|
|
|
|
class_names = ['AMAZONE', 'BIOGUERRA', 'REVENANT', 'ZANGBETO', 'PORTE DU NON RETOUR'] |
|
|
|
|
|
class_index = np.argmax(prediction) |
|
|
|
|
|
texte = class_names[class_index] |
|
|
|
return texte |
|
|
|
|
|
title = "Orisha" |
|
|
|
iface = gr.Interface( |
|
fn=detect, |
|
inputs=gr.Image(type="pil", image_mode='RGB'), |
|
outputs=gr.Textbox(label="Classe", lines=10), |
|
title=title |
|
) |
|
|
|
iface.launch(inline=False) |
|
|