KhadijaAsehnoune12's picture
Update app.py
c9d6d7a verified
raw
history blame
No virus
2.95 kB
import gradio as gr
import torch
from transformers import ViTFeatureExtractor, ViTForImageClassification
from PIL import Image
import numpy as np
import rembg
# Define the model and feature extractor
model_name ="KhadijaAsehnoune12/ViTOrangeLeafDiseaseClassifier"
model = ViTForImageClassification.from_pretrained(model_name, num_labels=10, ignore_mismatched_sizes=True)
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
# Define the label mapping
id2label = {
"0": "Aleurocanthus spiniferus",
"1": "Chancre citrique",
"2": "Cochenille blanche",
"3": "Dépérissement des agrumes",
"4": "Feuille saine",
"5": "Jaunissement des feuilles",
"6": "Maladie de l'oïdium",
"7": "Maladie du dragon jaune",
"8": "Mineuse des agrumes",
"9": "Trou de balle"
}
def remove_background(image):
# Convert the image to RGBA
image = image.convert("RGBA")
# Remove the background
image_np = np.array(image)
output_np = rembg.remove(image_np)
# Create a white background image
white_bg = Image.new("RGBA", image.size, "WHITE")
# Composite the original image over the white background
output_image = Image.alpha_composite(white_bg, Image.fromarray(output_np))
# Convert back to RGB
output_image = output_image.convert("RGB")
return output_image
def predict(image):
# Remove the background
image = remove_background(image)
# Preprocess the image
inputs = feature_extractor(images=image, return_tensors="pt")
# Forward pass through the model
outputs = model(**inputs)
# Get the logits
logits = outputs.logits
# Calculate confidence scores with softmax
probs = torch.nn.functional.softmax(logits, dim=-1)[0]
# Get the index of the most probable class
predicted_class_idx = probs.argmax().item()
# Get the label and confidence score of the most probable class
predicted_label = id2label[str(predicted_class_idx)]
confidence_score = probs[predicted_class_idx].item() * 100 # Multiply by 100 to get a percentage
# Return the label and confidence score
return f"{predicted_label}: {confidence_score:.2f}%"
# Create the Gradio interface
image = gr.Image(type="pil")
label = gr.Textbox(label="Prediction")
gr.Interface(fn=predict,
inputs=image,
outputs=label,
title="Classification des maladies des agrumes",
description="Téléchargez une image d'une feuille d'agrume pour classer sa maladie. Le modèle est entraîné sur les maladies suivantes : Aleurocanthus spiniferus, Chancre citrique, Cochenille blanche, Dépérissement des agrumes, Feuille saine, Jaunissement des feuilles, Maladie de l'oïdium, Maladie du dragon jaune, Mineuse des agrumes, Trou de balle.",
examples=["maladie_du_dragon_jaune.jpg", "critique.jpg", "feuille_saine.jpg"]).launch(share=True)