VatsalPatel18's picture
Update app.py
dc549dd verified
import spaces
import gradio as gr
import torch
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.cluster import KMeans
from lifelines import KaplanMeierFitter
from yellowbrick.cluster import KElbowVisualizer
from itertools import combinations
from lifelines.statistics import logrank_test
from MultiOmicsGraphAttentionAutoencoderModel import MultiOmicsGraphAttentionAutoencoderModel
from OmicsConfig import OmicsConfig
from Attention_Extracter import Attention_Extracter
from GraphAnalysis import GraphAnalysis
device = torch.device('cuda')
# Load the autoencoder model
autoencoder_config = OmicsConfig.from_pretrained("./lc_models/MultiOmicsAutoencoder/trained_autoencoder")
autoencoder_model = MultiOmicsGraphAttentionAutoencoderModel(autoencoder_config).to(device)
# Initialize Attention Extracter
graph_data_dict_path = './data/hnscc.patient.chg.network.pth'
extracter = Attention_Extracter(graph_data_dict_path, autoencoder_model.encoder, gpu=False)
def extract_features():
ga = GraphAnalysis(extracter)
return ga
def find_optimal_clusters(ga, min_clusters, max_clusters):
ga.find_optimal_clusters(min_clusters=min_clusters, max_clusters=max_clusters, save_path='./temp')
return ga.optimal_clusters
def perform_clustering(ga, num_clusters):
ga.cluster_data2(num_clusters)
return "Clustering completed."
def plot_kaplan_meier(ga):
ga.plot_kaplan_meier()
return "Kaplan-Meier plot saved."
def plot_median_survival_bar(ga):
ga.plot_median_survival_bar(name='temp')
return "Median survival bar plot saved."
def perform_log_rank_test(ga):
significant_pairs = ga.perform_log_rank_test()
return f"Significant pairs from log-rank test: {significant_pairs}"
css = """
#col-container {
margin: 0 auto;
max-width: 520px;
}
"""
@spaces.GPU
def run_extract_features():
return extract_features()
@spaces.GPU
def run_find_optimal_clusters(ga, min_clusters, max_clusters):
return find_optimal_clusters(ga, min_clusters, max_clusters)
@spaces.GPU
def run_perform_clustering(ga, num_clusters):
return perform_clustering(ga, num_clusters)
@spaces.GPU
def run_plot_kaplan_meier(ga):
return plot_kaplan_meier(ga)
@spaces.GPU
def run_plot_median_survival_bar(ga):
return plot_median_survival_bar(ga)
@spaces.GPU
def run_perform_log_rank_test(ga):
return perform_log_rank_test(ga)
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(f"""
# Graph to Features and Analysis
Currently running on {device}.
""")
with gr.Row():
extract_button = gr.Button("Extract Features")
clustering_button = gr.Button("Find Optimal Clusters")
cluster_data_button = gr.Button("Perform Clustering")
kaplan_meier_button = gr.Button("Plot Kaplan-Meier")
survival_bar_button = gr.Button("Plot Median Survival Bar")
log_rank_button = gr.Button("Perform Log-Rank Test")
num_clusters = gr.Slider(label="Number of Clusters", minimum=2, maximum=10, step=1, value=5)
min_clusters = gr.Slider(label="Min Clusters for Elbow Method", minimum=2, maximum=10, step=1, value=2)
max_clusters = gr.Slider(label="Max Clusters for Elbow Method", minimum=3, maximum=20, step=1, value=10)
result = gr.Textbox(label="Result")
ga = gr.State()
extract_button.click(fn=run_extract_features, inputs=[], outputs=[ga])
clustering_button.click(fn=run_find_optimal_clusters, inputs=[ga, min_clusters, max_clusters], outputs=[result])
cluster_data_button.click(fn=run_perform_clustering, inputs=[ga, num_clusters], outputs=[result])
kaplan_meier_button.click(fn=run_plot_kaplan_meier, inputs=[ga], outputs=[result])
survival_bar_button.click(fn=run_plot_median_survival_bar, inputs=[ga], outputs=[result])
log_rank_button.click(fn=run_perform_log_rank_test, inputs=[ga], outputs=[result])
demo.queue().launch()