test / app.py
saaara's picture
Update app.py
3152ad7 verified
import tensorflow as tf
import numpy as np
import gradio as gr
from tensorflow.keras.preprocessing import image
from sklearn.preprocessing import LabelEncoder
import matplotlib.pyplot as plt
from PIL import Image
from tensorflow.keras.metrics import MeanSquaredError
# Dictionnaire des sous-catégories
subcategory_dict = {
"Furniture": ["Home Decor"],
"Home Decor": [
"Poufs & Ottomans", "Rugs", "Antique items", "Brass Lamps",
"Candle Holders", "Pottery", "Kilim poufs", "Pillow Covers",
"Wall Decor", "Straw Lamps"
],
# Ajoutez d'autres catégories ici
}
# Fonction pour charger et prétraiter l'image
def preprocess_image(img):
img = img.resize((224, 224)) # Redimensionner
img_array = np.array(img) / 255.0 # Normaliser
img_array = np.expand_dims(img_array, axis=0) # Ajouter une dimension batch
return img_array
# Fonction pour prédire la catégorie, le prix et la sous-catégorie
def predict_image(img):
# Prétraiter l'image
img_array = preprocess_image(img)
# Faire les prédictions
category_pred, price_pred = model.predict(img_array)
# Décoder la catégorie
category_pred_class = np.argmax(category_pred, axis=1)[0] # La classe avec la plus haute probabilité
category_name = label_encoder.inverse_transform([category_pred_class])[0]
# Trouver les sous-catégories correspondantes
subcategories = subcategory_dict.get(category_name, [])
# Préparer les résultats sous forme de texte
results = f"Category: {category_name}\nPrice ($): {price_pred[0][0]:.2f}\nSubcategories: {', '.join(subcategories) if subcategories else 'No subcategories'}"
return results
# Charger le modèle pré-entraîné
model = tf.keras.models.load_model('trained_model.h5', custom_objects={'mse': MeanSquaredError()})
# Interface Gradio
interface = gr.Interface(
fn=predict_image,
inputs=gr.Image(type="pil"),
outputs=gr.Textbox(label="Prediction Results"),
title="Image Classification with TensorFlow",
description="Upload an image to predict its category, price, and subcategories."
)
# Lancer l'interface
interface.launch()