Spaces:
Running
Running
File size: 1,819 Bytes
10add35 9871674 10add35 9871674 0797cfd 10add35 d542ab3 9871674 10add35 c6200dd 9871674 c6200dd d542ab3 10add35 e7ed4b7 10add35 0797cfd e7ed4b7 0797cfd 10add35 d542ab3 9871674 10add35 0797cfd 9871674 0797cfd 9871674 0797cfd d542ab3 9871674 0797cfd 9871674 d542ab3 9871674 d542ab3 9871674 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
import gradio as gr
from transformers import pipeline, AutoImageProcessor, AutoModelForImageClassification
from PIL import Image, ImageDraw
import torch
from torchvision import transforms
import pandas as pd
# DATA AUGMENTATION
augment = transforms.Compose([
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomRotation(10),
transforms.ColorJitter(brightness=0.2, contrast=0.2),
])
MODEL_ID = "tribber93/my-trash-classification"
trash_classifier = pipeline(
"image-classification",
model=MODEL_ID,
device=0 if torch.cuda.is_available() else -1,
top_k=3
)
# MAPPING
POUBELLES = {
"cardboard": "papier/carton",
"glass": "verre",
"metal": "métal",
"paper": "papier",
"plastic": "plastique",
"trash": "ordures ménagères",
}
#CLASSIFICATION
def classify_image(image: Image.Image):
image_aug = augment(image)
results = trash_classifier(image_aug)
rows = []
for r in results:
label = r["label"]
score = r["score"]
poubelle = POUBELLES.get(label.lower(), "inconnue")
rows.append({
"Objet": label,
"Poubelle": poubelle,
"Confiance (%)": round(score * 100, 2)
})
return pd.DataFrame(rows)
#GRADIO
interface = gr.Interface(
fn=classify_image,
inputs=gr.Image(type="pil"),
outputs=gr.Dataframe(
headers=["Objet", "Poubelle", "Confiance (%)"],
row_count=(1, 10)
),
title="🗑️ Classifieur de Déchets ",
description=(
"Dépose une image de déchet pour savoir dans quelle poubelle la trier !! "
"Le modèle est fine-tuné sur TrashNet et bénéficie de data augmentation pour une meilleure robustesse."
),
examples=None,
allow_flagging="never"
)
if __name__ == "__main__":
interface.launch()
|