Spaces:
Runtime error
Runtime error
import spaces | |
import gradio as gr | |
import torch | |
import numpy as np | |
import pandas as pd | |
import seaborn as sns | |
import matplotlib.pyplot as plt | |
from sklearn.decomposition import PCA | |
from sklearn.manifold import TSNE | |
from sklearn.cluster import KMeans | |
from lifelines import KaplanMeierFitter | |
from yellowbrick.cluster import KElbowVisualizer | |
from itertools import combinations | |
from lifelines.statistics import logrank_test | |
from MultiOmicsGraphAttentionAutoencoderModel import MultiOmicsGraphAttentionAutoencoderModel | |
from OmicsConfig import OmicsConfig | |
from Attention_Extracter import Attention_Extracter | |
from GraphAnalysis import GraphAnalysis | |
device = torch.device('cuda') | |
# Load the autoencoder model | |
autoencoder_config = OmicsConfig.from_pretrained("./lc_models/MultiOmicsAutoencoder/trained_autoencoder") | |
autoencoder_model = MultiOmicsGraphAttentionAutoencoderModel(autoencoder_config).to(device) | |
# Initialize Attention Extracter | |
graph_data_dict_path = './data/hnscc.patient.chg.network.pth' | |
extracter = Attention_Extracter(graph_data_dict_path, autoencoder_model.encoder, gpu=False) | |
def extract_features(): | |
ga = GraphAnalysis(extracter) | |
return ga | |
def find_optimal_clusters(ga, min_clusters, max_clusters): | |
ga.find_optimal_clusters(min_clusters=min_clusters, max_clusters=max_clusters, save_path='./temp') | |
return ga.optimal_clusters | |
def perform_clustering(ga, num_clusters): | |
ga.cluster_data2(num_clusters) | |
return "Clustering completed." | |
def plot_kaplan_meier(ga): | |
ga.plot_kaplan_meier() | |
return "Kaplan-Meier plot saved." | |
def plot_median_survival_bar(ga): | |
ga.plot_median_survival_bar(name='temp') | |
return "Median survival bar plot saved." | |
def perform_log_rank_test(ga): | |
significant_pairs = ga.perform_log_rank_test() | |
return f"Significant pairs from log-rank test: {significant_pairs}" | |
css = """ | |
#col-container { | |
margin: 0 auto; | |
max-width: 520px; | |
} | |
""" | |
def run_extract_features(): | |
return extract_features() | |
def run_find_optimal_clusters(ga, min_clusters, max_clusters): | |
return find_optimal_clusters(ga, min_clusters, max_clusters) | |
def run_perform_clustering(ga, num_clusters): | |
return perform_clustering(ga, num_clusters) | |
def run_plot_kaplan_meier(ga): | |
return plot_kaplan_meier(ga) | |
def run_plot_median_survival_bar(ga): | |
return plot_median_survival_bar(ga) | |
def run_perform_log_rank_test(ga): | |
return perform_log_rank_test(ga) | |
with gr.Blocks(css=css) as demo: | |
with gr.Column(elem_id="col-container"): | |
gr.Markdown(f""" | |
# Graph to Features and Analysis | |
Currently running on {device}. | |
""") | |
with gr.Row(): | |
extract_button = gr.Button("Extract Features") | |
clustering_button = gr.Button("Find Optimal Clusters") | |
cluster_data_button = gr.Button("Perform Clustering") | |
kaplan_meier_button = gr.Button("Plot Kaplan-Meier") | |
survival_bar_button = gr.Button("Plot Median Survival Bar") | |
log_rank_button = gr.Button("Perform Log-Rank Test") | |
num_clusters = gr.Slider(label="Number of Clusters", minimum=2, maximum=10, step=1, value=5) | |
min_clusters = gr.Slider(label="Min Clusters for Elbow Method", minimum=2, maximum=10, step=1, value=2) | |
max_clusters = gr.Slider(label="Max Clusters for Elbow Method", minimum=3, maximum=20, step=1, value=10) | |
result = gr.Textbox(label="Result") | |
ga = gr.State() | |
extract_button.click(fn=run_extract_features, inputs=[], outputs=[ga]) | |
clustering_button.click(fn=run_find_optimal_clusters, inputs=[ga, min_clusters, max_clusters], outputs=[result]) | |
cluster_data_button.click(fn=run_perform_clustering, inputs=[ga, num_clusters], outputs=[result]) | |
kaplan_meier_button.click(fn=run_plot_kaplan_meier, inputs=[ga], outputs=[result]) | |
survival_bar_button.click(fn=run_plot_median_survival_bar, inputs=[ga], outputs=[result]) | |
log_rank_button.click(fn=run_perform_log_rank_test, inputs=[ga], outputs=[result]) | |
demo.queue().launch() | |