TESTFASHION / app.py
MODLI's picture
Update app.py
ca655fa verified
import os
os.environ['HF_HOME'] = '/tmp/cache'
os.environ['TORCH_HOME'] = '/tmp/cache'
import json
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from PIL import Image
import torch
import requests
from io import BytesIO
from transformers import CLIPProcessor, CLIPModel
app = FastAPI(title="Fashion Classification API")
# Middleware CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
expose_headers=["*"]
)
# --- Configuration du modèle ---
print("🔄 Chargement du modèle Fashion CLIP...")
model = None
processor = None
def load_model():
global model, processor
try:
model_name = "patrickjohncyh/fashion-clip"
model = CLIPModel.from_pretrained(model_name)
processor = CLIPProcessor.from_pretrained(model_name)
print("✅ Modèle chargé avec succès!")
except Exception as e:
print(f"❌ Erreur de chargement: {e}")
# Charger le modèle au démarrage
load_model()
# Catégories en français avec mapping vers anglais
CATEGORIES_FR = {
"haut": ["a t-shirt", "a shirt", "a sweater", "a blouse", "a top"],
"pantalon": ["jeans", "pants", "trousers", "leggings"],
"robe": ["a dress", "a gown", "a sundress"],
"jupe": ["a skirt"],
"short": ["shorts", "bermuda shorts"],
"veste": ["a jacket", "a blazer", "a leather jacket"],
"manteau": ["a coat", "a winter coat", "a parka"],
"chaussures": ["sneakers", "high heels", "boots", "sandals"],
"sac": ["a handbag", "a purse", "a backpack"],
"accessoire": ["a hat", "sunglasses", "a scarf", "a belt"],
"autre": ["clothing", "fashion item"]
}
@app.get("/")
def read_root():
return {"message": "Fashion Classification API is running!", "status": "OK"}
@app.get("/health")
def health_check():
return {
"model_loaded": model is not None,
"status": "ready" if model else "loading"
}
@app.post("/classify")
async def classify_fashion(image_data: dict):
"""
Endpoint pour Lovable - accepte une URL d'image
"""
try:
if not model or not processor:
raise HTTPException(status_code=503, detail="Model not loaded yet")
# Vérifier et extraire l'URL de l'image
image_url = image_data.get("imageUrl")
if not image_url:
raise HTTPException(status_code=400, detail="imageUrl is required")
# Télécharger l'image
response = requests.get(image_url, timeout=30)
response.raise_for_status()
# Ouvrir et préparer l'image
image = Image.open(BytesIO(response.content)).convert("RGB")
image = image.resize((224, 224)) # Taille standard pour CLIP
# Préparer les catégories
all_english_categories = []
category_mapping = {}
for fr_cat, en_categories in CATEGORIES_FR.items():
all_english_categories.extend(en_categories)
for en_cat in en_categories:
category_mapping[en_cat] = fr_cat
# === NOUVELLE APPROCHE : Traitement séquentiel ===
results = {}
for category in all_english_categories:
try:
# Traiter chaque catégorie individuellement
inputs = processor(
text=[category], # Une seule catégorie à la fois
images=image,
return_tensors="pt",
padding=True,
truncation=True,
max_length=77,
return_overflowing_tokens=False
)
with torch.no_grad():
outputs = model(**inputs)
results[category] = outputs.logits_per_image.item()
except Exception as e:
print(f"Erreur avec la catégorie {category}: {e}")
results[category] = -10.0 # Valeur très basse en cas d'erreur
# Trouver la meilleure catégorie
if not results:
raise HTTPException(status_code=500, detail="Aucun résultat obtenu")
best_english_category = max(results, key=results.get)
confidence = results[best_english_category]
# Convertir le score en probabilité (approximative)
confidence_normalized = 1 / (1 + torch.exp(torch.tensor(-confidence))).item()
# Catégorie française
best_french_category = category_mapping.get(best_english_category, "autre")
return {
"success": True,
"category": best_french_category,
"confidence": round(confidence_normalized, 4),
"colorHex": "#000000",
"originalCategory": best_english_category,
"method": "modli-api"
}
except requests.exceptions.RequestException as e:
raise HTTPException(status_code=400, detail=f"Invalid image URL: {str(e)}")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Classification error: {str(e)}")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)