File size: 4,082 Bytes
dc549dd
ce0e9f1
 
648ca95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ce0e9f1
36dd86e
ce0e9f1
648ca95
f0e9ca6
648ca95
 
 
f0e9ca6
36dd86e
648ca95
 
 
 
 
 
9adde00
648ca95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ce0e9f1
 
 
 
 
 
dc549dd
36dd86e
 
 
dc549dd
36dd86e
 
 
dc549dd
36dd86e
 
 
dc549dd
36dd86e
 
 
dc549dd
36dd86e
 
 
dc549dd
36dd86e
 
 
ce0e9f1
 
 
648ca95
 
ce0e9f1
648ca95
ce0e9f1
648ca95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36dd86e
 
 
 
 
 
648ca95
30724c3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import spaces
import gradio as gr
import torch
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.cluster import KMeans
from lifelines import KaplanMeierFitter
from yellowbrick.cluster import KElbowVisualizer
from itertools import combinations
from lifelines.statistics import logrank_test

from MultiOmicsGraphAttentionAutoencoderModel import MultiOmicsGraphAttentionAutoencoderModel
from OmicsConfig import OmicsConfig
from Attention_Extracter import Attention_Extracter
from GraphAnalysis import GraphAnalysis

device = torch.device('cuda')

# Load the autoencoder model
autoencoder_config = OmicsConfig.from_pretrained("./lc_models/MultiOmicsAutoencoder/trained_autoencoder")
autoencoder_model = MultiOmicsGraphAttentionAutoencoderModel(autoencoder_config).to(device)

# Initialize Attention Extracter
graph_data_dict_path = './data/hnscc.patient.chg.network.pth'
extracter = Attention_Extracter(graph_data_dict_path, autoencoder_model.encoder, gpu=False)

def extract_features():
    ga = GraphAnalysis(extracter)
    return ga

def find_optimal_clusters(ga, min_clusters, max_clusters):
    ga.find_optimal_clusters(min_clusters=min_clusters, max_clusters=max_clusters, save_path='./temp')
    return ga.optimal_clusters

def perform_clustering(ga, num_clusters):
    ga.cluster_data2(num_clusters)
    return "Clustering completed."

def plot_kaplan_meier(ga):
    ga.plot_kaplan_meier()
    return "Kaplan-Meier plot saved."

def plot_median_survival_bar(ga):
    ga.plot_median_survival_bar(name='temp')
    return "Median survival bar plot saved."

def perform_log_rank_test(ga):
    significant_pairs = ga.perform_log_rank_test()
    return f"Significant pairs from log-rank test: {significant_pairs}"

css = """
#col-container {
    margin: 0 auto;
    max-width: 520px;
}
"""

@spaces.GPU
def run_extract_features():
    return extract_features()

@spaces.GPU
def run_find_optimal_clusters(ga, min_clusters, max_clusters):
    return find_optimal_clusters(ga, min_clusters, max_clusters)

@spaces.GPU
def run_perform_clustering(ga, num_clusters):
    return perform_clustering(ga, num_clusters)

@spaces.GPU
def run_plot_kaplan_meier(ga):
    return plot_kaplan_meier(ga)

@spaces.GPU
def run_plot_median_survival_bar(ga):
    return plot_median_survival_bar(ga)

@spaces.GPU
def run_perform_log_rank_test(ga):
    return perform_log_rank_test(ga)

with gr.Blocks(css=css) as demo:
    with gr.Column(elem_id="col-container"):
        gr.Markdown(f"""
        # Graph to Features and Analysis
        Currently running on {device}.
        """)

        with gr.Row():
            extract_button = gr.Button("Extract Features")
            clustering_button = gr.Button("Find Optimal Clusters")
            cluster_data_button = gr.Button("Perform Clustering")
            kaplan_meier_button = gr.Button("Plot Kaplan-Meier")
            survival_bar_button = gr.Button("Plot Median Survival Bar")
            log_rank_button = gr.Button("Perform Log-Rank Test")

        num_clusters = gr.Slider(label="Number of Clusters", minimum=2, maximum=10, step=1, value=5)
        min_clusters = gr.Slider(label="Min Clusters for Elbow Method", minimum=2, maximum=10, step=1, value=2)
        max_clusters = gr.Slider(label="Max Clusters for Elbow Method", minimum=3, maximum=20, step=1, value=10)

        result = gr.Textbox(label="Result")

    ga = gr.State()

    extract_button.click(fn=run_extract_features, inputs=[], outputs=[ga])
    clustering_button.click(fn=run_find_optimal_clusters, inputs=[ga, min_clusters, max_clusters], outputs=[result])
    cluster_data_button.click(fn=run_perform_clustering, inputs=[ga, num_clusters], outputs=[result])
    kaplan_meier_button.click(fn=run_plot_kaplan_meier, inputs=[ga], outputs=[result])
    survival_bar_button.click(fn=run_plot_median_survival_bar, inputs=[ga], outputs=[result])
    log_rank_button.click(fn=run_perform_log_rank_test, inputs=[ga], outputs=[result])

demo.queue().launch()