|
|
|
|
|
""" |
|
|
Interface Streamlit pour la classification de déchets - Version Hugging Face Spaces |
|
|
Déployé sur Hugging Face Spaces avec téléchargement automatique des modèles |
|
|
""" |
|
|
|
|
|
import streamlit as st |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import matplotlib.pyplot as plt |
|
|
import seaborn as sns |
|
|
from PIL import Image |
|
|
import tensorflow as tf |
|
|
from tensorflow.keras.models import load_model |
|
|
from tensorflow.keras.preprocessing import image |
|
|
import os |
|
|
from pathlib import Path |
|
|
import logging |
|
|
import requests |
|
|
import zipfile |
|
|
import tempfile |
|
|
|
|
|
|
|
|
st.set_page_config( |
|
|
page_title="Classificateur de Déchets IA", |
|
|
page_icon="♻️", |
|
|
layout="wide", |
|
|
initial_sidebar_state="expanded" |
|
|
) |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class WasteClassifierUI: |
|
|
"""Classe principale pour l'interface de classification de déchets.""" |
|
|
|
|
|
def __init__(self): |
|
|
self.model_v1 = None |
|
|
self.model_v2 = None |
|
|
self.class_names = ["Papier", "Plastique"] |
|
|
self.target_size = (96, 96) |
|
|
|
|
|
|
|
|
self.models_dir = Path("models") |
|
|
self.models_dir.mkdir(exist_ok=True) |
|
|
|
|
|
self.model_v1_path = self.models_dir / "waste_classifier_v1.h5" |
|
|
self.model_v2_path = self.models_dir / "waste_classifier_v2.h5" |
|
|
|
|
|
|
|
|
|
|
|
self.model_v1_url = os.getenv('MODEL_V1_URL', "https://huggingface.co/your-username/waste-classifier/resolve/main/models/waste_classifier_v1.h5") |
|
|
self.model_v2_url = os.getenv('MODEL_V2_URL', "https://huggingface.co/your-username/waste-classifier/resolve/main/models/waste_classifier_v2.h5") |
|
|
|
|
|
|
|
|
local_v1 = Path("models/waste_classifier_v1.h5") |
|
|
local_v2 = Path("models/waste_classifier_v2.h5") |
|
|
|
|
|
if local_v1.exists(): |
|
|
self.model_v1_path = local_v1 |
|
|
if local_v2.exists(): |
|
|
self.model_v2_path = local_v2 |
|
|
|
|
|
def download_model(self, url, local_path): |
|
|
"""Télécharge un modèle depuis une URL.""" |
|
|
try: |
|
|
if local_path.exists(): |
|
|
logger.info(f"Modèle déjà présent: {local_path}") |
|
|
return True |
|
|
|
|
|
logger.info(f"Téléchargement du modèle depuis: {url}") |
|
|
response = requests.get(url, stream=True) |
|
|
response.raise_for_status() |
|
|
|
|
|
with open(local_path, 'wb') as f: |
|
|
for chunk in response.iter_content(chunk_size=8192): |
|
|
f.write(chunk) |
|
|
|
|
|
logger.info(f"Modèle téléchargé avec succès: {local_path}") |
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Erreur lors du téléchargement: {e}") |
|
|
return False |
|
|
|
|
|
def load_models(self): |
|
|
"""Charge les modèles v1 et v2.""" |
|
|
try: |
|
|
|
|
|
if not self.model_v1_path.exists(): |
|
|
st.info("Téléchargement du modèle v1...") |
|
|
if not self.download_model(self.model_v1_url, self.model_v1_path): |
|
|
st.warning("Impossible de télécharger le modèle v1") |
|
|
else: |
|
|
st.success("Modèle v1 téléchargé avec succès!") |
|
|
|
|
|
|
|
|
if self.model_v1_path.exists(): |
|
|
self.model_v1 = load_model(self.model_v1_path) |
|
|
logger.info("Modèle v1 chargé avec succès") |
|
|
else: |
|
|
logger.warning("Modèle v1 non disponible") |
|
|
|
|
|
|
|
|
if not self.model_v2_path.exists(): |
|
|
st.info("Téléchargement du modèle v2...") |
|
|
if not self.download_model(self.model_v2_url, self.model_v2_path): |
|
|
st.warning("Impossible de télécharger le modèle v2") |
|
|
else: |
|
|
st.success("Modèle v2 téléchargé avec succès!") |
|
|
|
|
|
|
|
|
if self.model_v2_path.exists(): |
|
|
self.model_v2 = load_model(self.model_v2_path) |
|
|
logger.info("Modèle v2 chargé avec succès") |
|
|
else: |
|
|
logger.warning("Modèle v2 non disponible") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Erreur lors du chargement des modèles: {e}") |
|
|
st.error(f"Erreur lors du chargement des modèles: {e}") |
|
|
|
|
|
def preprocess_image(self, img, target_size=(96, 96)): |
|
|
"""Préprocesse une image pour la prédiction.""" |
|
|
try: |
|
|
|
|
|
img_resized = img.resize(target_size) |
|
|
|
|
|
|
|
|
img_array = image.img_to_array(img_resized) |
|
|
|
|
|
|
|
|
img_array = img_array / 255.0 |
|
|
|
|
|
|
|
|
img_array = np.expand_dims(img_array, axis=0) |
|
|
|
|
|
return img_array |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Erreur lors du preprocessing: {e}") |
|
|
st.error(f"Erreur lors du preprocessing: {e}") |
|
|
return None |
|
|
|
|
|
def predict_image(self, img_array, model, model_name): |
|
|
"""Prédit la classe d'une image avec un modèle donné.""" |
|
|
try: |
|
|
if model is None: |
|
|
return None |
|
|
|
|
|
|
|
|
predictions = model.predict(img_array, verbose=0) |
|
|
|
|
|
|
|
|
predicted_class_idx = np.argmax(predictions[0]) |
|
|
confidence = predictions[0][predicted_class_idx] |
|
|
predicted_class = self.class_names[predicted_class_idx] |
|
|
|
|
|
|
|
|
class_probabilities = {} |
|
|
for i, class_name in enumerate(self.class_names): |
|
|
class_probabilities[class_name] = float(predictions[0][i]) |
|
|
|
|
|
result = { |
|
|
'model_name': model_name, |
|
|
'predicted_class': predicted_class, |
|
|
'confidence': float(confidence), |
|
|
'class_probabilities': class_probabilities |
|
|
} |
|
|
|
|
|
return result |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Erreur lors de la prédiction avec {model_name}: {e}") |
|
|
st.error(f"Erreur lors de la prédiction avec {model_name}: {e}") |
|
|
return None |
|
|
|
|
|
def create_confidence_chart(self, results): |
|
|
"""Crée un graphique en barres des probabilités de confiance.""" |
|
|
if not results: |
|
|
return None |
|
|
|
|
|
fig, axes = plt.subplots(1, len(results), figsize=(6 * len(results), 5)) |
|
|
if len(results) == 1: |
|
|
axes = [axes] |
|
|
|
|
|
for i, result in enumerate(results): |
|
|
if result is None: |
|
|
continue |
|
|
|
|
|
classes = list(result['class_probabilities'].keys()) |
|
|
probabilities = list(result['class_probabilities'].values()) |
|
|
|
|
|
|
|
|
bars = axes[i].bar(classes, probabilities, |
|
|
color=['#2E8B57' if c == result['predicted_class'] else '#4682B4' |
|
|
for c in classes]) |
|
|
|
|
|
axes[i].set_title(f"{result['model_name']}\nPrédiction: {result['predicted_class']}\nConfiance: {result['confidence']:.3f}") |
|
|
axes[i].set_ylabel("Probabilité") |
|
|
axes[i].set_ylim(0, 1) |
|
|
|
|
|
|
|
|
for bar, prob in zip(bars, probabilities): |
|
|
height = bar.get_height() |
|
|
axes[i].text(bar.get_x() + bar.get_width()/2., height + 0.01, |
|
|
f'{prob:.3f}', ha='center', va='bottom', fontweight='bold') |
|
|
|
|
|
plt.tight_layout() |
|
|
return fig |
|
|
|
|
|
def run(self): |
|
|
"""Lance l'interface Streamlit.""" |
|
|
|
|
|
st.title("♻️ Classificateur de Déchets IA") |
|
|
st.markdown("---") |
|
|
|
|
|
|
|
|
if self.model_v1 is None or self.model_v2 is None: |
|
|
with st.spinner("Chargement des modèles..."): |
|
|
self.load_models() |
|
|
|
|
|
|
|
|
st.sidebar.header("Configuration") |
|
|
|
|
|
|
|
|
available_models = [] |
|
|
if self.model_v1 is not None: |
|
|
available_models.append("Modèle v1") |
|
|
if self.model_v2 is not None: |
|
|
available_models.append("Modèle v2") |
|
|
|
|
|
if not available_models: |
|
|
st.error("Aucun modèle disponible. Vérifiez la connexion internet et réessayez.") |
|
|
return |
|
|
|
|
|
selected_models = st.sidebar.multiselect( |
|
|
"Sélectionnez les modèles à utiliser:", |
|
|
available_models, |
|
|
default=available_models |
|
|
) |
|
|
|
|
|
|
|
|
st.sidebar.header("Upload d'image") |
|
|
uploaded_file = st.sidebar.file_uploader( |
|
|
"Choisissez une image de déchet:", |
|
|
type=['jpg', 'jpeg', 'png', 'bmp', 'tiff'], |
|
|
help="Formats supportés: JPG, JPEG, PNG, BMP, TIFF" |
|
|
) |
|
|
|
|
|
|
|
|
col1, col2 = st.columns([1, 1]) |
|
|
|
|
|
with col1: |
|
|
st.header("Image d'entrée") |
|
|
if uploaded_file is not None: |
|
|
|
|
|
image_pil = Image.open(uploaded_file) |
|
|
st.image(image_pil, caption="Image uploadée", use_column_width=True) |
|
|
|
|
|
|
|
|
st.info(f"**Dimensions originales:** {image_pil.size[0]} x {image_pil.size[1]} pixels") |
|
|
|
|
|
|
|
|
if st.button("🔍 Classifier l'image", type="primary"): |
|
|
if not selected_models: |
|
|
st.warning("Veuillez sélectionner au moins un modèle.") |
|
|
else: |
|
|
with st.spinner("Classification en cours..."): |
|
|
|
|
|
img_array = self.preprocess_image(image_pil, self.target_size) |
|
|
|
|
|
if img_array is not None: |
|
|
|
|
|
results = [] |
|
|
for model_name in selected_models: |
|
|
if model_name == "Modèle v1" and self.model_v1 is not None: |
|
|
result = self.predict_image(img_array, self.model_v1, "Modèle v1") |
|
|
results.append(result) |
|
|
elif model_name == "Modèle v2" and self.model_v2 is not None: |
|
|
result = self.predict_image(img_array, self.model_v2, "Modèle v2") |
|
|
results.append(result) |
|
|
|
|
|
|
|
|
st.session_state['prediction_results'] = results |
|
|
st.session_state['uploaded_image'] = image_pil |
|
|
else: |
|
|
st.info("Veuillez uploader une image pour commencer la classification.") |
|
|
|
|
|
with col2: |
|
|
st.header("Résultats de classification") |
|
|
|
|
|
|
|
|
if 'prediction_results' in st.session_state and st.session_state['prediction_results']: |
|
|
results = st.session_state['prediction_results'] |
|
|
|
|
|
|
|
|
st.subheader("📊 Résumé des prédictions") |
|
|
|
|
|
for result in results: |
|
|
if result is not None: |
|
|
col_pred, col_conf = st.columns([2, 1]) |
|
|
with col_pred: |
|
|
st.write(f"**{result['model_name']}:**") |
|
|
with col_conf: |
|
|
confidence_pct = result['confidence'] * 100 |
|
|
st.metric("Confiance", f"{confidence_pct:.1f}%") |
|
|
|
|
|
|
|
|
st.progress(result['confidence']) |
|
|
|
|
|
|
|
|
with st.expander(f"Détails - {result['model_name']}"): |
|
|
for class_name, prob in result['class_probabilities'].items(): |
|
|
prob_pct = prob * 100 |
|
|
st.write(f"**{class_name}:** {prob_pct:.2f}%") |
|
|
|
|
|
|
|
|
if len(results) > 1: |
|
|
st.subheader("📈 Comparaison des modèles") |
|
|
fig = self.create_confidence_chart(results) |
|
|
if fig is not None: |
|
|
st.pyplot(fig) |
|
|
|
|
|
|
|
|
st.subheader("💡 Recommandation") |
|
|
if len(results) == 1: |
|
|
result = results[0] |
|
|
if result is not None: |
|
|
confidence_pct = result['confidence'] * 100 |
|
|
if confidence_pct >= 80: |
|
|
st.success(f"Classification très fiable: {result['predicted_class']} ({confidence_pct:.1f}%)") |
|
|
elif confidence_pct >= 60: |
|
|
st.warning(f"Classification modérée: {result['predicted_class']} ({confidence_pct:.1f}%)") |
|
|
else: |
|
|
st.error(f"Classification incertaine: {result['predicted_class']} ({confidence_pct:.1f}%)") |
|
|
else: |
|
|
|
|
|
predictions = [r['predicted_class'] for r in results if r is not None] |
|
|
confidences = [r['confidence'] for r in results if r is not None] |
|
|
|
|
|
if len(set(predictions)) == 1: |
|
|
st.success(f"✅ Consensus: Tous les modèles prédisent '{predictions[0]}'") |
|
|
else: |
|
|
st.warning("⚠️ Divergence: Les modèles donnent des prédictions différentes") |
|
|
for i, (pred, conf) in enumerate(zip(predictions, confidences)): |
|
|
st.write(f"- {results[i]['model_name']}: {pred} ({conf*100:.1f}%)") |
|
|
else: |
|
|
st.info("Les résultats de classification apparaîtront ici après l'analyse.") |
|
|
|
|
|
|
|
|
st.markdown("---") |
|
|
st.markdown( |
|
|
""" |
|
|
<div style='text-align: center; color: #666;'> |
|
|
<p>Classificateur de Déchets IA - Modèles v1 et v2</p> |
|
|
<p>Déployé sur Hugging Face Spaces</p> |
|
|
</div> |
|
|
""", |
|
|
unsafe_allow_html=True |
|
|
) |
|
|
|
|
|
def main(): |
|
|
"""Fonction principale.""" |
|
|
classifier_ui = WasteClassifierUI() |
|
|
classifier_ui.run() |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|