File size: 2,962 Bytes
459e6b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import gradio as gr
import numpy as np
import tensorflow as tf
import cv2
import joblib
import os
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"


# Charger le modèle de segmentation
segmentation_model = tf.keras.models.load_model('unet_optimized.keras', 
        custom_objects={"dice_coefficient": lambda y_true, y_pred: y_pred})

# Charger le modèle de classification
classification_model = joblib.load('knn.pkl')

# Classes pour le diagnostic
categories = ['Apple___Apple_scab', 'Apple___Black_rot', 'Apple___Cedar_apple_rust', 'Apple___healthy']

def segment_image(image):
    # Redimensionner et normaliser l'image
    resized_image = cv2.resize(image, (256, 256)) / 255.0
    input_image = np.expand_dims(resized_image, axis=0)

    # Prédire le masque
    mask = segmentation_model.predict(input_image)[0]

    # Debugging : Visualiser les statistiques du masque
    print("Raw mask - Min:", np.min(mask), "Max:", np.max(mask), "Mean:", np.mean(mask))

    # Si nécessaire, normaliser le masque
    if np.max(mask) > 1.0:  # Si les valeurs sont hors de l'échelle attendue
        mask = mask / np.max(mask)

    # Seuillage pour obtenir une image binaire
    mask = (mask.squeeze() > 0.1).astype(np.uint8)

    # Debugging : Sauvegarder le masque binaire
    cv2.imwrite("binary_mask.png", mask * 255)

    # Redimensionner le masque à la taille originale
    original_size = (image.shape[1], image.shape[0])
    mask_resized = cv2.resize(mask, original_size, interpolation=cv2.INTER_NEAREST)

    return mask_resized



# Fonction de classification
def classify_image(image):
    # Extraire les caractéristiques pour la classification
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    hist = cv2.calcHist([gray], [0], None, [256], [0, 256]).flatten()

    # Prédire la classe
    prediction = classification_model.predict([hist])
    return prediction[0]

# Fonction principale pour Gradio
def process_image(image):
    # Convertir l'image de PIL à NumPy
    image = np.array(image)

    # Segmentation
    mask = segment_image(image)

    # Classification
    diagnosis = classify_image(image)

    # Convertir le masque en image couleur pour l'affichage
    mask_colored = cv2.cvtColor(mask * 255, cv2.COLOR_GRAY2BGR)

    return mask_colored, diagnosis

# Interface Gradio
interface = gr.Interface(
    fn=process_image,
    inputs=gr.Image(label="Chargez une image de feuille", type="pil"),
    outputs=[
        gr.Image(label="Masque de segmentation"),
        gr.Label(label="Diagnostic")
    ],
    title="SafeLeaf",
    description=(
        "Cette application est une application de détection des maladies des feuilles de pommiers, elle utilise deux modèles : "
        "1. Un modèle de segmentation pour détecter la zone de la feuille malade. "
        "2. Un modèle de classification pour diagnostiquer la maladie de la feuille. "
        "Chargez une image pour commencer."
    ),
)

# Lancer l'application
interface.launch()