Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
|
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import re
|
| 3 |
+
import numpy as np
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import gradio as gr
|
| 6 |
+
import plotly.graph_objects as go
|
| 7 |
+
from plotly.subplots import make_subplots
|
| 8 |
+
from collections import defaultdict
|
| 9 |
+
from huggingface_hub import login
|
| 10 |
+
from datasets import load_dataset
|
| 11 |
+
|
| 12 |
+
def hex_to_rgba(hex_color, alpha=1):
|
| 13 |
+
hex_color = hex_color.lstrip('#')
|
| 14 |
+
r, g, b = tuple(int(hex_color[i:i+2], 16) for i in (0, 2, 4))
|
| 15 |
+
return f"rgba({r},{g},{b},{alpha})"
|
| 16 |
+
|
| 17 |
+
def create_sankey(df):
|
| 18 |
+
# Data preprocessing
|
| 19 |
+
df['Aberration'] = df['Aberration'].replace({
|
| 20 |
+
r'Deletion.*': 'Deletion',
|
| 21 |
+
r'Duplication.*': 'Duplication'
|
| 22 |
+
}, regex=True)
|
| 23 |
+
|
| 24 |
+
stages = ["Transferred", "Pos_HCG", "Heart_action", "Birth", "Aberration", "Aff_Chrom", "Par_Or", "Seg_Or"]
|
| 25 |
+
stage_labels = ["Transfer", "Implantation", "Pregnancy", "Live Birth", "Aberration", "Chromosome", "Parental Origin", "Segregation Origin"]
|
| 26 |
+
|
| 27 |
+
category_orders = {
|
| 28 |
+
"Transferred": ["Y", "N", "NA"],
|
| 29 |
+
"Pos_HCG": ["Y", "N", "NA"],
|
| 30 |
+
"Heart_action": ["Y", "N", "NA"],
|
| 31 |
+
"Birth": ["Y", "N", "NA"],
|
| 32 |
+
"Aberration": ["Monosomy", "Trisomy", "Tetraploidy", "Pentasomy", "Mixoploidy", "UPD", "Duplication", "Deletion"],
|
| 33 |
+
"Par_Or": ["Maternal", "Paternal", "Both", "ND"],
|
| 34 |
+
"Seg_Or": ["Meiotic I", "Meiotic II", "Mitotic", "ND"],
|
| 35 |
+
"Aff_Chrom": [str(i) for i in range(1, 23)] + ["X", "GW"]
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
colors = {
|
| 39 |
+
"Transferred": {"Y": "#ffe3c5", "N": "#b0cdc0"},
|
| 40 |
+
"Pos_HCG": {"Y": "#ffe3c5", "N": "#8a9991"},
|
| 41 |
+
"Heart_action": {"Y": "#ffe3c5", "N": "#8a9991"},
|
| 42 |
+
"Birth": {"Y": "#ffe3c5", "N": "#8a9991"},
|
| 43 |
+
"Aberration": {
|
| 44 |
+
"Monosomy": "#FFFACD", "Trisomy": "#ffc966", "Tetraploidy": "#ffc966",
|
| 45 |
+
"Pentasomy": "#ffc966", "Mixoploidy": "#ffc966", "UPD": "#d0b783",
|
| 46 |
+
"Duplication": "#ffc966", "Deletion": "#FFFACD"
|
| 47 |
+
},
|
| 48 |
+
"Par_Or": {"Maternal": "#C77CFF", "Paternal": "#F37735", "Both": "#993300", "ND": "#b48585"},
|
| 49 |
+
"Seg_Or": {"Meiotic I": "#9ECAE1", "Meiotic II": "#003366", "Mitotic": "#3182BD", "ND": "#b48585"},
|
| 50 |
+
"Aff_Chrom": {str(i).zfill(2): f"#{hex(i*123456%0xFFFFFF)[2:].zfill(6)}99" for i in range(1, 23)} |
|
| 51 |
+
{"X": "#66666699", "GW": "#CCCCCC99"}
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
# Create nodes
|
| 55 |
+
nodes = {f"{stage}_{value}": idx for idx, (stage, values) in enumerate(
|
| 56 |
+
[(stage, category_orders.get(stage, sorted(df[stage].unique())))
|
| 57 |
+
for stage in stages]
|
| 58 |
+
) for value in values}
|
| 59 |
+
|
| 60 |
+
# Create links
|
| 61 |
+
links = []
|
| 62 |
+
for i in range(len(stages)-1):
|
| 63 |
+
grouped = df.groupby([stages[i], stages[i+1]]).size().reset_index(name='count')
|
| 64 |
+
links.extend({
|
| 65 |
+
"source": nodes[f"{stages[i]}_{row[stages[i]]}"],
|
| 66 |
+
"target": nodes[f"{stages[i+1]}_{row[stages[i+1]]}"],
|
| 67 |
+
"value": row['count']
|
| 68 |
+
} for _, row in grouped.iterrows())
|
| 69 |
+
|
| 70 |
+
# Create node attributes
|
| 71 |
+
node_labels = []
|
| 72 |
+
node_colors = []
|
| 73 |
+
node_x = []
|
| 74 |
+
node_y = []
|
| 75 |
+
|
| 76 |
+
for node_key, idx in nodes.items():
|
| 77 |
+
stage, value = node_key.split('_', 1)
|
| 78 |
+
stage_idx = stages.index(stage)
|
| 79 |
+
|
| 80 |
+
label = "Yes" if value == 'Y' else "No" if value == 'N' else value
|
| 81 |
+
node_labels.append(label)
|
| 82 |
+
|
| 83 |
+
node_colors.append(hex_to_rgba(colors.get(stage, {}).get(value, "#A2A2A2")))
|
| 84 |
+
|
| 85 |
+
node_x.append(stage_idx / (len(stages) - 1) * 0.9 + 0.05)
|
| 86 |
+
|
| 87 |
+
stage_values = category_orders.get(stage, [])
|
| 88 |
+
y_pos = 0.01 + (stage_values.index(value) + 0.5) * 0.98 / max(len(stage_values), 1)
|
| 89 |
+
node_y.append(y_pos)
|
| 90 |
+
|
| 91 |
+
# Create link colors
|
| 92 |
+
link_colors = [f"rgba({node_colors[link['source']].split('(')[1].split(',')[:3][0]},{node_colors[link['source']].split(',')[1:3][0]},{node_colors[link['source']].split(',')[2].split(',')[0]},0.4)"
|
| 93 |
+
for link in links]
|
| 94 |
+
|
| 95 |
+
# Calculate percentages
|
| 96 |
+
source_totals = defaultdict(int)
|
| 97 |
+
for link in links:
|
| 98 |
+
source_totals[link["source"]] += link["value"]
|
| 99 |
+
link_customdata = [[link["value"] / source_totals[link["source"]] * 100] for link in links]
|
| 100 |
+
|
| 101 |
+
# Create sankey trace
|
| 102 |
+
sankey_trace = go.Sankey(
|
| 103 |
+
arrangement="snap",
|
| 104 |
+
node=dict(
|
| 105 |
+
pad=20,
|
| 106 |
+
thickness=40,
|
| 107 |
+
line=dict(color="black", width=0.8),
|
| 108 |
+
label=node_labels,
|
| 109 |
+
color=node_colors,
|
| 110 |
+
x=node_x,
|
| 111 |
+
y=node_y,
|
| 112 |
+
hovertemplate='%{label}<extra></extra>'
|
| 113 |
+
),
|
| 114 |
+
link=dict(
|
| 115 |
+
source=[link["source"] for link in links],
|
| 116 |
+
target=[link["target"] for link in links],
|
| 117 |
+
value=[link["value"] for link in links],
|
| 118 |
+
color=link_colors,
|
| 119 |
+
customdata=link_customdata,
|
| 120 |
+
hovertemplate='%{source.label} → %{target.label}<br>Count: %{value} (%{customdata[0]:.2f}%)<extra></extra>'
|
| 121 |
+
)
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
# Create annotations
|
| 125 |
+
annotations = [
|
| 126 |
+
dict(
|
| 127 |
+
x=i / (len(stage_labels) - 1) * 0.88 + 0.068,
|
| 128 |
+
y=1.02,
|
| 129 |
+
text=label,
|
| 130 |
+
showarrow=False,
|
| 131 |
+
font=dict(size=18, family="Arial"),
|
| 132 |
+
xanchor="center"
|
| 133 |
+
) for i, label in enumerate(stage_labels)
|
| 134 |
+
]
|
| 135 |
+
|
| 136 |
+
# Create subplot
|
| 137 |
+
fig = make_subplots(
|
| 138 |
+
rows=3, cols=1,
|
| 139 |
+
row_heights=[0.3, 0.35, 0.35],
|
| 140 |
+
specs=[[{"type": "sankey"}], [{"type": "sunburst"}], [{"type": "treemap"}]],
|
| 141 |
+
vertical_spacing=0.1
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
# Add traces
|
| 145 |
+
sunburst_trace, treemap_trace = build_sunburst_trace(df)
|
| 146 |
+
|
| 147 |
+
fig.add_trace(sankey_trace, row=1, col=1)
|
| 148 |
+
fig.add_trace(sunburst_trace, row=2, col=1)
|
| 149 |
+
fig.add_trace(treemap_trace, row=3, col=1)
|
| 150 |
+
|
| 151 |
+
# Update layout
|
| 152 |
+
fig.update_layout(
|
| 153 |
+
title=dict(text="Embryo Aberration Analysis (PGT-AO Study)",
|
| 154 |
+
font=dict(size=26, family="Arial"),
|
| 155 |
+
x=0.5, y=0.98),
|
| 156 |
+
width=1700,
|
| 157 |
+
height=3000,
|
| 158 |
+
font=dict(family="Arial", size=14),
|
| 159 |
+
margin=dict(l=200, r=200, t=150, b=200),
|
| 160 |
+
paper_bgcolor="white",
|
| 161 |
+
plot_bgcolor="white",
|
| 162 |
+
annotations=annotations
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
return fig
|
| 166 |
+
|
| 167 |
+
# [Rest of the functions remain largely the same but with minor optimizations]
|
| 168 |
+
|
| 169 |
+
def clean_data(df):
|
| 170 |
+
df = df.copy()
|
| 171 |
+
df['Aff_Chrom'] = df['Aff_Chrom'].apply(
|
| 172 |
+
lambda x: 'GW' if str(x) == 'GW' else
|
| 173 |
+
f'Chr{x}' if str(x) in ['X', 'Y'] else
|
| 174 |
+
f'Chr{str(x).zfill(2)}')
|
| 175 |
+
df = df.replace('NA', '')
|
| 176 |
+
df.loc[(df['Transferred'] == 'Y') & (df['Pos_HCG'] == 'N'), ['Heart_action', 'Birth']] = 'Not Applicable'
|
| 177 |
+
df['Transferred'] = df['Transferred'].fillna('N').replace({'N (deg)': 'N'})
|
| 178 |
+
df['Transferred'] = df['Transferred'].apply(lambda x: 'N' if x not in ['Y', 'N'] else x)
|
| 179 |
+
df[['Pos_HCG', 'Heart_action', 'Birth']] = df[['Pos_HCG', 'Heart_action', 'Birth']].fillna('Not Applicable')
|
| 180 |
+
df.loc[(df['Pos_HCG'] == 'Y') & (df['Heart_action'] == 'N'), 'Birth'] = 'No'
|
| 181 |
+
df['Par_Or'] = df['Par_Or'].replace('ND', 'Not Determinable')
|
| 182 |
+
df['Percentage'] = df['Percentage'].replace('Options not set correctly', np.nan)
|
| 183 |
+
return df
|
| 184 |
+
|
| 185 |
+
def create_footer():
|
| 186 |
+
return """
|
| 187 |
+
<div style="
|
| 188 |
+
position: fixed; bottom: 0; left: 0; width: 100%; padding: 15px;
|
| 189 |
+
background-color: #f8fafc; z-index: 1000; display: flex;
|
| 190 |
+
align-items: center; justify-content: center; flex-wrap: wrap; gap: 10px;">
|
| 191 |
+
<div style="text-align: center;">
|
| 192 |
+
<p style="margin: 5px 0; color: #4b5563;">© 2025 CGM</p>
|
| 193 |
+
<p style="margin: 0;">
|
| 194 |
+
<a href="https://www.zamanilab.org/" target="_blank" style="color: #1d4ed8; text-decoration: none; font-weight: 500;">Website</a> |
|
| 195 |
+
<a href="mailto:masoud.zamaniesteki@mumc.nl" style="color: #1d4ed8; text-decoration: none; font-weight: 500;">Contact</a>
|
| 196 |
+
</p>
|
| 197 |
+
</div>
|
| 198 |
+
</div>
|
| 199 |
+
"""
|
| 200 |
+
|
| 201 |
+
def main():
|
| 202 |
+
login(token=os.environ.get("hf_api_key"))
|
| 203 |
+
dataset = load_dataset("CellularGenomicMedicine/pgt-ao", data_files="202507_PGTAO.csv", token=True)
|
| 204 |
+
df = dataset['train'].to_pandas()
|
| 205 |
+
|
| 206 |
+
with gr.Blocks(theme=gr.themes.Soft(), css="...") as demo:
|
| 207 |
+
gr.Markdown('<h1 class="main-title">🧬 PGT-AO Study Data Visualization Dashboard</h1>')
|
| 208 |
+
gr.Markdown("...")
|
| 209 |
+
with gr.Group():
|
| 210 |
+
gr.Markdown('<div class="chart-title">🌳 Hierarchical Treemap</div>')
|
| 211 |
+
gr.Plot(create_sankey(df), show_label=False)
|
| 212 |
+
with gr.Group():
|
| 213 |
+
gr.Markdown('<div class="chart-title">💡 Key Insights</div>')
|
| 214 |
+
gr.Markdown("...")
|
| 215 |
+
gr.HTML(create_footer())
|
| 216 |
+
demo.launch(show_error=True)
|
| 217 |
+
|
| 218 |
+
if __name__ == "__main__":
|
| 219 |
+
main()
|