File size: 5,360 Bytes
7d22163
1ae86c6
 
d694d14
1ae86c6
 
 
 
d14694e
1ae86c6
b5b6de6
a16d3e9
1ae86c6
d694d14
1ae86c6
dee6b0b
d694d14
 
 
 
 
 
 
 
 
 
 
 
 
a16d3e9
3e8960a
f3ee846
 
1ae86c6
 
d694d14
5ba5379
d694d14
5ba5379
a16d3e9
1ae86c6
d592c7c
1ae86c6
b5b6de6
 
1ae86c6
b5b6de6
 
 
 
 
1ae86c6
9e3e3ef
 
 
 
d694d14
7308952
9e3e3ef
1ae86c6
d694d14
1ae86c6
9e3e3ef
d5a3bd1
1ae86c6
 
 
 
 
 
 
 
 
 
 
b5b6de6
1ae86c6
d694d14
1ae86c6
b5b6de6
 
 
 
 
 
d694d14
 
 
 
b5b6de6
 
d694d14
 
 
b5b6de6
d694d14
 
b5b6de6
a16d3e9
b5b6de6
a16d3e9
d694d14
a16d3e9
d694d14
9e3e3ef
7308952
 
 
 
 
d694d14
 
7308952
 
 
d694d14
 
a16d3e9
1ae86c6
 
 
 
 
 
 
 
 
 
b7612aa
1ae86c6
 
 
 
a16d3e9
d809567
1ae86c6
a16d3e9
7d22163
1ae86c6
 
a16d3e9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import streamlit as st
import tensorflow as tf
import numpy as np
import time
import tensorflow.keras as keras
from tensorflow.keras.applications import VGG16
from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras.models import Model, load_model
from datasets import load_dataset
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns
from huggingface_hub import HfApi
import os

# πŸ“Œ Percorso della cache
os.environ["HF_HOME"] = "/app/.cache"
os.environ["HF_DATASETS_CACHE"] = "/app/.cache"
HF_TOKEN = os.getenv("HF_TOKEN")

# πŸ“Œ Autenticazione Hugging Face
if HF_TOKEN:
    api = HfApi()
    user_info = api.whoami(HF_TOKEN)
    st.write(f"βœ… Autenticato come {user_info.get('name', 'Utente sconosciuto')}")
else:
    st.warning("⚠️ Nessun token API trovato! Verifica il Secret nello Space.")

# πŸ“Œ Caricamento del dataset
st.write("πŸ”„ Caricamento di 300 immagini da `tiny-imagenet`...")
dataset = load_dataset("zh-plus/tiny-imagenet", split="train")

image_list = []
label_list = []

for i, sample in enumerate(dataset):
    if i >= 300:  # Prende solo 300 immagini
        break
    image = tf.image.resize(sample["image"], (64, 64)) / 255.0  # Normalizzazione
    image_list.append(image.numpy())
    label_list.append(np.array(sample["label"]))  

X = np.array(image_list)
y = np.array(label_list)

# πŸ“Œ Suddivisione dataset: 80% training, 20% validation
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)

st.write(f"πŸ“Š **Training:** {X_train.shape[0]} immagini")
st.write(f"πŸ“Š **Validation:** {X_val.shape[0]} immagini")

# πŸ“Œ Checkbox per decidere se rifare il training
force_training = st.checkbox("πŸ”„ Rifai il training anche se Silva.h5 esiste")

# πŸ“Œ Caricamento o training del modello
history = None  # πŸ›  Inizializza history

if os.path.exists("Silva.h5") and not force_training:
    model = load_model("Silva.h5")
    st.write("βœ… Modello `Silva.h5` caricato, nessun nuovo training necessario!")
else:
    st.write("πŸš€ Training in corso...")
    base_model = VGG16(weights="imagenet", include_top=False, input_shape=(64, 64, 3))
    for layer in base_model.layers:
        layer.trainable = False

    x = Flatten()(base_model.output)
    x = Dense(256, activation="relu")(x)
    x = Dense(128, activation="relu")(x)
    output = Dense(len(set(y_train)), activation="softmax")(x)

    model = Model(inputs=base_model.input, outputs=output)
    model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])

    history = model.fit(X_train, y_train, epochs=10, validation_data=(X_val, y_val))
    model.save("Silva.h5")
    st.write("βœ… Modello salvato come `Silva.h5`!")

# πŸ“Œ Calcolo delle metriche sulla validazione
y_pred_val = np.argmax(model.predict(X_val), axis=1)
accuracy_val = np.mean(y_pred_val == y_val)
rmse_val = np.sqrt(np.mean((y_pred_val - y_val) ** 2))
report_val = classification_report(y_val, y_pred_val, output_dict=True)

recall_val = report_val["weighted avg"]["recall"]
precision_val = report_val["weighted avg"]["precision"]
f1_score_val = report_val["weighted avg"]["f1-score"]

st.write(f"πŸ“Š **Validation Accuracy:** {accuracy_val:.4f}")
st.write(f"πŸ“Š **Validation RMSE:** {rmse_val:.4f}")
st.write(f"πŸ“Š **Validation Precision:** {precision_val:.4f}")
st.write(f"πŸ“Š **Validation Recall:** {recall_val:.4f}")
st.write(f"πŸ“Š **Validation F1-Score:** {f1_score_val:.4f}")

# πŸ“Œ Bottone per generare la matrice di confusione sulla validazione
if st.button("πŸ”Ž Genera matrice di confusione per validazione"):
    conf_matrix_val = confusion_matrix(y_val, y_pred_val)
    fig, ax = plt.subplots(figsize=(10, 7))
    sns.heatmap(conf_matrix_val, annot=True, cmap="Blues", fmt="d", ax=ax)
    st.pyplot(fig)
    st.write("βœ… Matrice di confusione generata!")

# πŸ“Œ Grafico per Loss e Accuracy con validazione
if history is not None:
    fig, ax = plt.subplots(1, 2, figsize=(12, 5))
    ax[0].plot(history.history["loss"], label="Training Loss")
    ax[0].plot(history.history["val_loss"], label="Validation Loss")
    ax[1].plot(history.history["accuracy"], label="Training Accuracy")
    ax[1].plot(history.history["val_accuracy"], label="Validation Accuracy")
    ax[0].set_title("Loss durante il training e validazione")
    ax[1].set_title("Accuracy durante il training e validazione")
    ax[0].legend()
    ax[1].legend()
    st.pyplot(fig)
else:
    st.warning("⚠️ Nessun training eseguito, impossibile mostrare il grafico!")

# πŸ“Œ Bottone per scaricare il modello
if os.path.exists("Silva.h5"):
    with open("Silva.h5", "rb") as f:
        st.download_button(
            label="πŸ“₯ Scarica il modello Silva.h5",
            data=f,
            file_name="Silva.h5",
            mime="application/octet-stream"
        )

# πŸ“Œ Bottone per caricare il modello su Hugging Face
def upload_model():
    api.upload_file(
        path_or_fileobj="Silva.h5",
        path_in_repo="Silva.h5",
        repo_id="scontess/trainigVVG16",
        repo_type="space"
    )
    st.success("βœ… Modello 'Silva.h5' caricato su Hugging Face!")

st.write("πŸ“₯ Carica il modello Silva su Hugging Face")
if st.button("πŸš€ Carica Silva su Model Store"):
    upload_model()