File size: 2,945 Bytes
8d12dce bc18618 5da6a11 c9d6d7a bc18618 c9d6d7a efac948 6bb1d3a bc18618 c9d6d7a bc18618 c9d6d7a 5da6a11 bc18618 c9d6d7a 5da6a11 c9d6d7a bc18618 c9d6d7a bc18618 c9d6d7a bc18618 6bb1d3a c9d6d7a 7e0ee60 6bb1d3a c9d6d7a b26ba26 7e0ee60 c9d6d7a b26ba26 c9d6d7a 7e0ee60 c9d6d7a da42d5c fc5c4f6 c9d6d7a 9da9b42 fc5c4f6 00b5ecc 1af1578 c9d6d7a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
import gradio as gr
import torch
from transformers import ViTFeatureExtractor, ViTForImageClassification
from PIL import Image
import numpy as np
import rembg
# Define the model and feature extractor
model_name ="KhadijaAsehnoune12/ViTOrangeLeafDiseaseClassifier"
model = ViTForImageClassification.from_pretrained(model_name, num_labels=10, ignore_mismatched_sizes=True)
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
# Define the label mapping
id2label = {
"0": "Aleurocanthus spiniferus",
"1": "Chancre citrique",
"2": "Cochenille blanche",
"3": "Dépérissement des agrumes",
"4": "Feuille saine",
"5": "Jaunissement des feuilles",
"6": "Maladie de l'oïdium",
"7": "Maladie du dragon jaune",
"8": "Mineuse des agrumes",
"9": "Trou de balle"
}
def remove_background(image):
# Convert the image to RGBA
image = image.convert("RGBA")
# Remove the background
image_np = np.array(image)
output_np = rembg.remove(image_np)
# Create a white background image
white_bg = Image.new("RGBA", image.size, "WHITE")
# Composite the original image over the white background
output_image = Image.alpha_composite(white_bg, Image.fromarray(output_np))
# Convert back to RGB
output_image = output_image.convert("RGB")
return output_image
def predict(image):
# Remove the background
image = remove_background(image)
# Preprocess the image
inputs = feature_extractor(images=image, return_tensors="pt")
# Forward pass through the model
outputs = model(**inputs)
# Get the logits
logits = outputs.logits
# Calculate confidence scores with softmax
probs = torch.nn.functional.softmax(logits, dim=-1)[0]
# Get the index of the most probable class
predicted_class_idx = probs.argmax().item()
# Get the label and confidence score of the most probable class
predicted_label = id2label[str(predicted_class_idx)]
confidence_score = probs[predicted_class_idx].item() * 100 # Multiply by 100 to get a percentage
# Return the label and confidence score
return f"{predicted_label}: {confidence_score:.2f}%"
# Create the Gradio interface
image = gr.Image(type="pil")
label = gr.Textbox(label="Prediction")
gr.Interface(fn=predict,
inputs=image,
outputs=label,
title="Classification des maladies des agrumes",
description="Téléchargez une image d'une feuille d'agrume pour classer sa maladie. Le modèle est entraîné sur les maladies suivantes : Aleurocanthus spiniferus, Chancre citrique, Cochenille blanche, Dépérissement des agrumes, Feuille saine, Jaunissement des feuilles, Maladie de l'oïdium, Maladie du dragon jaune, Mineuse des agrumes, Trou de balle.",
examples=["maladie_du_dragon_jaune.jpg", "critique.jpg", "feuille_saine.jpg"]).launch(share=True)
|