AliSaadatV commited on
Commit
06fa7f1
·
verified ·
1 Parent(s): c17e8fd

Add data processing pipeline

Browse files
Files changed (1) hide show
  1. data_processing.py +164 -0
data_processing.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MSigDB Data Processing Pipeline for Contrastive Pretraining.
3
+
4
+ Strategy:
5
+ 1. Download full GMT files from Broad data server (no auth needed)
6
+ 2. Fetch brief descriptions from MSigDB HTML card pages (no auth needed)
7
+ 3. Build text-gene paired dataset
8
+
9
+ Usage:
10
+ python data_processing.py
11
+ """
12
+
13
+ import json
14
+ import os
15
+ import time
16
+ import html
17
+ import re
18
+ import random
19
+ import concurrent.futures
20
+ from collections import defaultdict
21
+ import requests
22
+
23
+ BROAD_BASE = "https://data.broadinstitute.org/gsea-msigdb/msigdb/release"
24
+ VERSION = "2024.1"
25
+
26
+ HUMAN_GMTS = {"H": "h.all", "C1": "c1.all", "C2": "c2.all", "C3": "c3.all",
27
+ "C4": "c4.all", "C5": "c5.all", "C6": "c6.all", "C7": "c7.all", "C8": "c8.all"}
28
+ MOUSE_GMTS = {"MH": "mh.all", "M1": "m1.all", "M2": "m2.all", "M3": "m3.all",
29
+ "M5": "m5.all", "M8": "m8.all"}
30
+
31
+
32
+ def download_gmt(collection_code, species="Hs", version=VERSION):
33
+ gmts = HUMAN_GMTS if species == "Hs" else MOUSE_GMTS
34
+ prefix = gmts.get(collection_code)
35
+ if not prefix:
36
+ return []
37
+ filename = f"{prefix}.v{version}.{species}.symbols.gmt"
38
+ url = f"{BROAD_BASE}/{version}.{species}/{filename}"
39
+ try:
40
+ resp = requests.get(url, timeout=60)
41
+ resp.raise_for_status()
42
+ except Exception as e:
43
+ print(f" Warning: {url}: {e}")
44
+ return []
45
+ gene_sets = []
46
+ for line in resp.text.strip().split("\\n"):
47
+ parts = line.split("\\t")
48
+ if len(parts) < 3:
49
+ continue
50
+ gene_sets.append({
51
+ "name": parts[0].strip(), "url": parts[1].strip(),
52
+ "genes": [g.strip() for g in parts[2:] if g.strip()],
53
+ "collection": collection_code,
54
+ "species": "human" if species == "Hs" else "mouse",
55
+ })
56
+ return gene_sets
57
+
58
+
59
+ def download_all_gmts(output_dir="data/raw"):
60
+ os.makedirs(output_dir, exist_ok=True)
61
+ all_gs = []
62
+ print("Downloading human gene sets...")
63
+ for code in HUMAN_GMTS:
64
+ gs = download_gmt(code, "Hs")
65
+ all_gs.extend(gs)
66
+ print(f" {code}: {len(gs)}")
67
+ print("Downloading mouse gene sets...")
68
+ for code in MOUSE_GMTS:
69
+ gs = download_gmt(code, "Mm")
70
+ all_gs.extend(gs)
71
+ print(f" {code}: {len(gs)}")
72
+ with open(os.path.join(output_dir, "all_gmt_genesets.json"), "w") as f:
73
+ json.dump(all_gs, f)
74
+ print(f"Total: {len(all_gs)}")
75
+ return all_gs
76
+
77
+
78
+ def fetch_description_html(name, species="human"):
79
+ url = f"https://www.gsea-msigdb.org/gsea/msigdb/{species}/geneset/{name}"
80
+ try:
81
+ resp = requests.get(url, timeout=15)
82
+ resp.raise_for_status()
83
+ match = re.findall(r'Brief\\s+description.*?<td[^>]*>(.*?)</td>', resp.text, re.DOTALL | re.IGNORECASE)
84
+ if match:
85
+ desc = re.sub(r'<[^>]+>', '', match[0]).strip()
86
+ desc = html.unescape(desc)
87
+ if desc and desc.lower() not in ["na", "n/a"]:
88
+ return desc
89
+ except Exception:
90
+ pass
91
+ return ""
92
+
93
+
94
+ def fetch_descriptions_batch(gene_sets, max_workers=10, cache_path="data/raw/descriptions_cache.json"):
95
+ cache = {}
96
+ if os.path.exists(cache_path):
97
+ with open(cache_path) as f:
98
+ cache = json.load(f)
99
+ to_fetch = [(gs["name"], gs["species"]) for gs in gene_sets if gs["name"] not in cache]
100
+ print(f"Need to fetch {len(to_fetch)} descriptions ({len(cache)} cached)")
101
+ if to_fetch:
102
+ fetched = 0
103
+ with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as ex:
104
+ futures = {ex.submit(fetch_description_html, n, s): (n, s) for n, s in to_fetch}
105
+ for f in concurrent.futures.as_completed(futures):
106
+ name, _ = futures[f]
107
+ cache[name] = f.result()
108
+ fetched += 1
109
+ if fetched % 500 == 0:
110
+ with open(cache_path, "w") as fp:
111
+ json.dump(cache, fp)
112
+ print(f" {fetched}/{len(to_fetch)}")
113
+ with open(cache_path, "w") as f:
114
+ json.dump(cache, f)
115
+ return cache
116
+
117
+
118
+ def build_pairs(gene_sets, descriptions, min_genes=5, max_genes=2000):
119
+ pairs = []
120
+ for gs in gene_sets:
121
+ genes = gs["genes"]
122
+ if len(genes) < min_genes or len(genes) > max_genes:
123
+ continue
124
+ desc = descriptions.get(gs["name"], "")
125
+ parts = [f"[Collection: {gs['collection']}] [Species: {gs['species']}]",
126
+ gs["name"].replace("_", " ")]
127
+ if desc:
128
+ parts.append(html.unescape(re.sub(r'<[^>]+>', ' ', desc)).strip())
129
+ text = "\\n".join(parts)
130
+ if len(text) < 30:
131
+ continue
132
+ pairs.append({"id": gs["name"], "text": text, "genes": genes,
133
+ "n_genes": len(genes), "collection": gs["collection"],
134
+ "species": gs["species"], "has_description": bool(desc)})
135
+ return pairs
136
+
137
+
138
+ def split_and_save(pairs, output_dir="data/processed"):
139
+ os.makedirs(output_dir, exist_ok=True)
140
+ train_cols = {"C2", "C5", "C8", "C1", "M2", "M5", "M8", "M1"}
141
+ val_cols = {"C3", "C4", "M3"}
142
+ test_cols = {"H", "C6", "C7", "MH"}
143
+ splits = {"train": [], "val": [], "test": []}
144
+ for p in pairs:
145
+ c = p["collection"]
146
+ if c in train_cols: splits["train"].append(p)
147
+ elif c in val_cols: splits["val"].append(p)
148
+ elif c in test_cols: splits["test"].append(p)
149
+ else: splits["train"].append(p)
150
+ for name, data in splits.items():
151
+ path = os.path.join(output_dir, f"{name}.jsonl")
152
+ with open(path, "w") as f:
153
+ for r in data:
154
+ f.write(json.dumps(r) + "\\n")
155
+ print(f"{name}: {len(data)} pairs -> {path}")
156
+ return splits
157
+
158
+
159
+ if __name__ == "__main__":
160
+ all_gs = download_all_gmts()
161
+ descs = fetch_descriptions_batch(all_gs)
162
+ pairs = build_pairs(all_gs, descs)
163
+ print(f"\\nTotal pairs: {len(pairs)}")
164
+ split_and_save(pairs)