MODLI's picture
Update app.py
b06362e verified
import gradio as gr
from transformers import ViTImageProcessor, ViTForImageClassification
from PIL import Image
import torch
import os
# --- Chargement du modèle et du processeur ---
print("Loading model and processor...")
model_name = "google/vit-base-patch16-224"
processor = ViTImageProcessor.from_pretrained(model_name)
model = ViTForImageClassification.from_pretrained(model_name)
print("Model loaded successfully!")
def predict(image):
"""Fonction de prédiction avec gestion d'erreurs et seuil de confiance"""
try:
# Conversion vers RGB pour éviter les erreurs de canaux
if image.mode != 'RGB':
image = image.convert('RGB')
# Pré-traitement de l'image
inputs = processor(images=image, return_tensors="pt")
# Prédiction
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
# Application de softmax pour obtenir les probabilités
probabilities = torch.nn.functional.softmax(logits, dim=-1)[0]
top_probs, top_indices = torch.topk(probabilities, 5) # Top 5 predictions
# Formatage des résultats sous forme de dictionnaire pour l'affichage
results = {}
for prob, idx in zip(top_probs, top_indices):
pred_label = model.config.id2label[idx.item()]
confidence = prob.item()
if confidence > 0.01: # Seuil de confiance à 1%
results[pred_label] = confidence
if not results:
return {"Aucune prédiction fiable": 0.0}, "Je ne suis pas sûr de reconnaître cet item. Essayez avec une image plus claire."
# Créer un message de résultat
top_prediction = list(results.items())[0]
message = f"🏷️ Prédiction principale: {top_prediction[0]} ({top_prediction[1]:.2%})"
return results, message
except Exception as e:
return {"Erreur": 0.0}, f"Une erreur s'est produite: {str(e)}"
# Interface Gradio améliorée
with gr.Blocks(title="Fashion Classifier", theme=gr.themes.Soft()) as demo:
gr.Markdown("# 👗 Fashion Item Classifier")
gr.Markdown("Téléchargez une image de vêtement pour le classer automatiquement")
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(
type="pil",
label="Image du vêtement",
height=300,
sources=["upload", "webcam", "clipboard"]
)
upload_btn = gr.Button("🚀 Analyser l'image", variant="primary")
with gr.Column(scale=1):
label_output = gr.Label(
label="Résultats de classification",
num_top_classes=5
)
text_output = gr.Textbox(
label="Conclusion",
interactive=False
)
# Exemples
gr.Examples(
examples=[
["https://images.unsplash.com/photo-1552374196-c4e7ffc6e126?w=300"], # T-shirt
["https://images.unsplash.com/photo-1543163521-1bf539c55dd2?w=300"], # Chaussures
["https://images.unsplash.com/photo-1594633312681-425c7b97ccd1?w=300"] # Robe
],
inputs=image_input,
label="Exemples d'images à tester"
)
# Instructions
gr.Markdown("""
### 📋 Instructions
- Téléchargez une image claire d'un vêtement
- L'image doit montrer le vêtement de face
- Fond uni recommandé pour de meilleurs résultats
- Cliquez sur 'Analyser l'image' pour obtenir la classification
""")
# Liaison du bouton
upload_btn.click(
fn=predict,
inputs=image_input,
outputs=[label_output, text_output]
)
# Liaison aussi quand on upload une image
image_input.upload(
fn=predict,
inputs=image_input,
outputs=[label_output, text_output]
)
# Lancement de l'application
if __name__ == "__main__":
demo.launch(
debug=True,
server_name="0.0.0.0",
server_port=int(os.environ.get("PORT", 7860))
)