SafeLeaf / app.py
GiGi2k5
Add application file
459e6b2
import gradio as gr
import numpy as np
import tensorflow as tf
import cv2
import joblib
import os
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
# Charger le modèle de segmentation
segmentation_model = tf.keras.models.load_model('unet_optimized.keras',
custom_objects={"dice_coefficient": lambda y_true, y_pred: y_pred})
# Charger le modèle de classification
classification_model = joblib.load('knn.pkl')
# Classes pour le diagnostic
categories = ['Apple___Apple_scab', 'Apple___Black_rot', 'Apple___Cedar_apple_rust', 'Apple___healthy']
def segment_image(image):
# Redimensionner et normaliser l'image
resized_image = cv2.resize(image, (256, 256)) / 255.0
input_image = np.expand_dims(resized_image, axis=0)
# Prédire le masque
mask = segmentation_model.predict(input_image)[0]
# Debugging : Visualiser les statistiques du masque
print("Raw mask - Min:", np.min(mask), "Max:", np.max(mask), "Mean:", np.mean(mask))
# Si nécessaire, normaliser le masque
if np.max(mask) > 1.0: # Si les valeurs sont hors de l'échelle attendue
mask = mask / np.max(mask)
# Seuillage pour obtenir une image binaire
mask = (mask.squeeze() > 0.1).astype(np.uint8)
# Debugging : Sauvegarder le masque binaire
cv2.imwrite("binary_mask.png", mask * 255)
# Redimensionner le masque à la taille originale
original_size = (image.shape[1], image.shape[0])
mask_resized = cv2.resize(mask, original_size, interpolation=cv2.INTER_NEAREST)
return mask_resized
# Fonction de classification
def classify_image(image):
# Extraire les caractéristiques pour la classification
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
hist = cv2.calcHist([gray], [0], None, [256], [0, 256]).flatten()
# Prédire la classe
prediction = classification_model.predict([hist])
return prediction[0]
# Fonction principale pour Gradio
def process_image(image):
# Convertir l'image de PIL à NumPy
image = np.array(image)
# Segmentation
mask = segment_image(image)
# Classification
diagnosis = classify_image(image)
# Convertir le masque en image couleur pour l'affichage
mask_colored = cv2.cvtColor(mask * 255, cv2.COLOR_GRAY2BGR)
return mask_colored, diagnosis
# Interface Gradio
interface = gr.Interface(
fn=process_image,
inputs=gr.Image(label="Chargez une image de feuille", type="pil"),
outputs=[
gr.Image(label="Masque de segmentation"),
gr.Label(label="Diagnostic")
],
title="SafeLeaf",
description=(
"Cette application est une application de détection des maladies des feuilles de pommiers, elle utilise deux modèles : "
"1. Un modèle de segmentation pour détecter la zone de la feuille malade. "
"2. Un modèle de classification pour diagnostiquer la maladie de la feuille. "
"Chargez une image pour commencer."
),
)
# Lancer l'application
interface.launch()