Spaces:
Runtime error
Runtime error
File size: 4,082 Bytes
dc549dd ce0e9f1 648ca95 ce0e9f1 36dd86e ce0e9f1 648ca95 f0e9ca6 648ca95 f0e9ca6 36dd86e 648ca95 9adde00 648ca95 ce0e9f1 dc549dd 36dd86e dc549dd 36dd86e dc549dd 36dd86e dc549dd 36dd86e dc549dd 36dd86e dc549dd 36dd86e ce0e9f1 648ca95 ce0e9f1 648ca95 ce0e9f1 648ca95 36dd86e 648ca95 30724c3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
import spaces
import gradio as gr
import torch
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.cluster import KMeans
from lifelines import KaplanMeierFitter
from yellowbrick.cluster import KElbowVisualizer
from itertools import combinations
from lifelines.statistics import logrank_test
from MultiOmicsGraphAttentionAutoencoderModel import MultiOmicsGraphAttentionAutoencoderModel
from OmicsConfig import OmicsConfig
from Attention_Extracter import Attention_Extracter
from GraphAnalysis import GraphAnalysis
device = torch.device('cuda')
# Load the autoencoder model
autoencoder_config = OmicsConfig.from_pretrained("./lc_models/MultiOmicsAutoencoder/trained_autoencoder")
autoencoder_model = MultiOmicsGraphAttentionAutoencoderModel(autoencoder_config).to(device)
# Initialize Attention Extracter
graph_data_dict_path = './data/hnscc.patient.chg.network.pth'
extracter = Attention_Extracter(graph_data_dict_path, autoencoder_model.encoder, gpu=False)
def extract_features():
ga = GraphAnalysis(extracter)
return ga
def find_optimal_clusters(ga, min_clusters, max_clusters):
ga.find_optimal_clusters(min_clusters=min_clusters, max_clusters=max_clusters, save_path='./temp')
return ga.optimal_clusters
def perform_clustering(ga, num_clusters):
ga.cluster_data2(num_clusters)
return "Clustering completed."
def plot_kaplan_meier(ga):
ga.plot_kaplan_meier()
return "Kaplan-Meier plot saved."
def plot_median_survival_bar(ga):
ga.plot_median_survival_bar(name='temp')
return "Median survival bar plot saved."
def perform_log_rank_test(ga):
significant_pairs = ga.perform_log_rank_test()
return f"Significant pairs from log-rank test: {significant_pairs}"
css = """
#col-container {
margin: 0 auto;
max-width: 520px;
}
"""
@spaces.GPU
def run_extract_features():
return extract_features()
@spaces.GPU
def run_find_optimal_clusters(ga, min_clusters, max_clusters):
return find_optimal_clusters(ga, min_clusters, max_clusters)
@spaces.GPU
def run_perform_clustering(ga, num_clusters):
return perform_clustering(ga, num_clusters)
@spaces.GPU
def run_plot_kaplan_meier(ga):
return plot_kaplan_meier(ga)
@spaces.GPU
def run_plot_median_survival_bar(ga):
return plot_median_survival_bar(ga)
@spaces.GPU
def run_perform_log_rank_test(ga):
return perform_log_rank_test(ga)
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(f"""
# Graph to Features and Analysis
Currently running on {device}.
""")
with gr.Row():
extract_button = gr.Button("Extract Features")
clustering_button = gr.Button("Find Optimal Clusters")
cluster_data_button = gr.Button("Perform Clustering")
kaplan_meier_button = gr.Button("Plot Kaplan-Meier")
survival_bar_button = gr.Button("Plot Median Survival Bar")
log_rank_button = gr.Button("Perform Log-Rank Test")
num_clusters = gr.Slider(label="Number of Clusters", minimum=2, maximum=10, step=1, value=5)
min_clusters = gr.Slider(label="Min Clusters for Elbow Method", minimum=2, maximum=10, step=1, value=2)
max_clusters = gr.Slider(label="Max Clusters for Elbow Method", minimum=3, maximum=20, step=1, value=10)
result = gr.Textbox(label="Result")
ga = gr.State()
extract_button.click(fn=run_extract_features, inputs=[], outputs=[ga])
clustering_button.click(fn=run_find_optimal_clusters, inputs=[ga, min_clusters, max_clusters], outputs=[result])
cluster_data_button.click(fn=run_perform_clustering, inputs=[ga, num_clusters], outputs=[result])
kaplan_meier_button.click(fn=run_plot_kaplan_meier, inputs=[ga], outputs=[result])
survival_bar_button.click(fn=run_plot_median_survival_bar, inputs=[ga], outputs=[result])
log_rank_button.click(fn=run_perform_log_rank_test, inputs=[ga], outputs=[result])
demo.queue().launch()
|