import spaces import gradio as gr import torch import numpy as np import pandas as pd import seaborn as sns import matplotlib.pyplot as plt from sklearn.decomposition import PCA from sklearn.manifold import TSNE from sklearn.cluster import KMeans from lifelines import KaplanMeierFitter from yellowbrick.cluster import KElbowVisualizer from itertools import combinations from lifelines.statistics import logrank_test from MultiOmicsGraphAttentionAutoencoderModel import MultiOmicsGraphAttentionAutoencoderModel from OmicsConfig import OmicsConfig from Attention_Extracter import Attention_Extracter from GraphAnalysis import GraphAnalysis device = torch.device('cuda') # Load the autoencoder model autoencoder_config = OmicsConfig.from_pretrained("./lc_models/MultiOmicsAutoencoder/trained_autoencoder") autoencoder_model = MultiOmicsGraphAttentionAutoencoderModel(autoencoder_config).to(device) # Initialize Attention Extracter graph_data_dict_path = './data/hnscc.patient.chg.network.pth' extracter = Attention_Extracter(graph_data_dict_path, autoencoder_model.encoder, gpu=False) def extract_features(): ga = GraphAnalysis(extracter) return ga def find_optimal_clusters(ga, min_clusters, max_clusters): ga.find_optimal_clusters(min_clusters=min_clusters, max_clusters=max_clusters, save_path='./temp') return ga.optimal_clusters def perform_clustering(ga, num_clusters): ga.cluster_data2(num_clusters) return "Clustering completed." def plot_kaplan_meier(ga): ga.plot_kaplan_meier() return "Kaplan-Meier plot saved." def plot_median_survival_bar(ga): ga.plot_median_survival_bar(name='temp') return "Median survival bar plot saved." def perform_log_rank_test(ga): significant_pairs = ga.perform_log_rank_test() return f"Significant pairs from log-rank test: {significant_pairs}" css = """ #col-container { margin: 0 auto; max-width: 520px; } """ @spaces.GPU def run_extract_features(): return extract_features() @spaces.GPU def run_find_optimal_clusters(ga, min_clusters, max_clusters): return find_optimal_clusters(ga, min_clusters, max_clusters) @spaces.GPU def run_perform_clustering(ga, num_clusters): return perform_clustering(ga, num_clusters) @spaces.GPU def run_plot_kaplan_meier(ga): return plot_kaplan_meier(ga) @spaces.GPU def run_plot_median_survival_bar(ga): return plot_median_survival_bar(ga) @spaces.GPU def run_perform_log_rank_test(ga): return perform_log_rank_test(ga) with gr.Blocks(css=css) as demo: with gr.Column(elem_id="col-container"): gr.Markdown(f""" # Graph to Features and Analysis Currently running on {device}. """) with gr.Row(): extract_button = gr.Button("Extract Features") clustering_button = gr.Button("Find Optimal Clusters") cluster_data_button = gr.Button("Perform Clustering") kaplan_meier_button = gr.Button("Plot Kaplan-Meier") survival_bar_button = gr.Button("Plot Median Survival Bar") log_rank_button = gr.Button("Perform Log-Rank Test") num_clusters = gr.Slider(label="Number of Clusters", minimum=2, maximum=10, step=1, value=5) min_clusters = gr.Slider(label="Min Clusters for Elbow Method", minimum=2, maximum=10, step=1, value=2) max_clusters = gr.Slider(label="Max Clusters for Elbow Method", minimum=3, maximum=20, step=1, value=10) result = gr.Textbox(label="Result") ga = gr.State() extract_button.click(fn=run_extract_features, inputs=[], outputs=[ga]) clustering_button.click(fn=run_find_optimal_clusters, inputs=[ga, min_clusters, max_clusters], outputs=[result]) cluster_data_button.click(fn=run_perform_clustering, inputs=[ga, num_clusters], outputs=[result]) kaplan_meier_button.click(fn=run_plot_kaplan_meier, inputs=[ga], outputs=[result]) survival_bar_button.click(fn=run_plot_median_survival_bar, inputs=[ga], outputs=[result]) log_rank_button.click(fn=run_perform_log_rank_test, inputs=[ga], outputs=[result]) demo.queue().launch()