KhadijaAsehnoune12's picture
Update app.py
82e3bb8 verified
raw
history blame
2.34 kB
import gradio as gr
import torch
from transformers import ViTFeatureExtractor, ViTForImageClassification
from PIL import Image
import numpy as np
import rembg
# Define the model and feature extractor
model_name ="KhadijaAsehnoune12/ViTOrangeLeafDiseaseClassifier"
model = ViTForImageClassification.from_pretrained(model_name, num_labels=10, ignore_mismatched_sizes=True)
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
# Define the label mapping
id2label = {
"0": "Aleurocanthus spiniferus",
"1": "Chancre citrique",
"2": "Cochenille blanche",
"3": "Dépérissement des agrumes",
"4": "Feuille saine",
"5": "Jaunissement des feuilles",
"6": "Maladie de l'oïdium",
"7": "Maladie du dragon jaune",
"8": "Mineuse des agrumes",
"9": "Trou de balle"
}
def remove_background(image):
image = image.convert("RGBA")
image_np = np.array(image)
output_np = rembg.remove(image_np)
white_bg = Image.new("RGBA", image.size, "WHITE")
output_image = Image.alpha_composite(white_bg, Image.fromarray(output_np))
output_image = output_image.convert("RGB")
return output_image
def predict(image):
image = remove_background(image)
inputs = feature_extractor(images=image, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
probs = torch.nn.functional.softmax(logits, dim=-1)[0]
predicted_class_idx = probs.argmax().item()
predicted_label = id2label[str(predicted_class_idx)]
confidence_score = probs[predicted_class_idx].item() * 100
return f"{predicted_label}: {confidence_score:.2f}%"
# Create the Gradio interface
image = gr.Image(type="pil")
label = gr.Textbox(label="Prediction")
gr.Interface(fn=predict,
inputs=image,
outputs=label,
title="Classification des maladies des agrumes",
description="Téléchargez une image d'une feuille d'agrume pour classer sa maladie. Le modèle est entraîné sur les maladies suivantes : Aleurocanthus spiniferus, Chancre citrique, Cochenille blanche, Dépérissement des agrumes, Feuille saine, Jaunissement des feuilles, Maladie de l'oïdium, Maladie du dragon jaune, Mineuse des agrumes, Trou de balle.",
examples=["maladie_du_dragon_jaune.jpg", "feuille_saine.jpg"]).launch(share=True)