AliSaadatV commited on
Commit
cdc3ab7
Β·
verified Β·
1 Parent(s): 1351015

Add comprehensive project README

Browse files
Files changed (1) hide show
  1. README.md +210 -0
README.md ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # GeneSetCLIP: Contrastive Pretraining for Gene Set–Text Alignment
2
+
3
+ A CLIP-style contrastive model that aligns **biological text descriptions** with **gene-set representations**, trained on MSigDB v2024.1 (human + mouse).
4
+
5
+ Given a text query like *"type I interferon signaling"*, the model retrieves the corresponding gene set β€” and vice versa.
6
+
7
+ ## Architecture
8
+
9
+ ```
10
+ TEXT SIDE GENE SET SIDE
11
+ ───────────────────── ──────────────────────────
12
+ "Genes up-regulated in {STAT1, IRF7, ISG15,
13
+ response to IFN-Ξ±..." OAS1, MX1, IFIT1, ...}
14
+ β”‚ β”‚
15
+ β–Ό β–Ό
16
+ BioLORD-2023 (frozen) GSFM (fine-tuned, lr/10)
17
+ [768-dim] [256-dim]
18
+ β”‚ β”‚
19
+ β–Ό β–Ό
20
+ text_proj (trainable) gene_proj (trainable)
21
+ 768 β†’ 512 β†’ 256 256 β†’ 256 β†’ 256
22
+ β”‚ β”‚
23
+ β–Ό β–Ό
24
+ z_text [256] z_gene [256]
25
+ β”‚ β”‚
26
+ └────── L2-normalize β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
27
+ β”‚
28
+ β–Ό
29
+ InfoNCE loss (Ο„ learnable)
30
+ ```
31
+
32
+ ### Components
33
+
34
+ | Component | Model | Dim | Training |
35
+ |-----------|-------|-----|----------|
36
+ | **Gene encoder** | [GSFM](https://huggingface.co/maayanlab/gsfm-rummagene) (MLP autoencoder, Set model) | 256 | Fine-tuned at 1/10 LR |
37
+ | **Text encoder** | [BioLORD-2023](https://huggingface.co/FremyCompany/BioLORD-2023) (MPNet-base) | 768 | Frozen |
38
+ | **Gene projection** | MLP: 256 β†’ 256 β†’ 256 + LayerNorm | 256 | Trained |
39
+ | **Text projection** | MLP: 768 β†’ 512 β†’ 256 + LayerNorm | 256 | Trained |
40
+
41
+ ### Why these encoders?
42
+
43
+ - **GSFM**: Purpose-built gene-set encoder from Ma'ayan Lab. Takes variable-length gene sets as input (multi-hot encoding β†’ MLP), producing permutation-invariant 256-dim embeddings. Pretrained on Rummagene (gene sets from PubMed tables).
44
+ - **BioLORD-2023**: Ontology-grounded biomedical sentence embeddings. Trained on UMLS concept name-synonym pairs + LLM-generated definitions β€” structurally identical to MSigDB gene set descriptions (name + definition anchored in GO/KEGG/Reactome).
45
+
46
+ ## Training Data
47
+
48
+ **MSigDB v2024.1** β€” 50,896 gene set–text pairs from the Molecular Signatures Database.
49
+
50
+ | Split | Collections | Pairs | Purpose |
51
+ |-------|-------------|-------|---------|
52
+ | Train | C2, C5, C8, C1, M2, M5, M8, M1 | 38,622 | Curated, GO, cell type signatures |
53
+ | Val | C3, C4, M3 | 6,766 | Regulatory targets, computational |
54
+ | Test | H, C6, C7, MH | 5,508 | Hallmarks, oncogenic, immunologic |
55
+
56
+ Each pair consists of:
57
+ - **Text**: `[Collection: H] [Species: human]\nHALLMARK APOPTOSIS\nGenes mediating programmed cell death by activation of caspases.`
58
+ - **Genes**: `["CASP3", "CASP6", "TP53", "BAX", ...]`
59
+
60
+ Data augmentation: 20% gene dropout (randomly remove genes each epoch).
61
+
62
+ ## Training Recipe
63
+
64
+ Based on [ProtST](https://arxiv.org/abs/2301.12040) (ICML 2023) adapted for gene sets:
65
+
66
+ | Parameter | Value |
67
+ |-----------|-------|
68
+ | Loss | Symmetric InfoNCE (NT-Xent) |
69
+ | Temperature | 0.07 (learnable, clamped [0.01, 1.0]) |
70
+ | Batch size | 256 |
71
+ | LR (projections) | 1e-4 |
72
+ | LR (gene encoder) | 1e-5 (10x lower) |
73
+ | LR (text encoder) | 0 (frozen) |
74
+ | Optimizer | AdamW (weight_decay=0.01) |
75
+ | Schedule | 500-step warmup β†’ cosine decay |
76
+ | Epochs | 50 (early stopping, patience=10) |
77
+ | Gene dropout | 20% |
78
+ | Max gene set size | 512 |
79
+ | Hardware | 1Γ— T4 GPU (16GB) |
80
+
81
+ ## Quick Start
82
+
83
+ ### Installation
84
+ ```bash
85
+ pip install torch sentence-transformers huggingface_hub safetensors lightning
86
+ GIT_LFS_SKIP_SMUDGE=1 pip install "git+https://huggingface.co/maayanlab/gsfm"
87
+ ```
88
+
89
+ ### Inference
90
+ ```python
91
+ import torch
92
+ import torch.nn as nn
93
+ import torch.nn.functional as F
94
+ from gsfm import GSFM, Vocab
95
+ from sentence_transformers import SentenceTransformer
96
+ from huggingface_hub import hf_hub_download
97
+
98
+ # Load gene encoder + vocab
99
+ gene_encoder = GSFM.from_pretrained("maayanlab/gsfm-rummagene")
100
+ vocab = Vocab.from_pretrained("maayanlab/gsfm-rummagene")
101
+ gene_encoder.eval()
102
+
103
+ # Load text encoder
104
+ text_encoder = SentenceTransformer("FremyCompany/BioLORD-2023")
105
+
106
+ # Load projection heads
107
+ clip_path = hf_hub_download("AliSaadatV/GeneSetCLIP", "clip_model.pt")
108
+
109
+ class ProjectionHead(nn.Module):
110
+ def __init__(self, d_in, d_h, d_out):
111
+ super().__init__()
112
+ self.net = nn.Sequential(
113
+ nn.Linear(d_in, d_h), nn.GELU(), nn.Dropout(0.1),
114
+ nn.Linear(d_h, d_out), nn.LayerNorm(d_out))
115
+ def forward(self, x): return self.net(x)
116
+
117
+ class GeneSetCLIP(nn.Module):
118
+ def __init__(self):
119
+ super().__init__()
120
+ self.log_temperature = nn.Parameter(torch.zeros(1))
121
+ self.text_proj = ProjectionHead(768, 512, 256)
122
+ self.gene_proj = ProjectionHead(256, 256, 256)
123
+
124
+ clip_model = GeneSetCLIP()
125
+ clip_model.load_state_dict(torch.load(clip_path, map_location="cpu", weights_only=True))
126
+ clip_model.eval()
127
+
128
+ # --- Encode a gene set ---
129
+ genes = ["STAT1", "IRF7", "ISG15", "OAS1", "MX1", "IFIT1"]
130
+ gene_ids = torch.tensor([vocab(genes)])
131
+ with torch.no_grad():
132
+ gene_emb = gene_encoder.encode(gene_ids)
133
+ z_gene = F.normalize(clip_model.gene_proj(gene_emb), dim=-1)
134
+
135
+ # --- Encode text queries ---
136
+ queries = [
137
+ "Interferon alpha response genes",
138
+ "Apoptosis signaling",
139
+ "Cell cycle regulation",
140
+ ]
141
+ text_embs = text_encoder.encode(queries, convert_to_tensor=True)
142
+ with torch.no_grad():
143
+ z_text = F.normalize(clip_model.text_proj(text_embs), dim=-1)
144
+
145
+ # --- Compute similarities ---
146
+ sims = (z_gene @ z_text.T).squeeze()
147
+ for q, s in zip(queries, sims):
148
+ print(f" {s.item():.3f} {q}")
149
+ # Expected: highest similarity for "Interferon alpha response genes"
150
+ ```
151
+
152
+ ## Training from Scratch
153
+
154
+ ### 1. Process MSigDB data
155
+ ```bash
156
+ python data_processing.py
157
+ ```
158
+ This downloads all MSigDB GMT files and scrapes descriptions.
159
+
160
+ ### 2. Train
161
+ ```bash
162
+ # Self-contained (downloads data from Hub automatically)
163
+ python train_job.py
164
+
165
+ # Or with local data
166
+ python train.py
167
+ ```
168
+
169
+ ### 3. On HF Jobs (GPU)
170
+ ```python
171
+ from huggingface_hub import HfApi
172
+ # Submit as HF Job with GPU
173
+ # See train_job.py for the self-contained script
174
+ ```
175
+
176
+ ## Downstream Applications
177
+
178
+ 1. **Zero-shot gene set annotation**: Embed a gene list from an experiment β†’ find nearest text descriptions
179
+ 2. **Cross-modal search**: Text query β†’ gene sets, or gene list β†’ pathway descriptions
180
+ 3. **Gene set similarity**: Compare gene sets via embedding cosine similarity (captures functional similarity beyond gene overlap)
181
+ 4. **Cell type annotation**: Embed cell marker gene sets β†’ match to cell type text descriptions
182
+ 5. **Biological RAG**: Use MSigDB embeddings as retrieval corpus for LLM-based reasoning
183
+
184
+ ## Key References
185
+
186
+ - [ProtST](https://arxiv.org/abs/2301.12040) (ICML 2023) β€” Protein-text contrastive alignment
187
+ - [MoleculeSTM](https://arxiv.org/abs/2212.10789) (Nature MI 2024) β€” Molecule-text alignment
188
+ - [LangCell](https://arxiv.org/abs/2405.06708) β€” Cell-text contrastive with MSigDB pathways
189
+ - [BioLORD-2023](https://arxiv.org/abs/2311.16075) (JAMIA 2024) β€” Biomedical sentence embeddings
190
+ - [Set Transformer](https://arxiv.org/abs/1810.00825) β€” Permutation-invariant set encoding
191
+
192
+ ## Files
193
+
194
+ | File | Description |
195
+ |------|-------------|
196
+ | `clip_model.pt` | Trained projection heads (text + gene) |
197
+ | `gene_encoder.pt` | Fine-tuned GSFM gene encoder |
198
+ | `config.json` | Training configuration |
199
+ | `vocab.json` | Gene symbol β†’ token ID mapping |
200
+ | `test_results.json` | Evaluation metrics on test set |
201
+ | `train_job.py` | Self-contained training script (for HF Jobs) |
202
+ | `train.py` | Modular training script |
203
+ | `data_processing.py` | MSigDB data download + processing |
204
+
205
+ ## License
206
+
207
+ - Code: MIT
208
+ - GSFM model: BSD-3-Clause
209
+ - BioLORD-2023: Other (requires UMLS account)
210
+ - MSigDB data: [Creative Commons Attribution 4.0](https://www.gsea-msigdb.org/gsea/msigdb/licenses.jsp)