KhadijaAsehnoune12's picture
Update app.py
17ae7fe verified
raw
history blame
2.11 kB
import gradio as gr
import torch
from transformers import ViTFeatureExtractor, ViTForImageClassification
from PIL import Image
# Define the model and feature extractor
model_name = "KhadijaAsehnoune12/LeafDiseaseDetectorForOranges"
model = ViTForImageClassification.from_pretrained(model_name, num_labels=10, ignore_mismatched_sizes=True)
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
# Define the label mapping
id2label = {
"0": "Aleurocanthus spiniferus",
"1": "Chancre citrique",
"2": "Cochenille blanche",
"3": "Dépérissement des agrumes",
"4": "Feuille saine",
"5": "Jaunissement des feuilles",
"6": "Maladie de l'oïdium",
"7": "Maladie du dragon jaune",
"8": "Mineuse des agrumes",
"9": "Trou de balle"
}
def predict(image):
# Preprocess the image
inputs = feature_extractor(images=image, return_tensors="pt")
# Forward pass through the model
outputs = model(**inputs)
# Get the predicted label and confidence score
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
confidence_score = torch.nn.functional.softmax(logits, dim=-1)[0, predicted_class_idx].item()
# Get the label name
predicted_label = id2label[str(predicted_class_idx)]
# Return the predicted label and confidence score
return f"{predicted_label}: {confidence_score:.2f}"
# Create the Gradio interface
image = gr.Image(type="pil")
label = gr.Textbox(label="Prediction")
gr.Interface(fn=predict,
inputs="image",
outputs="text",
title="Citrus Disease Classification",
description="Upload an image of a citrus leaf to classify its disease. The model is trained on the following diseases: Aleurocanthus spiniferus, Chancre citrique, Cochenille blanche, Dépérissement des agrumes, Feuille saine, Jaunissement des feuilles, Maladie de l'oïdium, Maladie du dragon jaune, Mineuse des agrumes, Trou de balle.",
examples=["maladie_du_dragon_jaune.jpg", "mineuse_des_agrumes.jpg","feuille_saine.jpg"]).launch(share=True)