File size: 2,154 Bytes
73dd2e6
 
 
 
 
 
 
7cd8cc1
73dd2e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8728306
73dd2e6
 
 
 
 
3152ad7
 
73dd2e6
 
 
1a04da2
 
73dd2e6
 
 
 
 
3152ad7
73dd2e6
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import tensorflow as tf
import numpy as np
import gradio as gr
from tensorflow.keras.preprocessing import image
from sklearn.preprocessing import LabelEncoder
import matplotlib.pyplot as plt
from PIL import Image
from tensorflow.keras.metrics import MeanSquaredError

# Dictionnaire des sous-catégories
subcategory_dict = {
    "Furniture": ["Home Decor"],
    "Home Decor": [
        "Poufs & Ottomans", "Rugs", "Antique items", "Brass Lamps",
        "Candle Holders", "Pottery", "Kilim poufs", "Pillow Covers",
        "Wall Decor", "Straw Lamps"
    ],
    # Ajoutez d'autres catégories ici
}

# Fonction pour charger et prétraiter l'image
def preprocess_image(img):
    img = img.resize((224, 224))  # Redimensionner
    img_array = np.array(img) / 255.0  # Normaliser
    img_array = np.expand_dims(img_array, axis=0)  # Ajouter une dimension batch
    return img_array

# Fonction pour prédire la catégorie, le prix et la sous-catégorie
def predict_image(img):
    # Prétraiter l'image
    img_array = preprocess_image(img)

    # Faire les prédictions
    category_pred, price_pred = model.predict(img_array)

    # Décoder la catégorie
    category_pred_class = np.argmax(category_pred, axis=1)[0]  # La classe avec la plus haute probabilité
    category_name = label_encoder.inverse_transform([category_pred_class])[0]

    # Trouver les sous-catégories correspondantes
    subcategories = subcategory_dict.get(category_name, [])

    # Préparer les résultats sous forme de texte
    results = f"Category: {category_name}\nPrice ($): {price_pred[0][0]:.2f}\nSubcategories: {', '.join(subcategories) if subcategories else 'No subcategories'}"
    return results

# Charger le modèle pré-entraîné
model = tf.keras.models.load_model('trained_model.h5', custom_objects={'mse': MeanSquaredError()})


# Interface Gradio
interface = gr.Interface(
    fn=predict_image,
    inputs=gr.Image(type="pil"),
    outputs=gr.Textbox(label="Prediction Results"),
    title="Image Classification with TensorFlow",
    description="Upload an image to predict its category, price, and subcategories."
)

# Lancer l'interface
interface.launch()