RCP79 commited on
Commit
578fbf9
·
verified ·
1 Parent(s): 5498f6f

here is the code:

Browse files

#%%
#!/usr/bin/env python3
# Required installations (run these in your environment if not already installed):
# pip install multimolecule umap-learn pacmap plotly torch transformers pandas numpy h5py

import pandas as pd
import numpy as np
import torch
from multimolecule import RnaTokenizer, RnaErnieModel
import umap
import pacmap
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from torch.utils.data import DataLoader, Dataset
import h5py # For saving to HDF5
import os

# =============================================================================
# CONFIGURATION: Customize these parameters for different data frames
# =============================================================================
CONFIG = {
'file_path': '/Users/roger/Desktop/Feina/Colaboracions/Claire/BGI_20250811_data/Analysis/Saved_data/df_merge.h5', # Path to input H5 file
'h5_key': 'df_merge', # Key for the DataFrame in H5 (set to None to auto-detect or use first key)
'seq_col': 'full_sequence', # Column name containing raw sequences (with T instead of U)
'output_seq_col': 'sequence', # New column name for processed sequences (U-replaced)
'filter_query': "(counts_pwt03 > 0) | (counts_pwt04 > 0)", # Pandas query string for filtering (e.g., "col1 > 0 & col2 == 'value'"); set to None for no filter
'from_char': 'T', # Character to replace in sequences (DNA to RNA)
'to_char': 'U', # Replacement character
'batch_size': 512, # Batch size for inference (tune based on seq length and RAM)
'num_workers': 4, # DataLoader workers (increase for faster I/O on multi-core)
'max_length': None, # Max sequence length for tokenization (auto-computed if None)
'model_name': 'multimolecule/rnaernie', # HuggingFace model name (assumes RNAErnie-compatible)
'output_file': 'rna_embeddings_reductions.h5', # Output H5 file path
'umap_neighbors': 15, # UMAP n_neighbors (lower = faster)
'umap_min_dist': 0.1, # UMAP min_dist
'pacmap_iters': 50, # PaCMAP num_iters (lower = faster)
'pacmap_mn_ratio': 0.5 # PaCMAP MN_ratio
}

# =============================================================================
# Custom Dataset Class (Moved to top level to allow pickling for multiprocessing)
# =============================================================================
class SequenceDataset(Dataset):
def __init__(self, sequences, tokenizer, max_length):
self.sequences = sequences
self.tokenizer = tokenizer
self.max_length = max_length

def __len__(self):
return len(self.sequences)

def __getitem__(self, idx):
tokenized = self.tokenizer(
self.sequences[idx],
return_tensors="pt",
padding=True,
truncation=True,
max_length=self.max_length
)
# Squeeze the batch dim (1) to return flat tensors of shape (max_length,)
# This ensures proper collation in DataLoader to (batch_size, max_length)
return {k: v.squeeze(0) for k, v in tokenized.items()}

# =============================================================================
# Load and Preprocess Data
# =============================================================================
def load_and_preprocess(config):
file_path = config['file_path']
h5_key = config['h5_key']

# List keys if needed
with h5py.File(file_path, 'r') as f:
keys = list(f.keys())
print("Keys in the H5 file:", keys)

# Load DataFrame
if h5_key is None:
h5_key = keys[0] if keys else None
df = pd.read_hdf(file_path, key=h5_key)

print("DataFrame shape:", df.shape)
print("Columns:", list(df.columns))

# Create processed sequence column
seq_col = config['seq_col']
output_seq_col = config['output_seq_col']
if seq_col not in df.columns:
raise ValueError(f"Sequence column '{seq_col}' not found in DataFrame.")

df[output_seq_col] = df[seq_col].str.replace(config['from_char'], config['to_char'])

# Filter if query provided
df_filtered = df
if config['filter_query']:
df_filtered = df.query(config['filter_query'])
print(f"Filtered DataFrame shape (query: '{config['filter_query']}'):", df_filtered.shape)

# Extract sequences
sequences = df_filtered[output_seq_col].tolist()

# Compute max length if not provided
max_len = config['max_length']
if max_len is None:
max_len = max(len(seq) for seq in sequences)

print(f"Number of sequences: {len(sequences)}, Max sequence length: {max_len}")

return df, df_filtered, sequences, max_len

# =============================================================================
# Model and Embedding Generation
# =============================================================================
def generate_embeddings(sequences, tokenizer, model, device, batch_size, num_workers, max_length):
# Use the top-level SequenceDataset class
dataset = SequenceDataset(sequences, tokenizer, max_length)
dataloader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers
)

embeddings = []
with torch.no_grad():
for batch in dataloader:
batch = {k: v.to(device) for k, v in batch.items()}
outputs = model(**batch)

# Masked mean pooling
if 'attention_mask' in batch:
mask = batch['attention_mask'].unsqueeze(-1).expand(outputs.last_hidden_state.shape).float()
masked_emb = outputs.last_hidden_state * mask
sum_emb = masked_emb.sum(dim=1)
count_emb = mask.sum(dim=1)
emb = (sum_emb / count_emb).squeeze().cpu().numpy()
else:
emb = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy()

embeddings.extend(emb)

return np.array(embeddings)

# =============================================================================
# Dimensionality Reduction
# =============================================================================
def reduce_dimensions(embeddings, config):
metric = 'cosine'
random_state = 42

# UMAP 2D
umap_2d = umap.UMAP(
n_components=2,
n_neighbors=config['umap_neighbors'],
min_dist=config['umap_min_dist'],
random_state=random_state,
metric=metric
).fit_transform(embeddings)

# UMAP 3D
umap_3d = umap.UMAP(
n_components=3,
n_neighbors=config['umap_neighbors'],
min_dist=config['umap_min_dist'],
random_state=random_state,
metric=metric
).fit_transform(embeddings)

# PaCMAP 2D (init with UMAP)
pac_2d = pacmap.PaCMAP(
n_components=2,
random_state=random_state,
MN_ratio=config['pacmap_mn_ratio'],
num_iters=config['pacmap_iters']
).fit_transform(embeddings, init_low_dim=umap_2d)

# PaCMAP 3D (init with UMAP)
pac_3d = pacmap.PaCMAP(
n_components=3,
random_state=random_state,
MN_ratio=config['pacmap_mn_ratio'],
num_iters=config['pacmap_iters']
).fit_transform(embeddings, init_low_dim=umap_3d)

return {
'umap_2d': umap_2d,
'umap_3d': umap_3d,
'pacmap_2d': pac_2d,
'pacmap_3d': pac_3d
}

# =============================================================================
# Save Outputs
# =============================================================================
def save_outputs(df, df_filtered, embeddings, reductions, output_file):
with h5py.File(output_file, 'w') as f:
# Save full original DataFrame (with processed sequences)
df.to_hdf(f, key='data_full', mode='w')

# Save filtered DataFrame
df_filtered.to_hdf(f, key='data_filtered', mode='r+')

# Save embeddings
f.create_dataset('embeddings', data=embeddings)

# Save reductions
f.create_dataset('umap_2d', data=reductions['umap_2d'])
f.create_dataset('umap_3d', data=reductions['umap_3d'])
f.create_dataset('pacmap_2d', data=reductions['pacmap_2d'])
f.create_dataset('pacmap_3d', data=reductions['pacmap_3d'])

print(f"Saved all data to {output_file}")

# =============================================================================
# Plotting
# =============================================================================
def plot_reductions(reductions):
# 2D comparison
fig_2d = make_subplots(
rows=1, cols=2,
subplot_titles=('UMAP 2D', 'PaCMAP 2D')
)

fig_2d.add_trace(
go.Scatter(
x=reductions['umap_2d'][:, 0],
y=reductions['umap_2d'][:, 1],
mode='markers',
marker=dict(size=5, opacity=0.7),
name='Points'
),
row=1, col=1
)

fig_2d.add_trace(
go.Scatter(
x=reductions['pacmap_2d'][:, 0],
y=reductions['pacmap_2d'][:, 1],
mode='markers',
marker=dict(size=5, opacity=0.7),
name='Points'
),
row=1, col=2
)

fig_2d.update_layout(
title_text='2D Dimensionality Reduction Comparison',
height=500
)
fig_2d.show()

# UMAP 3D
fig_umap_3d = go.Figure(data=go.Scatter3d(
x=reductions['umap_3d'][:, 0],
y=reductions['umap_3d'][:, 1],
z=reductions['umap_3d'][:, 2],
mode='markers',
marker=dict(size=5, opacity=0.7),
name='Points'
))
fig_umap_3d.update_layout(
title='UMAP 3D',
scene=dict(xaxis_title='Dim 1', yaxis_title='Dim 2', zaxis_title='Dim 3'),
height=600
)
fig_umap_3d.show()

# PaCMAP 3D
fig_pac_3d = go.Figure(data=go.Scatter3d(
x=reductions['pacmap_3d'][:, 0],
y=reductions['pacmap_3d'][:, 1],
z=reductions['pacmap_3d'][:, 2],
mode='markers',
marker=dict(size=5, opacity=0.7),
name='Points'
))
fig_pac_3d.update_layout(
title='PaCMAP 3D',

Files changed (2) hide show
  1. analysis.html +103 -0
  2. index.html +24 -11
analysis.html ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>RNA Sequence Analysis</title>
7
+ <script src="https://cdn.tailwindcss.com"></script>
8
+ <script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
9
+ <script src="https://unpkg.com/htmx.org@1.9.6"></script>
10
+ </head>
11
+ <body class="bg-gray-100 min-h-screen">
12
+ <div class="container mx-auto px-4 py-8">
13
+ <header class="mb-8">
14
+ <h1 class="text-3xl font-bold text-gray-800">RNA Sequence Analysis Tool</h1>
15
+ <p class="text-gray-600">Analyze RNA sequences with dimensional reduction techniques</p>
16
+ </header>
17
+
18
+ <div class="grid grid-cols-1 lg:grid-cols-3 gap-8">
19
+ <!-- Configuration Panel -->
20
+ <div class="bg-white rounded-lg shadow p-6 lg:col-span-1">
21
+ <h2 class="text-xl font-semibold mb-4">Configuration</h2>
22
+
23
+ <form id="analysisForm" hx-post="/analyze" hx-target="#results" hx-indicator="#loading">
24
+ <!-- File Input -->
25
+ <div class="mb-4">
26
+ <label class="block text-gray-700 mb-2" for="dataFile">Data File (HDF5)</label>
27
+ <input type="file" id="dataFile" name="dataFile" accept=".h5,.hdf5"
28
+ class="w-full px-3 py-2 border border-gray-300 rounded-md">
29
+ </div>
30
+
31
+ <!-- Model Selection -->
32
+ <div class="mb-4">
33
+ <label class="block text-gray-700 mb-2" for="modelName">RNA Model</label>
34
+ <select id="modelName" name="modelName" class="w-full px-3 py-2 border border-gray-300 rounded-md">
35
+ <option value="multimolecule/rnaernie">RNAErnie (default)</option>
36
+ <option value="multimolecule/rnabert">RNABERT</option>
37
+ <option value="multimolecule/rnadistill">RNADistill</option>
38
+ </select>
39
+ </div>
40
+
41
+ <!-- Sequence Column -->
42
+ <div class="mb-4">
43
+ <label class="block text-gray-700 mb-2" for="seqCol">Sequence Column</label>
44
+ <input type="text" id="seqCol" name="seqCol" value="full_sequence"
45
+ class="w-full px-3 py-2 border border-gray-300 rounded-md">
46
+ </div>
47
+
48
+ <!-- Reduction Method -->
49
+ <div class="mb-4">
50
+ <label class="block text-gray-700 mb-2">Dimensionality Reduction</label>
51
+ <div class="flex items-center mb-2">
52
+ <input type="radio" id="umap" name="reductionMethod" value="umap" checked class="mr-2">
53
+ <label for="umap">UMAP</label>
54
+ </div>
55
+ <div class="flex items-center">
56
+ <input type="radio" id="pacmap" name="reductionMethod" value="pacmap" class="mr-2">
57
+ <label for="pacmap">PaCMAP</label>
58
+ </div>
59
+ </div>
60
+
61
+ <!-- Dimensions -->
62
+ <div class="mb-4">
63
+ <label class="block text-gray-700 mb-2">Dimensions</label>
64
+ <div class="flex items-center mb-2">
65
+ <input type="radio" id="2d" name="dimensions" value="2" checked class="mr-2">
66
+ <label for="2d">2D</label>
67
+ </div>
68
+ <div class="flex items-center">
69
+ <input type="radio" id="3d" name="dimensions" value="3" class="mr-2">
70
+ <label for="3d">3D</label>
71
+ </div>
72
+ </div>
73
+
74
+ <!-- Color By -->
75
+ <div class="mb-6">
76
+ <label class="block text-gray-700 mb-2" for="colorBy">Color By Column (optional)</label>
77
+ <input type="text" id="colorBy" name="colorBy"
78
+ class="w-full px-3 py-2 border border-gray-300 rounded-md">
79
+ </div>
80
+
81
+ <button type="submit" class="w-full bg-blue-600 text-white py-2 px-4 rounded-md hover:bg-blue-700 transition">
82
+ Analyze Sequences
83
+ </button>
84
+ </form>
85
+ </div>
86
+
87
+ <!-- Results Panel -->
88
+ <div class="bg-white rounded-lg shadow p-6 lg:col-span-2">
89
+ <h2 class="text-xl font-semibold mb-4">Results</h2>
90
+ <div id="loading" class="htmx-indicator text-center py-8">
91
+ <div class="animate-spin rounded-full h-12 w-12 border-b-2 border-blue-600 mx-auto mb-4"></div>
92
+ <p class="text-gray-600">Processing data...</p>
93
+ </div>
94
+ <div id="results" class="min-h-64">
95
+ <div class="text-center text-gray-500 py-12">
96
+ <p>Configure and run analysis to see results</p>
97
+ </div>
98
+ </div>
99
+ </div>
100
+ </div>
101
+ </div>
102
+ </body>
103
+ </html>
index.html CHANGED
@@ -4,16 +4,29 @@
4
  <meta charset="utf-8" />
5
  <meta name="viewport" content="width=device-width" />
6
  <title>My static Space</title>
7
- <link rel="stylesheet" href="style.css" />
8
- </head>
 
 
 
 
 
 
 
9
  <body>
10
- <div class="card">
11
- <h1>Welcome to your static Space!</h1>
12
- <p>You can modify this app directly by editing <i>index.html</i> in the Files and versions tab.</p>
13
- <p>
14
- Also don't forget to check the
15
- <a href="https://huggingface.co/docs/hub/spaces" target="_blank">Spaces documentation</a>.
16
- </p>
17
- </div>
18
- </body>
 
 
 
 
 
 
19
  </html>
 
4
  <meta charset="utf-8" />
5
  <meta name="viewport" content="width=device-width" />
6
  <title>My static Space</title>
7
+ <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.0/css/all.min.css">
8
+ <script src="https://cdn.tailwindcss.com"></script>
9
+ <style>
10
+ body {
11
+ background-color: #f7fafc;
12
+ font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
13
+ }
14
+ </style>
15
+ </head>
16
  <body>
17
+ <div class="max-w-2xl mx-auto p-8 bg-white rounded-lg shadow">
18
+ <h1 class="text-3xl font-bold text-gray-800 mb-4">RNA Sequence Analysis</h1>
19
+ <p class="text-gray-600 mb-6">
20
+ Analyze RNA sequences using state-of-the-art models and dimensional reduction techniques.
21
+ </p>
22
+ <div class="space-y-4">
23
+ <a href="/analysis.html" class="block w-full bg-blue-600 hover:bg-blue-700 text-white font-medium py-3 px-6 rounded-lg text-center transition">
24
+ Launch Analysis Tool
25
+ </a>
26
+ <p class="text-sm text-gray-500">
27
+ Requires Python dependencies: multimolecule, umap-learn, pacmap, plotly, torch, transformers, pandas, numpy, h5py
28
+ </p>
29
+ </div>
30
+ </div>
31
+ </body>
32
  </html>