File size: 1,819 Bytes
10add35
9871674
 
10add35
9871674
0797cfd
10add35
d542ab3
9871674
 
 
 
 
10add35
c6200dd
9871674
 
 
 
 
 
 
c6200dd
d542ab3
10add35
e7ed4b7
10add35
0797cfd
 
e7ed4b7
0797cfd
10add35
 
d542ab3
9871674
 
 
10add35
0797cfd
9871674
 
 
0797cfd
 
 
 
9871674
0797cfd
 
 
d542ab3
9871674
0797cfd
 
9871674
 
 
 
d542ab3
9871674
d542ab3
9871674
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import gradio as gr
from transformers import pipeline, AutoImageProcessor, AutoModelForImageClassification
from PIL import Image, ImageDraw
import torch
from torchvision import transforms
import pandas as pd

# DATA AUGMENTATION 
augment = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
])

MODEL_ID = "tribber93/my-trash-classification"
trash_classifier = pipeline(
    "image-classification",
    model=MODEL_ID,
    device=0 if torch.cuda.is_available() else -1,
    top_k=3
)


# MAPPING
POUBELLES = {
    "cardboard": "papier/carton",
    "glass": "verre",
    "metal": "métal",
    "paper": "papier",
    "plastic": "plastique",
    "trash": "ordures ménagères",
}

#CLASSIFICATION
def classify_image(image: Image.Image):
    image_aug = augment(image)
    results = trash_classifier(image_aug)

    rows = []
    for r in results:
        label = r["label"]
        score = r["score"]
        poubelle = POUBELLES.get(label.lower(), "inconnue")
        rows.append({
            "Objet": label,
            "Poubelle": poubelle,
            "Confiance (%)": round(score * 100, 2)
        })
    return pd.DataFrame(rows)

#GRADIO
interface = gr.Interface(
    fn=classify_image,
    inputs=gr.Image(type="pil"),
    outputs=gr.Dataframe(
        headers=["Objet", "Poubelle", "Confiance (%)"],
        row_count=(1, 10)
    ),
    title="🗑️ Classifieur de Déchets ",
    description=(
        "Dépose une image de déchet pour savoir dans quelle poubelle la trier !! "
        "Le modèle est fine-tuné sur TrashNet et bénéficie de data augmentation pour une meilleure robustesse."
    ),
    examples=None,
    allow_flagging="never"
)

if __name__ == "__main__":
    interface.launch()