timcryt commited on
Commit
f6fc460
·
verified ·
1 Parent(s): fb07bfc

Initial commit

Browse files
Files changed (2) hide show
  1. test_model.py +96 -0
  2. train_model.py +312 -0
test_model.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import sys
4
+ import numpy as np
5
+ import pandas as pd
6
+ import torch
7
+ from transformers import AutoTokenizer, AutoModelForMaskedLM
8
+ from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
9
+ from sklearn.model_selection import cross_validate
10
+
11
+ # Конфигурация по умолчанию
12
+ DEFAULT_TASKS = ['ESOL', 'FreeSolv', 'HIV', 'BACE', 'BBBP', 'ClinTox']
13
+ MODEL_NAME = "DeepChem/ChemBERTa-10M-MLM"
14
+
15
+ def load_model_and_checkpoint(checkpoint_path, device="cpu"):
16
+ print(f"Loading model {MODEL_NAME}...", file=sys.stderr)
17
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
18
+ backbone = AutoModelForMaskedLM.from_pretrained(MODEL_NAME).roberta.to(device)
19
+
20
+ if not os.path.exists(checkpoint_path):
21
+ raise FileNotFoundError(f"File not found: {checkpoint_path}")
22
+
23
+ checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
24
+ backbone.load_state_dict(checkpoint['backbone'])
25
+ backbone.eval()
26
+ print("Model is loaded", file=sys.stderr)
27
+ return tokenizer, backbone
28
+
29
+ @torch.no_grad()
30
+ def mol_to_emb(smiles, tokenizer, model, device="cpu"):
31
+ tokenized = tokenizer([smiles], padding=False, return_tensors="pt")
32
+ input_ids = tokenized['input_ids'].to(device)
33
+ hs = model(input_ids).last_hidden_state
34
+
35
+ emb = torch.cat([hs[:, 0], hs[:, 1:].mean(dim=1)], dim=1)
36
+ return emb.squeeze(0).cpu().numpy()
37
+
38
+ def evaluate_tasks(checkpoint_path, data_dir='./support/', device="cpu"):
39
+ tasks = DEFAULT_TASKS
40
+ tokenizer, model = load_model_and_checkpoint(checkpoint_path, device)
41
+
42
+ results = {}
43
+ for task in tasks:
44
+ csv_path = os.path.join(data_dir, f"{task}.csv")
45
+ if not os.path.exists(csv_path):
46
+ print(f"\n[WARN] File {csv_path} not found. Skipping '{task}'.", file=sys.stderr)
47
+ continue
48
+
49
+ print(f"Task: {task}", file=sys.stderr)
50
+ ds = pd.read_csv(csv_path, sep='\t')
51
+
52
+ # Вычисление эмбеддингов с прогресс-баром
53
+ ds['v'] = ds['X'].apply(lambda x: mol_to_emb(x, tokenizer, model, device))
54
+ ds = ds.sample(frac=1, random_state=42).reset_index(drop=True)
55
+
56
+ # Подготовка данных для sklearn
57
+ X = np.stack(ds['v'].values)
58
+ y = ds['y'].to_numpy()
59
+
60
+ # Выбор модели и метрики
61
+ if task in ['ESOL', 'FreeSolv']:
62
+ rf_model = RandomForestRegressor(random_state=42, n_jobs=5)
63
+ scoring = 'neg_mean_absolute_error'
64
+ metric_name = "MAE"
65
+ else:
66
+ rf_model = RandomForestClassifier(random_state=42, n_jobs=5)
67
+ scoring = 'f1_macro'
68
+ metric_name = "F1-macro"
69
+
70
+ # Кросс-валидация
71
+ cv_results = cross_validate(rf_model, X, y, cv=5, scoring=scoring, n_jobs=1)
72
+ mean_score = cv_results['test_score'].mean()
73
+ std_score = cv_results['test_score'].std()
74
+ results[task] = (mean_score, std_score)
75
+ print(f" {metric_name}: {mean_score:.4f} ± {std_score:.4f}", file=sys.stderr)
76
+
77
+
78
+ for task, (mean, std) in results.items():
79
+ print(f"{task:10}: {mean:.4f} ± {std:.4f}")
80
+ return results
81
+
82
+ if __name__ == "__main__":
83
+ parser = argparse.ArgumentParser()
84
+ parser.add_argument(
85
+ "checkpoint_path", type=str,
86
+ help="Path to checkpoint file (.pth)"
87
+ )
88
+ parser.add_argument(
89
+ "--device", type=str, default="cpu", choices=["cpu", "cuda"],
90
+ )
91
+ args = parser.parse_args()
92
+
93
+ evaluate_tasks(
94
+ checkpoint_path=args.checkpoint_path,
95
+ device=args.device
96
+ )
train_model.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import math
4
+ import logging
5
+ import argparse
6
+ from datetime import datetime
7
+ from typing import Dict, List, Any
8
+
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn.functional as F
12
+ import torch.nn as nn
13
+ from torch.utils.data import DataLoader
14
+ from torch.utils.tensorboard import SummaryWriter
15
+ from tqdm import tqdm
16
+
17
+ from datasets import load_dataset
18
+ from transformers import (
19
+ AutoTokenizer,
20
+ AutoModelForMaskedLM,
21
+ DataCollatorForLanguageModeling,
22
+ PreTrainedTokenizerBase
23
+ )
24
+ from rdkit import Chem
25
+ from rdkit.Chem import Descriptors
26
+
27
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(message)s")
28
+ logger = logging.getLogger(__name__)
29
+
30
+ def compute_rdkit_features(smiles: str) -> np.ndarray:
31
+ try:
32
+ mol = Chem.MolFromSmiles(smiles)
33
+ if mol is None:
34
+ return np.zeros(210, dtype=np.float32)
35
+ return np.array(list(Descriptors.CalcMolDescriptors(mol).values()))
36
+ except Exception:
37
+ return np.zeros(210, dtype=np.float32)
38
+
39
+ class SMILESAndDescriptorCollator:
40
+ def __init__(
41
+ self,
42
+ tokenizer: PreTrainedTokenizerBase,
43
+ max_length: int = 512,
44
+ mlm_probability: float = 0.15,
45
+ do_mlm: bool = True
46
+ ):
47
+ self.tokenizer = tokenizer
48
+ self.max_length = max_length
49
+ self.do_mlm = do_mlm
50
+ if self.do_mlm:
51
+ self.mlm_collator = DataCollatorForLanguageModeling(
52
+ tokenizer=self.tokenizer,
53
+ mlm=True,
54
+ mlm_probability=mlm_probability,
55
+ return_tensors="pt"
56
+ )
57
+ else:
58
+ self.mlm_collator = None
59
+
60
+ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
61
+ smiles_batch = [f['smiles'] for f in features]
62
+ descriptors_list = [f['descriptors'] for f in features]
63
+
64
+ tokenized = self.tokenizer(
65
+ smiles_batch,
66
+ padding=False,
67
+ truncation=True,
68
+ max_length=self.max_length,
69
+ return_tensors=None
70
+ )
71
+
72
+ features_for_mlm = [
73
+ {k: v[i] for k, v in tokenized.items()}
74
+ for i in range(len(smiles_batch))
75
+ ]
76
+
77
+ if self.do_mlm and self.mlm_collator:
78
+ batch_text = self.mlm_collator(features_for_mlm)
79
+ else:
80
+ tokenized_padded = self.tokenizer.pad(
81
+ features_for_mlm,
82
+ padding=True,
83
+ max_length=self.max_length,
84
+ return_tensors="pt"
85
+ )
86
+ batch_text = dict(tokenized_padded)
87
+
88
+ descriptors_tensor = torch.tensor(np.stack(descriptors_list), dtype=torch.float32)
89
+
90
+ batch = batch_text
91
+ batch['descriptors'] = descriptors_tensor
92
+
93
+ return batch
94
+
95
+ def get_backbone_grad_vector(module, exclude_keywords=None):
96
+ if exclude_keywords is None:
97
+ exclude_keywords = []
98
+
99
+ grads = []
100
+ for name, param in module.named_parameters():
101
+ if any(keyword in name.lower() for keyword in exclude_keywords):
102
+ continue
103
+ if param.grad is not None:
104
+ grads.append(param.grad.detach().flatten())
105
+
106
+ if len(grads) == 0:
107
+ return torch.tensor([])
108
+
109
+ return torch.cat(grads)
110
+
111
+ def compute_gradient_metrics(model, loss1, loss2, exclude_keywords=None):
112
+ if exclude_keywords is None:
113
+ exclude_keywords = []
114
+
115
+ model.zero_grad(set_to_none=True)
116
+ loss1.backward(retain_graph=True)
117
+ g1 = get_backbone_grad_vector(model, exclude_keywords)
118
+ norm_mtr = g1.norm().item() if g1 is not None and g1.numel() > 0 else None
119
+
120
+ model.zero_grad(set_to_none=True)
121
+ loss2.backward(retain_graph=True)
122
+ g2 = get_backbone_grad_vector(model, exclude_keywords)
123
+ norm_mlm = g2.norm().item() if g2 is not None and g2.numel() > 0 else None
124
+
125
+ model.zero_grad(set_to_none=True)
126
+
127
+ angle_deg = None
128
+ if (g1 is not None and g2 is not None and
129
+ g1.numel() > 0 and g2.numel() > 0 and
130
+ g1.numel() == g2.numel()):
131
+ cos_sim = F.cosine_similarity(g1.unsqueeze(0), g2.unsqueeze(0), dim=1).item()
132
+ cos_sim = max(min(cos_sim, 1.0), -1.0)
133
+ angle_rad = math.acos(cos_sim)
134
+ angle_deg = math.degrees(angle_rad)
135
+
136
+ return {
137
+ 'angle_deg': angle_deg,
138
+ 'norm_mtr': norm_mtr,
139
+ 'norm_mlm': norm_mlm
140
+ }
141
+
142
+ if __name__ == '__main__':
143
+ parser = argparse.ArgumentParser(description="ChemBERTa Multi-Task Training")
144
+ parser.add_argument("--smiles_file", type=str, default="support/smiles_10k.txt")
145
+ parser.add_argument("--stats_file", type=str, default="support/normalization_params.pth")
146
+ parser.add_argument("--output_file", type=str, default="model.pth")
147
+ parser.add_argument("--batch_size", type=int, default=64)
148
+ parser.add_argument("--max_length", type=int, default=128)
149
+ parser.add_argument("--mlm_weight", type=float, default=1.0)
150
+ parser.add_argument("--mtr_weight", type=float, default=1.0)
151
+ parser.add_argument("--lr", type=float, default=3e-5)
152
+ parser.add_argument("--epochs", type=int, default=1)
153
+ args = parser.parse_args()
154
+
155
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
156
+ logger.info(f"Using device: {device}")
157
+
158
+ logger.info("Loading model...")
159
+ tokenizer = AutoTokenizer.from_pretrained("DeepChem/ChemBERTa-10M-MLM")
160
+ model_base = AutoModelForMaskedLM.from_pretrained("DeepChem/ChemBERTa-10M-MLM").roberta
161
+ model_dim = 384
162
+
163
+ mlm_head = nn.Sequential(
164
+ nn.Linear(model_dim, model_dim * 2),
165
+ nn.GELU(),
166
+ nn.Linear(model_dim * 2, tokenizer.vocab_size),
167
+ )
168
+ rdkit_head = nn.Sequential(
169
+ nn.Linear(model_dim, model_dim * 2),
170
+ nn.GELU(),
171
+ nn.Linear(model_dim * 2, 210),
172
+ )
173
+
174
+ model_base.to(device)
175
+ mlm_head.to(device)
176
+ rdkit_head.to(device)
177
+
178
+ logger.info("Loading dataset...")
179
+ raw_dataset = load_dataset("text", data_files={"train": args.smiles_file})
180
+ raw_dataset = raw_dataset.rename_column("text", "smiles")
181
+
182
+ logger.info("Calculating RDKit features...")
183
+ processed_dataset = raw_dataset.map(
184
+ lambda x: {"descriptors": compute_rdkit_features(x["smiles"])},
185
+ num_proc=8,
186
+ desc="Calculating RDKit features"
187
+ )
188
+
189
+ collator = SMILESAndDescriptorCollator(tokenizer=tokenizer, max_length=args.max_length)
190
+ dataloader = DataLoader(
191
+ processed_dataset["train"],
192
+ batch_size=args.batch_size,
193
+ collate_fn=collator
194
+ )
195
+
196
+ logger.info("Loading normalization stats...")
197
+ stats = torch.load(args.stats_file, map_location=device)
198
+ means = stats["means"].to(device)
199
+ stds = stats["stds"].to(device)
200
+ stds[stds < 1e-6] = 1.0
201
+
202
+ optimizer = torch.optim.AdamW(
203
+ list(model_base.parameters()) + list(mlm_head.parameters()) + list(rdkit_head.parameters()),
204
+ lr=args.lr, weight_decay=1e-4
205
+ )
206
+
207
+ clip_grad_norm = 1.0
208
+ BACKBONE_EXCLUDE_KEYWORDS = ["head", "rdkit", "mlm", "classifier", "pooler"]
209
+ LOG_GRAD_METRICS_EVERY_N_BATCHES = 10
210
+
211
+ log_dir = os.path.join("runs", f"train_{datetime.now().strftime('%Y%m%d_%H%M%S')}")
212
+ writer = SummaryWriter(log_dir=log_dir)
213
+ global_step = 0
214
+
215
+ logger.info("Starting training")
216
+
217
+ for epoch in range(args.epochs):
218
+ start_time = time.time()
219
+ total_loss_mtr = 0.0
220
+ total_loss_mlm = 0.0
221
+ total_loss = 0.0
222
+ num_batches = 0
223
+
224
+ pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{args.epochs}")
225
+ model_base.train()
226
+ mlm_head.train()
227
+ rdkit_head.train()
228
+
229
+ for batch_idx, batch in enumerate(pbar):
230
+ input_ids = batch["input_ids"].to(device)
231
+ attention_mask = batch["attention_mask"].to(device)
232
+ descriptors = batch["descriptors"].to(device)
233
+ labels = batch["labels"].to(device)
234
+
235
+ outputs = model_base(input_ids, attention_mask=attention_mask)
236
+ result = outputs.last_hidden_state
237
+
238
+ mtr_res = rdkit_head(result[:, 0])
239
+ mlm_res = mlm_head(result)
240
+
241
+ loss_mtr = F.huber_loss(mtr_res, (descriptors - means) / stds)
242
+ loss_mlm = F.cross_entropy(mlm_res.flatten(end_dim=1), labels.flatten(), ignore_index=-100)
243
+ loss = loss_mtr * args.mtr_weight + loss_mlm * args.mlm_weight
244
+
245
+ if num_batches % LOG_GRAD_METRICS_EVERY_N_BATCHES == 0:
246
+ metrics = compute_gradient_metrics(
247
+ model=model_base, loss1=loss_mtr, loss2=loss_mlm,
248
+ exclude_keywords=BACKBONE_EXCLUDE_KEYWORDS
249
+ )
250
+ postfix = {
251
+ "loss_mlm": f"{loss_mlm.item():.4f}",
252
+ "loss_mtr": f"{loss_mtr.item():.4f}",
253
+ }
254
+ if metrics["angle_deg"] is not None:
255
+ postfix["angle"] = f"{metrics['angle_deg']:.1f}°"
256
+ writer.add_scalar("gradients/backbone_angle_deg", metrics["angle_deg"], global_step)
257
+ if metrics["norm_mtr"] is not None:
258
+ postfix["‖∇MTR‖"] = f"{metrics['norm_mtr']:.3f}"
259
+ writer.add_scalar("gradients/backbone_norm_mtr", metrics["norm_mtr"], global_step)
260
+ if metrics["norm_mlm"] is not None:
261
+ postfix["‖∇MLM‖"] = f"{metrics['norm_mlm']:.3f}"
262
+ writer.add_scalar("gradients/backbone_norm_mlm", metrics["norm_mlm"], global_step)
263
+
264
+ pbar.set_postfix(postfix)
265
+
266
+ writer.add_scalar("loss/total", loss.item(), global_step)
267
+ writer.add_scalar("loss/mtr_l1", loss_mtr.item(), global_step)
268
+ writer.add_scalar("loss/mlm_ce", loss_mlm.item(), global_step)
269
+
270
+ optimizer.zero_grad()
271
+ loss.backward()
272
+
273
+ grad_norm = torch.nn.utils.clip_grad_norm_(model_base.parameters(), clip_grad_norm)
274
+ torch.nn.utils.clip_grad_norm_(rdkit_head.parameters(), clip_grad_norm)
275
+ torch.nn.utils.clip_grad_norm_(mlm_head.parameters(), clip_grad_norm)
276
+
277
+ writer.add_scalar("training/grad_norm_clipped", grad_norm.item(), global_step)
278
+ writer.add_scalar("training/learning_rate", optimizer.param_groups[0]["lr"], global_step)
279
+
280
+ optimizer.step()
281
+
282
+ total_loss += loss.item()
283
+ total_loss_mtr += loss_mtr.item()
284
+ total_loss_mlm += loss_mlm.item()
285
+ num_batches += 1
286
+ global_step += 1
287
+
288
+ epoch_time = time.time() - start_time
289
+ avg_loss = total_loss / num_batches
290
+ avg_loss_mtr = total_loss_mtr / num_batches
291
+ avg_loss_mlm = total_loss_mlm / num_batches
292
+
293
+ writer.add_scalar("epoch/avg_total_loss", avg_loss, epoch)
294
+ writer.add_scalar("epoch/avg_loss_mtr", avg_loss_mtr, epoch)
295
+ writer.add_scalar("epoch/avg_loss_mlm", avg_loss_mlm, epoch)
296
+ writer.add_scalar("epoch/time_sec", epoch_time, epoch)
297
+
298
+ logger.info(
299
+ f"Epoch {epoch+1}/{args.epochs} | Time: {epoch_time:.2f}s | "
300
+ f"Total Loss: {avg_loss:.4f} | L1 (MTR): {avg_loss_mtr:.4f} | "
301
+ f"CE (MLM): {avg_loss_mlm:.4f} | Grad Norm: {grad_norm:.4f}"
302
+ )
303
+
304
+ writer.close()
305
+
306
+ logger.info("Saving checkpoint..")
307
+ torch.save({
308
+ "backbone": model_base.state_dict(),
309
+ "mlm_head": mlm_head.state_dict(),
310
+ "mtr_head": rdkit_head.state_dict(),
311
+ }, args.output_file)
312
+ logger.info("Training is finished")