Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import ViTImageProcessor, ViTForImageClassification | |
| from PIL import Image | |
| import torch | |
| import os | |
| # --- Chargement du modèle et du processeur --- | |
| print("Loading model and processor...") | |
| model_name = "google/vit-base-patch16-224" | |
| processor = ViTImageProcessor.from_pretrained(model_name) | |
| model = ViTForImageClassification.from_pretrained(model_name) | |
| print("Model loaded successfully!") | |
| def predict(image): | |
| """Fonction de prédiction avec gestion d'erreurs et seuil de confiance""" | |
| try: | |
| # Conversion vers RGB pour éviter les erreurs de canaux | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| # Pré-traitement de l'image | |
| inputs = processor(images=image, return_tensors="pt") | |
| # Prédiction | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| # Application de softmax pour obtenir les probabilités | |
| probabilities = torch.nn.functional.softmax(logits, dim=-1)[0] | |
| top_probs, top_indices = torch.topk(probabilities, 5) # Top 5 predictions | |
| # Formatage des résultats sous forme de dictionnaire pour l'affichage | |
| results = {} | |
| for prob, idx in zip(top_probs, top_indices): | |
| pred_label = model.config.id2label[idx.item()] | |
| confidence = prob.item() | |
| if confidence > 0.01: # Seuil de confiance à 1% | |
| results[pred_label] = confidence | |
| if not results: | |
| return {"Aucune prédiction fiable": 0.0}, "Je ne suis pas sûr de reconnaître cet item. Essayez avec une image plus claire." | |
| # Créer un message de résultat | |
| top_prediction = list(results.items())[0] | |
| message = f"🏷️ Prédiction principale: {top_prediction[0]} ({top_prediction[1]:.2%})" | |
| return results, message | |
| except Exception as e: | |
| return {"Erreur": 0.0}, f"Une erreur s'est produite: {str(e)}" | |
| # Interface Gradio améliorée | |
| with gr.Blocks(title="Fashion Classifier", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# 👗 Fashion Item Classifier") | |
| gr.Markdown("Téléchargez une image de vêtement pour le classer automatiquement") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| image_input = gr.Image( | |
| type="pil", | |
| label="Image du vêtement", | |
| height=300, | |
| sources=["upload", "webcam", "clipboard"] | |
| ) | |
| upload_btn = gr.Button("🚀 Analyser l'image", variant="primary") | |
| with gr.Column(scale=1): | |
| label_output = gr.Label( | |
| label="Résultats de classification", | |
| num_top_classes=5 | |
| ) | |
| text_output = gr.Textbox( | |
| label="Conclusion", | |
| interactive=False | |
| ) | |
| # Exemples | |
| gr.Examples( | |
| examples=[ | |
| ["https://images.unsplash.com/photo-1552374196-c4e7ffc6e126?w=300"], # T-shirt | |
| ["https://images.unsplash.com/photo-1543163521-1bf539c55dd2?w=300"], # Chaussures | |
| ["https://images.unsplash.com/photo-1594633312681-425c7b97ccd1?w=300"] # Robe | |
| ], | |
| inputs=image_input, | |
| label="Exemples d'images à tester" | |
| ) | |
| # Instructions | |
| gr.Markdown(""" | |
| ### 📋 Instructions | |
| - Téléchargez une image claire d'un vêtement | |
| - L'image doit montrer le vêtement de face | |
| - Fond uni recommandé pour de meilleurs résultats | |
| - Cliquez sur 'Analyser l'image' pour obtenir la classification | |
| """) | |
| # Liaison du bouton | |
| upload_btn.click( | |
| fn=predict, | |
| inputs=image_input, | |
| outputs=[label_output, text_output] | |
| ) | |
| # Liaison aussi quand on upload une image | |
| image_input.upload( | |
| fn=predict, | |
| inputs=image_input, | |
| outputs=[label_output, text_output] | |
| ) | |
| # Lancement de l'application | |
| if __name__ == "__main__": | |
| demo.launch( | |
| debug=True, | |
| server_name="0.0.0.0", | |
| server_port=int(os.environ.get("PORT", 7860)) | |
| ) |