| import gradio as gr |
| import tensorflow as tf |
| from tensorflow.keras import layers, Model |
| import keras |
| import numpy as np |
| import pandas as pd |
| import os |
| import matplotlib.pyplot as plt |
|
|
| |
| MODEL_PATH = "./models/aging_score_autoencoder_fixed.keras" |
|
|
| |
| model = None |
| encoder_model = None |
| load_error = None |
|
|
|
|
| def build_model(input_dim=18241, latent_dim=32): |
| """ |
| Rebuild the model architecture from scratch to avoid deserialization issues. |
| Architecture matches the training notebook exactly. |
| """ |
| inputs = layers.Input(shape=(input_dim,)) |
| |
| |
| x = layers.Dense(512, activation="relu")(inputs) |
| x = layers.BatchNormalization()(x) |
| x = layers.Dropout(0.3)(x) |
| x = layers.Dense(128, activation="relu")(x) |
| latent = layers.Dense(latent_dim, name="latent")(x) |
| |
| |
| x = layers.Dense(128, activation="relu")(latent) |
| x = layers.Dense(512, activation="relu")(x) |
| reconstruction = layers.Dense(input_dim, name="reconstruction")(x) |
| |
| |
| age_pred = layers.Dense(1, name="age")(latent) |
| |
| model = Model(inputs=inputs, outputs=[reconstruction, age_pred]) |
| return model |
|
|
|
|
| def load_resources(): |
| global model, encoder_model, load_error |
| load_error = None |
| |
| |
| model = build_model() |
| |
| |
| if os.path.exists(MODEL_PATH): |
| try: |
| print(f"Loading weights from {MODEL_PATH}...") |
| |
| try: |
| saved_model = keras.saving.load_model( |
| MODEL_PATH, |
| compile=False, |
| safe_mode=False, |
| ) |
| model.set_weights(saved_model.get_weights()) |
| print("Weights loaded successfully from .keras file.") |
| except Exception: |
| |
| try: |
| model.load_weights(MODEL_PATH) |
| print("Weights loaded successfully.") |
| except Exception as e: |
| load_error = f"Could not load weights: {e}" |
| print(f"Warning: {load_error}") |
| print("Model will run with random weights.") |
| except Exception as e: |
| load_error = f"Error loading weights: {e}" |
| print(f"Warning: {load_error}") |
| print("Model will run with random weights.") |
| else: |
| load_error = f"Model file not found at {MODEL_PATH}. Model will run with random weights." |
| print(load_error) |
| |
| |
| try: |
| latent_layer = model.get_layer("latent") |
| encoder_model = Model(inputs=model.input, outputs=latent_layer.output) |
| print("Encoder model created successfully.") |
| except Exception as e: |
| print(f"Warning: Could not create encoder model: {e}") |
|
|
|
|
| |
| load_resources() |
|
|
| def predict_aging(input_file, chron_age): |
| if model is None: |
| if load_error: |
| return f"Error: {load_error}", None, None |
| return "Error: Model not found.", None, None |
| |
| try: |
| |
| if input_file.name.endswith('.csv'): |
| df = pd.read_csv(input_file.name) |
| else: |
| df = pd.read_parquet(input_file.name) |
|
|
| |
| META_COLS = ["sample_id", "subject_id", "tissue", "sex", "age", "death_time", "estimated_age"] |
| gene_cols = [c for c in df.columns if c not in META_COLS] |
| X = df[gene_cols].values |
| |
| |
| X_scaled = np.log1p(X) |
| X_scaled = (X_scaled - np.mean(X_scaled)) / (np.std(X_scaled) + 1e-8) |
| |
| |
| |
| _, age_pred = model.predict(X_scaled) |
| biological_age = float(age_pred[0][0]) |
| |
| |
| aging_score = "N/A" |
| if encoder_model: |
| latent_vector = encoder_model.predict(X_scaled) |
| |
| aging_score = float(np.mean(latent_vector[0])) |
|
|
| |
| rhythm = biological_age - chron_age |
| status = "Vieillissement Accéléré ⚠️" if rhythm > 2 else "Vieillissement Ralenti ✅" if rhythm < -2 else "Vieillissement Normal 🆗" |
| |
| |
| res_text = f""" |
| ### Résultats d'Analyse |
| - **Âge Chronologique :** {chron_age} ans |
| - **Âge Biologique (Estimé) :** {biological_age:.2f} ans |
| - **Score de Vieillissement (Latent) :** {aging_score:.4f} |
| - **Statut :** {status} |
| """ |
| |
| |
| fig, ax = plt.subplots(figsize=(6, 2)) |
| colors = ['#2ecc71', '#f1c40f', '#e74c3c'] |
| ax.barh(['Rythme'], [rhythm], color='#3498db') |
| ax.axvline(0, color='black', linestyle='--') |
| ax.set_title("Différentiel de Vieillissement (Bio - Chrono)") |
| ax.set_xlim(-15, 15) |
| |
| return res_text, fig |
| |
| except Exception as e: |
| return f"Erreur : {str(e)}", None |
|
|
| |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: |
| gr.Markdown("# 🧠 Aging Score Bio-Predictor") |
| gr.Markdown("Analyse du rythme de vieillissement biologique via Autoencoder supervisé.") |
| |
| with gr.Row(): |
| with gr.Column(): |
| input_file = gr.File(label="Données Transcriptomiques (18k gènes)") |
| chron_age = gr.Number(label="Âge Chronologique Réel", value=40) |
| btn = gr.Button("Calculer l'Aging Score", variant="primary") |
| |
| with gr.Column(): |
| output_text = gr.Markdown() |
| output_plot = gr.Plot() |
|
|
| btn.click(fn=predict_aging, inputs=[input_file, chron_age], outputs=[output_text, output_plot]) |
|
|
| demo.launch() |
|
|