MODLI commited on
Commit
3474c7b
·
verified ·
1 Parent(s): bf440a3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -17
app.py CHANGED
@@ -2,13 +2,14 @@ import gradio as gr
2
  from transformers import ViTImageProcessor, ViTForImageClassification
3
  from PIL import Image
4
  import torch
 
5
 
6
  # --- Chargement du modèle et du processeur ---
7
- # Modèle de base ViT pré-entraîné sur ImageNet (beaucoup mieux que "beans")
8
- # C'est une solution temporaire en attendant de fine-tuner sur le dataset mode
9
  model_name = "google/vit-base-patch16-224"
10
  processor = ViTImageProcessor.from_pretrained(model_name)
11
  model = ViTForImageClassification.from_pretrained(model_name)
 
12
 
13
  def predict(image):
14
  """Fonction de prédiction avec gestion d'erreurs et seuil de confiance"""
@@ -34,11 +35,9 @@ def predict(image):
34
  for i, (prob, idx) in enumerate(zip(top_probs, top_indices)):
35
  pred_label = model.config.id2label[idx.item()]
36
  confidence = prob.item()
37
- # N'afficher que si la confiance est > 10%
38
- if confidence > 0.1:
39
  predictions.append(f"{pred_label}: {confidence:.2%}")
40
 
41
- # Si aucune prédiction n'a une confiance suffisante
42
  if not predictions:
43
  return "Je ne suis pas sûr de reconnaître cet item. Essayez avec une image plus claire."
44
 
@@ -51,17 +50,10 @@ def predict(image):
51
  title = "Fashion Item Classifier"
52
  description = (
53
  "Upload an image of a clothing item, and I will classify it. "
54
- "⚠️ This is a general-purpose model. For better accuracy on fashion items, "
55
  "a specialized model is needed."
56
  )
57
 
58
- # Exemples d'images (ajoutez vos propres exemples plus tard)
59
- examples = [
60
- ["shirt_example.jpg"],
61
- ["shoe_example.jpg"],
62
- ["dress_example.jpg"]
63
- ]
64
-
65
  # Création de l'interface
66
  demo = gr.Interface(
67
  fn=predict,
@@ -69,10 +61,18 @@ demo = gr.Interface(
69
  outputs=gr.Textbox(label="Classification Results"),
70
  title=title,
71
  description=description,
72
- examples=examples,
73
- allow_flagging="never"
 
 
 
74
  )
75
 
76
- # Lancement de l'application
77
  if __name__ == "__main__":
78
- demo.launch(debug=True, share=False)
 
 
 
 
 
 
2
  from transformers import ViTImageProcessor, ViTForImageClassification
3
  from PIL import Image
4
  import torch
5
+ import os
6
 
7
  # --- Chargement du modèle et du processeur ---
8
+ print("Loading model and processor...")
 
9
  model_name = "google/vit-base-patch16-224"
10
  processor = ViTImageProcessor.from_pretrained(model_name)
11
  model = ViTForImageClassification.from_pretrained(model_name)
12
+ print("Model loaded successfully!")
13
 
14
  def predict(image):
15
  """Fonction de prédiction avec gestion d'erreurs et seuil de confiance"""
 
35
  for i, (prob, idx) in enumerate(zip(top_probs, top_indices)):
36
  pred_label = model.config.id2label[idx.item()]
37
  confidence = prob.item()
38
+ if confidence > 0.1: # Seuil de confiance à 10%
 
39
  predictions.append(f"{pred_label}: {confidence:.2%}")
40
 
 
41
  if not predictions:
42
  return "Je ne suis pas sûr de reconnaître cet item. Essayez avec une image plus claire."
43
 
 
50
  title = "Fashion Item Classifier"
51
  description = (
52
  "Upload an image of a clothing item, and I will classify it. "
53
+ "This is a general-purpose model (ImageNet). For better accuracy on fashion items, "
54
  "a specialized model is needed."
55
  )
56
 
 
 
 
 
 
 
 
57
  # Création de l'interface
58
  demo = gr.Interface(
59
  fn=predict,
 
61
  outputs=gr.Textbox(label="Classification Results"),
62
  title=title,
63
  description=description,
64
+ allow_flagging="never",
65
+ examples=[
66
+ ["https://images.unsplash.com/photo-1552374196-c4e7ffc6e126?w=400"], # T-shirt example
67
+ ["https://images.unsplash.com/photo-1543163521-1bf539c55dd2?w=400"] # Shoe example
68
+ ]
69
  )
70
 
71
+ # Lancement de l'application - CONFIGURATION SPÉCIFIQUE POUR HUGGING FACE SPACES
72
  if __name__ == "__main__":
73
+ # Cette configuration est cruciale pour Hugging Face Spaces
74
+ demo.launch(
75
+ debug=True,
76
+ server_name="0.0.0.0", # Important pour les conteneurs Docker
77
+ server_port=int(os.environ.get("PORT", 7860)) Utilise le port de l'environnement Spaces
78
+ )