import gradio as gr import torch from transformers import ViTFeatureExtractor, ViTForImageClassification from PIL import Image import numpy as np import rembg # Define the model and feature extractor model_name ="KhadijaAsehnoune12/ViTOrangeLeafDiseaseClassifier" model = ViTForImageClassification.from_pretrained(model_name, num_labels=10, ignore_mismatched_sizes=True) feature_extractor = ViTFeatureExtractor.from_pretrained(model_name) # Define the label mapping id2label = { "0": "Aleurocanthus spiniferus", "1": "Chancre citrique", "2": "Cochenille blanche", "3": "Dépérissement des agrumes", "4": "Feuille saine", "5": "Jaunissement des feuilles", "6": "Maladie de l'oïdium", "7": "Maladie du dragon jaune", "8": "Mineuse des agrumes", "9": "Trou de balle" } def remove_background(image): # Convert the image to RGBA image = image.convert("RGBA") # Remove the background image_np = np.array(image) output_np = rembg.remove(image_np) # Create a white background image white_bg = Image.new("RGBA", image.size, "WHITE") # Composite the original image over the white background output_image = Image.alpha_composite(white_bg, Image.fromarray(output_np)) # Convert back to RGB output_image = output_image.convert("RGB") return output_image def predict(image): # Remove the background image = remove_background(image) # Preprocess the image inputs = feature_extractor(images=image, return_tensors="pt") # Forward pass through the model outputs = model(**inputs) # Get the logits logits = outputs.logits # Calculate confidence scores with softmax probs = torch.nn.functional.softmax(logits, dim=-1)[0] # Get the index of the most probable class predicted_class_idx = probs.argmax().item() # Get the label and confidence score of the most probable class predicted_label = id2label[str(predicted_class_idx)] confidence_score = probs[predicted_class_idx].item() * 100 # Multiply by 100 to get a percentage # Return the label and confidence score return f"{predicted_label}: {confidence_score:.2f}%" # Create the Gradio interface image = gr.Image(type="pil") label = gr.Textbox(label="Prediction") gr.Interface(fn=predict, inputs=image, outputs=label, title="Classification des maladies des agrumes", description="Téléchargez une image d'une feuille d'agrume pour classer sa maladie. Le modèle est entraîné sur les maladies suivantes : Aleurocanthus spiniferus, Chancre citrique, Cochenille blanche, Dépérissement des agrumes, Feuille saine, Jaunissement des feuilles, Maladie de l'oïdium, Maladie du dragon jaune, Mineuse des agrumes, Trou de balle.", examples=["maladie_du_dragon_jaune.jpg", "critique.jpg", "feuille_saine.jpg"]).launch(share=True)