File size: 2,945 Bytes
8d12dce
bc18618
 
 
5da6a11
c9d6d7a
bc18618
c9d6d7a
efac948
6bb1d3a
bc18618
 
c9d6d7a
bc18618
 
 
 
 
 
 
 
 
 
 
 
 
c9d6d7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5da6a11
bc18618
c9d6d7a
 
5da6a11
c9d6d7a
bc18618
 
c9d6d7a
bc18618
 
c9d6d7a
bc18618
6bb1d3a
c9d6d7a
7e0ee60
6bb1d3a
c9d6d7a
b26ba26
7e0ee60
c9d6d7a
b26ba26
c9d6d7a
7e0ee60
c9d6d7a
da42d5c
fc5c4f6
c9d6d7a
9da9b42
 
fc5c4f6
00b5ecc
1af1578
 
 
 
c9d6d7a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import gradio as gr
import torch
from transformers import ViTFeatureExtractor, ViTForImageClassification
from PIL import Image
import numpy as np
import rembg

# Define the model and feature extractor
model_name ="KhadijaAsehnoune12/ViTOrangeLeafDiseaseClassifier"
model = ViTForImageClassification.from_pretrained(model_name, num_labels=10, ignore_mismatched_sizes=True)
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)

# Define the label mapping
id2label = {
    "0": "Aleurocanthus spiniferus",
    "1": "Chancre citrique",
    "2": "Cochenille blanche",
    "3": "Dépérissement des agrumes",
    "4": "Feuille saine",
    "5": "Jaunissement des feuilles",
    "6": "Maladie de l'oïdium",
    "7": "Maladie du dragon jaune",
    "8": "Mineuse des agrumes",
    "9": "Trou de balle"
}

def remove_background(image):
    # Convert the image to RGBA
    image = image.convert("RGBA")
    
    # Remove the background
    image_np = np.array(image)
    output_np = rembg.remove(image_np)
    
    # Create a white background image
    white_bg = Image.new("RGBA", image.size, "WHITE")
    
    # Composite the original image over the white background
    output_image = Image.alpha_composite(white_bg, Image.fromarray(output_np))
    
    # Convert back to RGB
    output_image = output_image.convert("RGB")
    
    return output_image

def predict(image):
    # Remove the background
    image = remove_background(image)
    
    # Preprocess the image
    inputs = feature_extractor(images=image, return_tensors="pt")
    
    # Forward pass through the model
    outputs = model(**inputs)
    
    # Get the logits
    logits = outputs.logits
    
    # Calculate confidence scores with softmax
    probs = torch.nn.functional.softmax(logits, dim=-1)[0]
    
    # Get the index of the most probable class
    predicted_class_idx = probs.argmax().item()
    
    # Get the label and confidence score of the most probable class
    predicted_label = id2label[str(predicted_class_idx)]
    confidence_score = probs[predicted_class_idx].item() * 100  # Multiply by 100 to get a percentage
    
    # Return the label and confidence score
    return f"{predicted_label}: {confidence_score:.2f}%"

# Create the Gradio interface
image = gr.Image(type="pil") 
label = gr.Textbox(label="Prediction")

gr.Interface(fn=predict,
             inputs=image,
             outputs=label,
             title="Classification des maladies des agrumes",
             description="Téléchargez une image d'une feuille d'agrume pour classer sa maladie. Le modèle est entraîné sur les maladies suivantes : Aleurocanthus spiniferus, Chancre citrique, Cochenille blanche, Dépérissement des agrumes, Feuille saine, Jaunissement des feuilles, Maladie de l'oïdium, Maladie du dragon jaune, Mineuse des agrumes, Trou de balle.",
             examples=["maladie_du_dragon_jaune.jpg", "critique.jpg", "feuille_saine.jpg"]).launch(share=True)