gbyuvd commited on
Commit
deb45f3
·
verified ·
1 Parent(s): 331cbeb

Update train/trainbarlow.py

Browse files
Files changed (1) hide show
  1. train/trainbarlow.py +352 -348
train/trainbarlow.py CHANGED
@@ -1,349 +1,353 @@
1
- #!/usr/bin/env python3
2
- """
3
- Self-Supervised Training for Molecular Representations (SMILES)
4
-
5
- Usage:
6
- python trainbarlow.py --config config.yaml
7
- """
8
- print("Initializing ...")
9
- import os
10
- import json
11
- import argparse
12
- import random
13
- from pathlib import Path
14
- from typing import Dict, Any, Tuple, List
15
-
16
- import numpy as np
17
- import pandas as pd
18
- import torch
19
- import torch.nn as nn
20
- from torch.utils.data import DataLoader
21
- from tqdm.auto import tqdm
22
- from sklearn.metrics.pairwise import cosine_similarity
23
- from sklearn.preprocessing import normalize
24
-
25
- # Suppress RDKit warnings
26
- from rdkit import RDLogger
27
- RDLogger.DisableLog('rdApp.*')
28
-
29
- try:
30
- from rdkit.Chem import MolFromSmiles, MolToSmiles, AllChem
31
- from rdkit import DataStructs
32
- except ImportError:
33
- raise ImportError("RDKit is required. Install with: conda install -c conda-forge rdkit")
34
-
35
- try:
36
- from sentence_transformers import SentenceTransformer, InputExample
37
- except ImportError:
38
- raise ImportError("Install sentence-transformers: pip install sentence-transformers")
39
-
40
-
41
- # ======================
42
- # Projector
43
- # ======================
44
-
45
- class BarlowTwinsProjector(nn.Module):
46
- """Projector with BatchNorm (for Barlow Twins)."""
47
- def __init__(self, in_dim: int, hidden_dim: int = 2048, out_dim: int = 2048):
48
- super().__init__()
49
- self.layers = nn.Sequential(
50
- nn.Linear(in_dim, hidden_dim, bias=False),
51
- nn.BatchNorm1d(hidden_dim),
52
- nn.ReLU(),
53
- nn.Linear(hidden_dim, hidden_dim, bias=False),
54
- nn.BatchNorm1d(hidden_dim),
55
- nn.ReLU(),
56
- nn.Linear(hidden_dim, out_dim, bias=False),
57
- nn.BatchNorm1d(out_dim, affine=False)
58
- )
59
-
60
- def forward(self, x):
61
- return self.layers(x)
62
-
63
- # ======================
64
- # Loss Function
65
- # ======================
66
-
67
- class BarlowTwinsLoss(nn.Module):
68
- def __init__(self, λ: float = 0.005):
69
- super().__init__()
70
- self.λ = λ
71
-
72
- def forward(self, z1: torch.Tensor, z2: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, float]]:
73
- B, d = z1.shape
74
- # Shared standardization
75
- z = torch.cat([z1, z2], dim=0)
76
- z = (z - z.mean(dim=0)) / (z.std(dim=0) + 1e-8)
77
- z1, z2 = z[:B], z[B:]
78
- c = (z1.T @ z2) / B
79
- on_diag = (1 - torch.diagonal(c)).pow(2).sum()
80
- off_diag = (c ** 2).sum() - torch.diagonal(c).pow(2).sum()
81
- off_diag = off_diag / d
82
- total_loss = on_diag + self.λ * off_diag
83
- with torch.no_grad():
84
- diag_mean = torch.diagonal(c).mean().item()
85
- off_diag_mask = ~torch.eye(d, dtype=torch.bool, device=c.device)
86
- off_diag_mean = c[off_diag_mask].abs().mean().item()
87
- return total_loss, {
88
- 'od': on_diag.item(),
89
- 'ofsc': (self.λ * off_diag).item(),
90
- 'ofrw': off_diag.item(),
91
- 'cr_onm': diag_mean,
92
- 'cr_offm': off_diag_mean
93
- }
94
-
95
-
96
- # ======================
97
- # Utilities
98
- # ======================
99
-
100
- def load_config(config_path: str) -> Dict[str, Any]:
101
- config_path = Path(config_path)
102
- if config_path.suffix in {'.yaml', '.yml'}:
103
- import yaml
104
- with open(config_path) as f:
105
- return yaml.safe_load(f)
106
- elif config_path.suffix == '.json':
107
- with open(config_path) as f:
108
- return json.load(f)
109
- else:
110
- raise ValueError(f"Unsupported config format: {config_path.suffix}")
111
-
112
- def sanitize_config(config: Dict[str, Any]) -> Dict[str, Any]:
113
- float_keys = {
114
- "LR", "WEIGHT_DECAY", "BARLOW_LAMBDA", "VICREG_LAMBDA",
115
- "VICREG_MU", "VICREG_NU", "CORINFOMAX_ALPHA"
116
- }
117
- int_keys = {
118
- "BATCH_SIZE", "EFFECTIVE_BATCH", "EPOCHS", "MAX_LENGTH",
119
- "SEED", "EVAL_EVERY_N_PERCENT"
120
- }
121
- bool_keys = {"BEST_BY_HEALTH"}
122
- for key in float_keys:
123
- if key in config:
124
- config[key] = float(config[key])
125
- for key in int_keys:
126
- if key in config:
127
- config[key] = int(config[key])
128
- for key in bool_keys:
129
- if key in config:
130
- val = config[key]
131
- config[key] = val.lower() in {"true", "1", "yes", "on"} if isinstance(val, str) else bool(val)
132
- return config
133
-
134
- def set_seed(seed: int):
135
- torch.manual_seed(seed)
136
- np.random.seed(seed)
137
- random.seed(seed)
138
- if torch.cuda.is_available():
139
- torch.cuda.manual_seed_all(seed)
140
-
141
- def enum_smiles(smi: str, k: int = 2) -> List[str]:
142
- from rdkit.Chem import MolFromSmiles, MolToSmiles
143
- mol = MolFromSmiles(smi)
144
- if mol is None:
145
- return [smi] * k
146
- variants = set()
147
- attempts = 0
148
- while len(variants) < k and attempts < 100:
149
- variants.add(MolToSmiles(mol, doRandom=True, canonical=False))
150
- attempts += 1
151
- return list(variants)[:k]
152
-
153
- def tanimoto(s1: str, s2: str) -> float:
154
- m1, m2 = MolFromSmiles(s1), MolFromSmiles(s2)
155
- if not m1 or not m2:
156
- return 0.0
157
- fp1 = AllChem.GetMorganFingerprintAsBitVect(m1, radius=2, nBits=2048)
158
- fp2 = AllChem.GetMorganFingerprintAsBitVect(m2, radius=2, nBits=2048)
159
- return DataStructs.TanimotoSimilarity(fp1, fp2)
160
-
161
- def uniformity_metrics(emb: np.ndarray) -> Dict[str, float]:
162
- emb = normalize(emb)
163
- sim = cosine_similarity(emb)
164
- mask = ~np.eye(len(sim), dtype=bool)
165
- pairwise = sim[mask]
166
- mean_sim, std_sim = pairwise.mean(), pairwise.std()
167
- distances = 1 - sim
168
- uniformity = np.log(np.exp(-2 * distances[mask]).mean())
169
- return {
170
- 'mean': float(mean_sim),
171
- 'std': float(std_sim),
172
- 'uniformity': float(uniformity),
173
- 'health_old': float(1 - mean_sim),
174
- 'collapsed': mean_sim > 0.7 or std_sim < 0.05
175
- }
176
-
177
- def forward_pooled(model: SentenceTransformer, text_list: List[str], device: torch.device) -> torch.Tensor:
178
- tok = model.tokenize(text_list)
179
- tok = {k: v.to(device) for k, v in tok.items()}
180
- hf_output = model(tok)
181
- return hf_output['token_embeddings'][:, 0, :]
182
-
183
- def evaluate(model, eval_smiles: List[str], device: torch.device, step: int) -> Dict[str, Any]:
184
- model.eval()
185
- with torch.no_grad():
186
- emb = model.encode(eval_smiles, convert_to_numpy=True, show_progress_bar=False, batch_size=32)
187
- um = uniformity_metrics(emb)
188
- same_view = [enum_smiles(s, 1)[0] for s in eval_smiles]
189
- with torch.no_grad():
190
- emb2 = model.encode(same_view, convert_to_numpy=True, show_progress_bar=False, batch_size=32)
191
- same_cos = np.diag(cosine_similarity(emb, emb2))
192
- alignment = 1 - same_cos.mean()
193
- barlow_health = same_cos.mean() - um['mean']
194
- print(f"\n📊 Step {step} | Alignment={alignment:.3f} | Uniformity={um['uniformity']:.3f}")
195
- print(f" Same-mol cos: {same_cos.mean():.3f}±{same_cos.std():.3f} | Pairwise: {um['mean']:.3f}±{um['std']:.3f}")
196
- print(f" Barlow Health: {barlow_health:.3f} (higher = better)")
197
- model.train()
198
- um['health'] = barlow_health
199
- um['alignment'] = alignment
200
- um['same_cos_mean'] = same_cos.mean()
201
- um['same_cos_std'] = same_cos.std()
202
- return um
203
-
204
-
205
- # ======================
206
- # Main
207
- # ======================
208
-
209
- def main():
210
- parser = argparse.ArgumentParser()
211
- parser.add_argument("--config", type=str, required=True)
212
- parser.add_argument("--epochs", type=int)
213
- parser.add_argument("--lr", type=float)
214
- parser.add_argument("--batch_size", type=int)
215
- parser.add_argument("--loss_type", type=str, choices=["barlow", "vicreg", "corinfomax"])
216
- args = parser.parse_args()
217
-
218
- config = load_config(args.config)
219
- for key, value in vars(args).items():
220
- if value is not None and key != "config":
221
- config[key] = value
222
- config = sanitize_config(config)
223
-
224
- set_seed(config.get("SEED", 42))
225
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
226
- output_dir = Path(config["OUTPUT_DIR"])
227
- output_dir.mkdir(parents=True, exist_ok=True)
228
-
229
- df = pd.read_csv(config["DATA_PATH"])
230
- smiles_list = df["SMILES"].dropna().tolist()
231
- print(f"📂 Loaded {len(smiles_list)} SMILES")
232
-
233
- train_examples = []
234
- for smi in tqdm(smiles_list, desc="Enumerating SMILES"):
235
- variants = enum_smiles(smi, 2)
236
- if len(variants) < 2:
237
- variants = [smi, smi]
238
- train_examples.append(InputExample(texts=[variants[0], variants[1]]))
239
- print(f" Created {len(train_examples)} pairs")
240
-
241
- eval_size = min(200, len(smiles_list))
242
- eval_smiles = np.random.choice(smiles_list, eval_size, replace=False).tolist()
243
-
244
- # Model
245
- model = SentenceTransformer('./chmbedv2-warmup-l5/final')
246
- model.max_seq_length = config.get("MAX_LENGTH", 512)
247
- embed_dim = model.get_sentence_embedding_dimension()
248
-
249
- # Projector & Loss
250
- loss_type = config.get("LOSS_TYPE", "barlow")
251
- if loss_type == "barlow":
252
- projector = BarlowTwinsProjector(
253
- embed_dim,
254
- hidden_dim=2048,
255
- out_dim=2048
256
- ).to(device)
257
- train_loss = BarlowTwinsLoss(
258
- λ=config.get("BARLOW_LAMBDA", 0.005)
259
- ).to(device)
260
- else:
261
- raise ValueError(f"Unknown loss_type: {loss_type}")
262
-
263
- model.to(device)
264
-
265
- # Optimizer (include projector!)
266
- from ranger21 import Ranger21
267
-
268
- no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
269
- model_params = [
270
- {"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
271
- "weight_decay": config.get("WEIGHT_DECAY", 0.01)},
272
- {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
273
- "weight_decay": 0.0}
274
- ]
275
-
276
- # Calculate training parameters for Ranger21 scheduling
277
- batch_size = config.get("BATCH_SIZE", 8)
278
- effective_batch = config.get("EFFECTIVE_BATCH", 32)
279
- grad_acc = effective_batch // batch_size
280
- epochs = config.get("EPOCHS", 1)
281
- total_steps = (len(train_examples) // effective_batch) * epochs
282
- train_loader = DataLoader(train_examples, batch_size=batch_size, shuffle=True, collate_fn=lambda x: x)
283
- num_batches_per_epoch = len(train_examples) // effective_batch
284
-
285
- no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
286
- model_params = [
287
- {"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
288
- "weight_decay": config.get("WEIGHT_DECAY", 0.01)},
289
- {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
290
- "weight_decay": 0.0}
291
- ]
292
-
293
- optimizer = Ranger21(
294
- model_params + [{"params": projector.parameters(), "weight_decay": config.get("WEIGHT_DECAY", 0.01)}],
295
- lr=config.get("LR", 1e-5),
296
- num_epochs=epochs,
297
- num_batches_per_epoch=num_batches_per_epoch,
298
- weight_decay=0.0, # Handle weight decay manually in param groups
299
- )
300
-
301
- # Training loop setup
302
- scheduler = torch.optim.lr_scheduler.LinearLR(
303
- optimizer, start_factor=1.0, end_factor=0.0, total_iters=total_steps
304
- )
305
-
306
-
307
- # Train
308
- model.train()
309
- step = 0
310
- best_health = 0.0
311
- best_step = 0
312
- log_interval = max(1, int(total_steps * config.get("EVAL_EVERY_N_PERCENT", 25) / 100))
313
-
314
- for epoch in range(epochs):
315
- pbar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs}")
316
- for batch_idx, batch in enumerate(pbar):
317
- texts = [[ex.texts[i] for ex in batch] for i in range(2)]
318
- z1 = forward_pooled(model, texts[0], device)
319
- z2 = forward_pooled(model, texts[1], device)
320
- p1 = projector(z1)
321
- p2 = projector(z2)
322
- loss, extras = train_loss(p1, p2)
323
-
324
- loss = loss / grad_acc
325
- loss.backward()
326
-
327
- if (batch_idx + 1) % grad_acc == 0:
328
- optimizer.step()
329
- scheduler.step()
330
- optimizer.zero_grad()
331
- step += 1
332
-
333
- postfix = {"step": step, "lr": scheduler.get_last_lr()[0]}
334
- for k, v in extras.items():
335
- postfix[k] = f"{v:.3f}"
336
- pbar.set_postfix(postfix)
337
-
338
- if step % log_interval == 0 or step == total_steps:
339
- um = evaluate(model, eval_smiles, device, step)
340
- if config.get("BEST_BY_HEALTH", True) and um["health"] > best_health:
341
- best_health, best_step = um["health"], step
342
- model.save(str(output_dir / "best"))
343
-
344
- model.save(str(output_dir / "final"))
345
- print(f"\n✅ Training complete! Best health: {best_health:.3f} at step {best_step}")
346
-
347
-
348
- if __name__ == "__main__":
 
 
 
 
349
  main()
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Self-Supervised Training for Molecular Representations (SMILES)
4
+
5
+ Usage:
6
+ python trainbarlow.py --config config.yaml
7
+ """
8
+ print("Initializing ...")
9
+ import os
10
+ import json
11
+ import argparse
12
+ import random
13
+ from pathlib import Path
14
+ from typing import Dict, Any, Tuple, List
15
+
16
+ import numpy as np
17
+ import pandas as pd
18
+ import torch
19
+ import torch.nn as nn
20
+ from torch.utils.data import DataLoader
21
+ from tqdm.auto import tqdm
22
+ from sklearn.metrics.pairwise import cosine_similarity
23
+ from sklearn.preprocessing import normalize
24
+
25
+ # Suppress RDKit warnings
26
+ from rdkit import RDLogger
27
+ RDLogger.DisableLog('rdApp.*')
28
+
29
+ try:
30
+ from rdkit.Chem import MolFromSmiles, MolToSmiles, AllChem
31
+ from rdkit import DataStructs
32
+ except ImportError:
33
+ raise ImportError("RDKit is required. Install with: conda install -c conda-forge rdkit")
34
+
35
+ try:
36
+ from sentence_transformers import SentenceTransformer, InputExample
37
+ except ImportError:
38
+ raise ImportError("Install sentence-transformers: pip install sentence-transformers")
39
+
40
+
41
+ # ======================
42
+ # Projector
43
+ # ======================
44
+
45
+ class BarlowTwinsProjector(nn.Module):
46
+ """Projector with BatchNorm (for Barlow Twins)."""
47
+ def __init__(self, in_dim: int, hidden_dim: int = 2048, out_dim: int = 2048):
48
+ super().__init__()
49
+ self.layers = nn.Sequential(
50
+ nn.Linear(in_dim, hidden_dim, bias=False),
51
+ nn.BatchNorm1d(hidden_dim),
52
+ nn.ReLU(),
53
+ nn.Linear(hidden_dim, hidden_dim, bias=False),
54
+ nn.BatchNorm1d(hidden_dim),
55
+ nn.ReLU(),
56
+ nn.Linear(hidden_dim, out_dim, bias=False),
57
+ nn.BatchNorm1d(out_dim, affine=False)
58
+ )
59
+
60
+ def forward(self, x):
61
+ return self.layers(x)
62
+
63
+ # ======================
64
+ # Loss Function
65
+ # ======================
66
+
67
+ class BarlowTwinsLoss(nn.Module):
68
+ """
69
+ Barlow Twins' Loss Implementation
70
+ with shared standardization and scaled off-diagonals with d.
71
+ """
72
+ def __init__(self, λ: float = 0.005):
73
+ super().__init__()
74
+ self.λ = λ
75
+
76
+ def forward(self, z1: torch.Tensor, z2: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, float]]:
77
+ B, d = z1.shape
78
+ # Shared standardization
79
+ z = torch.cat([z1, z2], dim=0)
80
+ z = (z - z.mean(dim=0)) / (z.std(dim=0) + 1e-8)
81
+ z1, z2 = z[:B], z[B:]
82
+ c = (z1.T @ z2) / B
83
+ on_diag = (1 - torch.diagonal(c)).pow(2).sum()
84
+ off_diag = (c ** 2).sum() - torch.diagonal(c).pow(2).sum()
85
+ off_diag = off_diag / d
86
+ total_loss = on_diag + self.λ * off_diag
87
+ with torch.no_grad():
88
+ diag_mean = torch.diagonal(c).mean().item()
89
+ off_diag_mask = ~torch.eye(d, dtype=torch.bool, device=c.device)
90
+ off_diag_mean = c[off_diag_mask].abs().mean().item()
91
+ return total_loss, {
92
+ 'od': on_diag.item(),
93
+ 'ofsc': (self.λ * off_diag).item(),
94
+ 'ofrw': off_diag.item(),
95
+ 'cr_onm': diag_mean,
96
+ 'cr_offm': off_diag_mean
97
+ }
98
+
99
+
100
+ # ======================
101
+ # Utilities
102
+ # ======================
103
+
104
+ def load_config(config_path: str) -> Dict[str, Any]:
105
+ config_path = Path(config_path)
106
+ if config_path.suffix in {'.yaml', '.yml'}:
107
+ import yaml
108
+ with open(config_path) as f:
109
+ return yaml.safe_load(f)
110
+ elif config_path.suffix == '.json':
111
+ with open(config_path) as f:
112
+ return json.load(f)
113
+ else:
114
+ raise ValueError(f"Unsupported config format: {config_path.suffix}")
115
+
116
+ def sanitize_config(config: Dict[str, Any]) -> Dict[str, Any]:
117
+ float_keys = {
118
+ "LR", "WEIGHT_DECAY", "BARLOW_LAMBDA", "VICREG_LAMBDA",
119
+ "VICREG_MU", "VICREG_NU", "CORINFOMAX_ALPHA"
120
+ }
121
+ int_keys = {
122
+ "BATCH_SIZE", "EFFECTIVE_BATCH", "EPOCHS", "MAX_LENGTH",
123
+ "SEED", "EVAL_EVERY_N_PERCENT"
124
+ }
125
+ bool_keys = {"BEST_BY_HEALTH"}
126
+ for key in float_keys:
127
+ if key in config:
128
+ config[key] = float(config[key])
129
+ for key in int_keys:
130
+ if key in config:
131
+ config[key] = int(config[key])
132
+ for key in bool_keys:
133
+ if key in config:
134
+ val = config[key]
135
+ config[key] = val.lower() in {"true", "1", "yes", "on"} if isinstance(val, str) else bool(val)
136
+ return config
137
+
138
+ def set_seed(seed: int):
139
+ torch.manual_seed(seed)
140
+ np.random.seed(seed)
141
+ random.seed(seed)
142
+ if torch.cuda.is_available():
143
+ torch.cuda.manual_seed_all(seed)
144
+
145
+ def enum_smiles(smi: str, k: int = 2) -> List[str]:
146
+ from rdkit.Chem import MolFromSmiles, MolToSmiles
147
+ mol = MolFromSmiles(smi)
148
+ if mol is None:
149
+ return [smi] * k
150
+ variants = set()
151
+ attempts = 0
152
+ while len(variants) < k and attempts < 100:
153
+ variants.add(MolToSmiles(mol, doRandom=True, canonical=False))
154
+ attempts += 1
155
+ return list(variants)[:k]
156
+
157
+ def tanimoto(s1: str, s2: str) -> float:
158
+ m1, m2 = MolFromSmiles(s1), MolFromSmiles(s2)
159
+ if not m1 or not m2:
160
+ return 0.0
161
+ fp1 = AllChem.GetMorganFingerprintAsBitVect(m1, radius=2, nBits=2048)
162
+ fp2 = AllChem.GetMorganFingerprintAsBitVect(m2, radius=2, nBits=2048)
163
+ return DataStructs.TanimotoSimilarity(fp1, fp2)
164
+
165
+ def uniformity_metrics(emb: np.ndarray) -> Dict[str, float]:
166
+ emb = normalize(emb)
167
+ sim = cosine_similarity(emb)
168
+ mask = ~np.eye(len(sim), dtype=bool)
169
+ pairwise = sim[mask]
170
+ mean_sim, std_sim = pairwise.mean(), pairwise.std()
171
+ distances = 1 - sim
172
+ uniformity = np.log(np.exp(-2 * distances[mask]).mean())
173
+ return {
174
+ 'mean': float(mean_sim),
175
+ 'std': float(std_sim),
176
+ 'uniformity': float(uniformity),
177
+ 'health_old': float(1 - mean_sim),
178
+ 'collapsed': mean_sim > 0.7 or std_sim < 0.05
179
+ }
180
+
181
+ def forward_pooled(model: SentenceTransformer, text_list: List[str], device: torch.device) -> torch.Tensor:
182
+ tok = model.tokenize(text_list)
183
+ tok = {k: v.to(device) for k, v in tok.items()}
184
+ hf_output = model(tok)
185
+ return hf_output['token_embeddings'][:, 0, :]
186
+
187
+ def evaluate(model, eval_smiles: List[str], device: torch.device, step: int) -> Dict[str, Any]:
188
+ model.eval()
189
+ with torch.no_grad():
190
+ emb = model.encode(eval_smiles, convert_to_numpy=True, show_progress_bar=False, batch_size=32)
191
+ um = uniformity_metrics(emb)
192
+ same_view = [enum_smiles(s, 1)[0] for s in eval_smiles]
193
+ with torch.no_grad():
194
+ emb2 = model.encode(same_view, convert_to_numpy=True, show_progress_bar=False, batch_size=32)
195
+ same_cos = np.diag(cosine_similarity(emb, emb2))
196
+ alignment = 1 - same_cos.mean()
197
+ barlow_health = same_cos.mean() - um['mean']
198
+ print(f"\n📊 Step {step} | Alignment={alignment:.3f} | Uniformity={um['uniformity']:.3f}")
199
+ print(f" Same-mol cos: {same_cos.mean():.3f}±{same_cos.std():.3f} | Pairwise: {um['mean']:.3f}±{um['std']:.3f}")
200
+ print(f" Barlow Health: {barlow_health:.3f} (higher = better)")
201
+ model.train()
202
+ um['health'] = barlow_health
203
+ um['alignment'] = alignment
204
+ um['same_cos_mean'] = same_cos.mean()
205
+ um['same_cos_std'] = same_cos.std()
206
+ return um
207
+
208
+
209
+ # ======================
210
+ # Main
211
+ # ======================
212
+
213
+ def main():
214
+ parser = argparse.ArgumentParser()
215
+ parser.add_argument("--config", type=str, required=True)
216
+ parser.add_argument("--epochs", type=int)
217
+ parser.add_argument("--lr", type=float)
218
+ parser.add_argument("--batch_size", type=int)
219
+ parser.add_argument("--loss_type", type=str, choices=["barlow", "vicreg", "corinfomax"])
220
+ args = parser.parse_args()
221
+
222
+ config = load_config(args.config)
223
+ for key, value in vars(args).items():
224
+ if value is not None and key != "config":
225
+ config[key] = value
226
+ config = sanitize_config(config)
227
+
228
+ set_seed(config.get("SEED", 42))
229
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
230
+ output_dir = Path(config["OUTPUT_DIR"])
231
+ output_dir.mkdir(parents=True, exist_ok=True)
232
+
233
+ df = pd.read_csv(config["DATA_PATH"])
234
+ smiles_list = df["SMILES"].dropna().tolist()
235
+ print(f"📂 Loaded {len(smiles_list)} SMILES")
236
+
237
+ train_examples = []
238
+ for smi in tqdm(smiles_list, desc="Enumerating SMILES"):
239
+ variants = enum_smiles(smi, 2)
240
+ if len(variants) < 2:
241
+ variants = [smi, smi]
242
+ train_examples.append(InputExample(texts=[variants[0], variants[1]]))
243
+ print(f" Created {len(train_examples)} pairs")
244
+
245
+ eval_size = min(200, len(smiles_list))
246
+ eval_smiles = np.random.choice(smiles_list, eval_size, replace=False).tolist()
247
+
248
+ # Model
249
+ model = SentenceTransformer('./chmbedv2-warmup-l5/final')
250
+ model.max_seq_length = config.get("MAX_LENGTH", 512)
251
+ embed_dim = model.get_sentence_embedding_dimension()
252
+
253
+ # Projector & Loss
254
+ loss_type = config.get("LOSS_TYPE", "barlow")
255
+ if loss_type == "barlow":
256
+ projector = BarlowTwinsProjector(
257
+ embed_dim,
258
+ hidden_dim=2048,
259
+ out_dim=2048
260
+ ).to(device)
261
+ train_loss = BarlowTwinsLoss(
262
+ λ=config.get("BARLOW_LAMBDA", 0.005)
263
+ ).to(device)
264
+ else:
265
+ raise ValueError(f"Unknown loss_type: {loss_type}")
266
+
267
+ model.to(device)
268
+
269
+ # Optimizer (include projector!)
270
+ from ranger21 import Ranger21
271
+
272
+ no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
273
+ model_params = [
274
+ {"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
275
+ "weight_decay": config.get("WEIGHT_DECAY", 0.01)},
276
+ {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
277
+ "weight_decay": 0.0}
278
+ ]
279
+
280
+ # Calculate training parameters for Ranger21 scheduling
281
+ batch_size = config.get("BATCH_SIZE", 8)
282
+ effective_batch = config.get("EFFECTIVE_BATCH", 32)
283
+ grad_acc = effective_batch // batch_size
284
+ epochs = config.get("EPOCHS", 1)
285
+ total_steps = (len(train_examples) // effective_batch) * epochs
286
+ train_loader = DataLoader(train_examples, batch_size=batch_size, shuffle=True, collate_fn=lambda x: x)
287
+ num_batches_per_epoch = len(train_examples) // effective_batch
288
+
289
+ no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
290
+ model_params = [
291
+ {"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
292
+ "weight_decay": config.get("WEIGHT_DECAY", 0.01)},
293
+ {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
294
+ "weight_decay": 0.0}
295
+ ]
296
+
297
+ optimizer = Ranger21(
298
+ model_params + [{"params": projector.parameters(), "weight_decay": config.get("WEIGHT_DECAY", 0.01)}],
299
+ lr=config.get("LR", 1e-5),
300
+ num_epochs=epochs,
301
+ num_batches_per_epoch=num_batches_per_epoch,
302
+ weight_decay=0.0, # Handle weight decay manually in param groups
303
+ )
304
+
305
+ # Training loop setup
306
+ scheduler = torch.optim.lr_scheduler.LinearLR(
307
+ optimizer, start_factor=1.0, end_factor=0.0, total_iters=total_steps
308
+ )
309
+
310
+
311
+ # Train
312
+ model.train()
313
+ step = 0
314
+ best_health = 0.0
315
+ best_step = 0
316
+ log_interval = max(1, int(total_steps * config.get("EVAL_EVERY_N_PERCENT", 25) / 100))
317
+
318
+ for epoch in range(epochs):
319
+ pbar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs}")
320
+ for batch_idx, batch in enumerate(pbar):
321
+ texts = [[ex.texts[i] for ex in batch] for i in range(2)]
322
+ z1 = forward_pooled(model, texts[0], device)
323
+ z2 = forward_pooled(model, texts[1], device)
324
+ p1 = projector(z1)
325
+ p2 = projector(z2)
326
+ loss, extras = train_loss(p1, p2)
327
+
328
+ loss = loss / grad_acc
329
+ loss.backward()
330
+
331
+ if (batch_idx + 1) % grad_acc == 0:
332
+ optimizer.step()
333
+ scheduler.step()
334
+ optimizer.zero_grad()
335
+ step += 1
336
+
337
+ postfix = {"step": step, "lr": scheduler.get_last_lr()[0]}
338
+ for k, v in extras.items():
339
+ postfix[k] = f"{v:.3f}"
340
+ pbar.set_postfix(postfix)
341
+
342
+ if step % log_interval == 0 or step == total_steps:
343
+ um = evaluate(model, eval_smiles, device, step)
344
+ if config.get("BEST_BY_HEALTH", True) and um["health"] > best_health:
345
+ best_health, best_step = um["health"], step
346
+ model.save(str(output_dir / "best"))
347
+
348
+ model.save(str(output_dir / "final"))
349
+ print(f"\n✅ Training complete! Best health: {best_health:.3f} at step {best_step}")
350
+
351
+
352
+ if __name__ == "__main__":
353
  main()