KhadijaAsehnoune12's picture
Update app.py
dac200a verified
raw
history blame
No virus
2.36 kB
import gradio as gr
import torch
from transformers import ViTFeatureExtractor, ViTForImageClassification
from PIL import Image
# Définir le modèle et le feature extractor
model_name ="KhadijaAsehnoune12/ViTOrangeLeafDiseaseClassifier"
model = ViTForImageClassification.from_pretrained(model_name, num_labels=10, ignore_mismatched_sizes=True)
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
# Définir la mappage des labels
id2label = {
"0": "Aleurocanthus spiniferus",
"1": "Chancre citrique",
"2": "Cochenille blanche",
"3": "Dépérissement des agrumes",
"4": "Feuille saine",
"5": "Jaunissement des feuilles",
"6": "Maladie de l'oïdium",
"7": "Maladie du dragon jaune",
"8": "Mineuse des agrumes",
"9": "Trou de balle"
}
def predict(image):
# Prétraiter l'image
inputs = feature_extractor(images=image, return_tensors="pt")
# Passage en avant dans le modèle
outputs = model(**inputs)
# Obtenir les logits
logits = outputs.logits
# Calculer les scores de confiance avec softmax
probs = torch.nn.functional.softmax(logits, dim=-1)[0]
# Obtenir l'indice de la classe la plus probable
predicted_class_idx = probs.argmax().item()
# Obtenir le label et le score de confiance de la classe la plus probable
predicted_label = id2label[str(predicted_class_idx)]
confidence_score = probs[predicted_class_idx].item() * 100 # Multiplie par 100 pour obtenir un pourcentage
# Retourner le label et le score de confiance
return f"{predicted_label}: {confidence_score:.2f}%"
# Créer l'interface Gradio
image = gr.Image(type="pil")
label = gr.Textbox(label="Prediction")
gr.Interface(fn=predict,
inputs=image,
outputs=label,
title="Classification des maladies des agrumes",
description="Téléchargez une image d'une feuille d'agrume pour classer sa maladie. Le modèle est entraîné sur les maladies suivantes : Aleurocanthus spiniferus, Chancre citrique, Cochenille blanche, Dépérissement des agrumes, Feuille saine, Jaunissement des feuilles, Maladie de l'oïdium, Maladie du dragon jaune, Mineuse des agrumes, Trou de balle.",
examples=["maladie_du_dragon_jaune.jpg", "mineuse_des_agrumes.jpg","feuille_saine.jpg"]).launch(share=True)