|
|
import gradio as gr |
|
|
import numpy as np |
|
|
import tensorflow as tf |
|
|
import cv2 |
|
|
import joblib |
|
|
import os |
|
|
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0" |
|
|
|
|
|
|
|
|
|
|
|
segmentation_model = tf.keras.models.load_model('unet_optimized.keras', |
|
|
custom_objects={"dice_coefficient": lambda y_true, y_pred: y_pred}) |
|
|
|
|
|
|
|
|
classification_model = joblib.load('knn.pkl') |
|
|
|
|
|
|
|
|
categories = ['Apple___Apple_scab', 'Apple___Black_rot', 'Apple___Cedar_apple_rust', 'Apple___healthy'] |
|
|
|
|
|
def segment_image(image): |
|
|
|
|
|
resized_image = cv2.resize(image, (256, 256)) / 255.0 |
|
|
input_image = np.expand_dims(resized_image, axis=0) |
|
|
|
|
|
|
|
|
mask = segmentation_model.predict(input_image)[0] |
|
|
|
|
|
|
|
|
print("Raw mask - Min:", np.min(mask), "Max:", np.max(mask), "Mean:", np.mean(mask)) |
|
|
|
|
|
|
|
|
if np.max(mask) > 1.0: |
|
|
mask = mask / np.max(mask) |
|
|
|
|
|
|
|
|
mask = (mask.squeeze() > 0.1).astype(np.uint8) |
|
|
|
|
|
|
|
|
cv2.imwrite("binary_mask.png", mask * 255) |
|
|
|
|
|
|
|
|
original_size = (image.shape[1], image.shape[0]) |
|
|
mask_resized = cv2.resize(mask, original_size, interpolation=cv2.INTER_NEAREST) |
|
|
|
|
|
return mask_resized |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def classify_image(image): |
|
|
|
|
|
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) |
|
|
hist = cv2.calcHist([gray], [0], None, [256], [0, 256]).flatten() |
|
|
|
|
|
|
|
|
prediction = classification_model.predict([hist]) |
|
|
return prediction[0] |
|
|
|
|
|
|
|
|
def process_image(image): |
|
|
|
|
|
image = np.array(image) |
|
|
|
|
|
|
|
|
mask = segment_image(image) |
|
|
|
|
|
|
|
|
diagnosis = classify_image(image) |
|
|
|
|
|
|
|
|
mask_colored = cv2.cvtColor(mask * 255, cv2.COLOR_GRAY2BGR) |
|
|
|
|
|
return mask_colored, diagnosis |
|
|
|
|
|
|
|
|
interface = gr.Interface( |
|
|
fn=process_image, |
|
|
inputs=gr.Image(label="Chargez une image de feuille", type="pil"), |
|
|
outputs=[ |
|
|
gr.Image(label="Masque de segmentation"), |
|
|
gr.Label(label="Diagnostic") |
|
|
], |
|
|
title="SafeLeaf", |
|
|
description=( |
|
|
"Cette application est une application de détection des maladies des feuilles de pommiers, elle utilise deux modèles : " |
|
|
"1. Un modèle de segmentation pour détecter la zone de la feuille malade. " |
|
|
"2. Un modèle de classification pour diagnostiquer la maladie de la feuille. " |
|
|
"Chargez une image pour commencer." |
|
|
), |
|
|
) |
|
|
|
|
|
|
|
|
interface.launch() |
|
|
|