Spaces:
Running
Running
import streamlit as st | |
import tensorflow as tf | |
import numpy as np | |
import time | |
import tensorflow.keras as keras | |
from tensorflow.keras.applications import VGG16 | |
from tensorflow.keras.layers import Dense, Flatten | |
from tensorflow.keras.models import Model, load_model | |
from datasets import load_dataset | |
import matplotlib.pyplot as plt | |
from sklearn.model_selection import train_test_split | |
from sklearn.metrics import confusion_matrix, classification_report | |
import seaborn as sns | |
from huggingface_hub import HfApi | |
import os | |
# π Percorso della cache | |
os.environ["HF_HOME"] = "/app/.cache" | |
os.environ["HF_DATASETS_CACHE"] = "/app/.cache" | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
# π Autenticazione Hugging Face | |
if HF_TOKEN: | |
api = HfApi() | |
user_info = api.whoami(HF_TOKEN) | |
st.write(f"β Autenticato come {user_info.get('name', 'Utente sconosciuto')}") | |
else: | |
st.warning("β οΈ Nessun token API trovato! Verifica il Secret nello Space.") | |
# π Caricamento del dataset | |
st.write("π Caricamento di 300 immagini da `tiny-imagenet`...") | |
dataset = load_dataset("zh-plus/tiny-imagenet", split="train") | |
image_list = [] | |
label_list = [] | |
for i, sample in enumerate(dataset): | |
if i >= 300: # Prende solo 300 immagini | |
break | |
image = tf.image.resize(sample["image"], (64, 64)) / 255.0 # Normalizzazione | |
image_list.append(image.numpy()) | |
label_list.append(np.array(sample["label"])) | |
X = np.array(image_list) | |
y = np.array(label_list) | |
# π Suddivisione dataset: 80% training, 20% validation | |
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42) | |
st.write(f"π **Training:** {X_train.shape[0]} immagini") | |
st.write(f"π **Validation:** {X_val.shape[0]} immagini") | |
# π Checkbox per decidere se rifare il training | |
force_training = st.checkbox("π Rifai il training anche se Silva.h5 esiste") | |
# π Caricamento o training del modello | |
history = None # π Inizializza history | |
if os.path.exists("Silva.h5") and not force_training: | |
model = load_model("Silva.h5") | |
st.write("β Modello `Silva.h5` caricato, nessun nuovo training necessario!") | |
else: | |
st.write("π Training in corso...") | |
base_model = VGG16(weights="imagenet", include_top=False, input_shape=(64, 64, 3)) | |
for layer in base_model.layers: | |
layer.trainable = False | |
x = Flatten()(base_model.output) | |
x = Dense(256, activation="relu")(x) | |
x = Dense(128, activation="relu")(x) | |
output = Dense(len(set(y_train)), activation="softmax")(x) | |
model = Model(inputs=base_model.input, outputs=output) | |
model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"]) | |
history = model.fit(X_train, y_train, epochs=10, validation_data=(X_val, y_val)) | |
model.save("Silva.h5") | |
st.write("β Modello salvato come `Silva.h5`!") | |
# π Calcolo delle metriche sulla validazione | |
y_pred_val = np.argmax(model.predict(X_val), axis=1) | |
accuracy_val = np.mean(y_pred_val == y_val) | |
rmse_val = np.sqrt(np.mean((y_pred_val - y_val) ** 2)) | |
report_val = classification_report(y_val, y_pred_val, output_dict=True) | |
recall_val = report_val["weighted avg"]["recall"] | |
precision_val = report_val["weighted avg"]["precision"] | |
f1_score_val = report_val["weighted avg"]["f1-score"] | |
st.write(f"π **Validation Accuracy:** {accuracy_val:.4f}") | |
st.write(f"π **Validation RMSE:** {rmse_val:.4f}") | |
st.write(f"π **Validation Precision:** {precision_val:.4f}") | |
st.write(f"π **Validation Recall:** {recall_val:.4f}") | |
st.write(f"π **Validation F1-Score:** {f1_score_val:.4f}") | |
# π Bottone per generare la matrice di confusione sulla validazione | |
if st.button("π Genera matrice di confusione per validazione"): | |
conf_matrix_val = confusion_matrix(y_val, y_pred_val) | |
fig, ax = plt.subplots(figsize=(10, 7)) | |
sns.heatmap(conf_matrix_val, annot=True, cmap="Blues", fmt="d", ax=ax) | |
st.pyplot(fig) | |
st.write("β Matrice di confusione generata!") | |
# π Grafico per Loss e Accuracy con validazione | |
if history is not None: | |
fig, ax = plt.subplots(1, 2, figsize=(12, 5)) | |
ax[0].plot(history.history["loss"], label="Training Loss") | |
ax[0].plot(history.history["val_loss"], label="Validation Loss") | |
ax[1].plot(history.history["accuracy"], label="Training Accuracy") | |
ax[1].plot(history.history["val_accuracy"], label="Validation Accuracy") | |
ax[0].set_title("Loss durante il training e validazione") | |
ax[1].set_title("Accuracy durante il training e validazione") | |
ax[0].legend() | |
ax[1].legend() | |
st.pyplot(fig) | |
else: | |
st.warning("β οΈ Nessun training eseguito, impossibile mostrare il grafico!") | |
# π Bottone per scaricare il modello | |
if os.path.exists("Silva.h5"): | |
with open("Silva.h5", "rb") as f: | |
st.download_button( | |
label="π₯ Scarica il modello Silva.h5", | |
data=f, | |
file_name="Silva.h5", | |
mime="application/octet-stream" | |
) | |
# π Bottone per caricare il modello su Hugging Face | |
def upload_model(): | |
api.upload_file( | |
path_or_fileobj="Silva.h5", | |
path_in_repo="Silva.h5", | |
repo_id="scontess/trainigVVG16", | |
repo_type="space" | |
) | |
st.success("β Modello 'Silva.h5' caricato su Hugging Face!") | |
st.write("π₯ Carica il modello Silva su Hugging Face") | |
if st.button("π Carica Silva su Model Store"): | |
upload_model() | |