Joblib
ynuozhang commited on
Commit
1c0b97e
·
1 Parent(s): baf3373

Delete unused embeddings, metrics, and training scripts

Browse files
Files changed (47) hide show
  1. .gitattributes +1 -0
  2. embeddings/binding/data-00000-of-00001.arrow +0 -3
  3. embeddings/binding/dataset_info.json +0 -3
  4. embeddings/binding/state.json +0 -3
  5. embeddings/fast_embedding_generation.py +0 -3
  6. embeddings/hemolysis/data-00000-of-00001.arrow +0 -3
  7. embeddings/hemolysis/dataset_info.json +0 -3
  8. embeddings/hemolysis/shuffled_hemo.csv +0 -3
  9. embeddings/hemolysis/state.json +0 -3
  10. embeddings/nonfouling/combined_nonfouling.csv +0 -3
  11. embeddings/nonfouling/data-00000-of-00001.arrow +0 -3
  12. embeddings/nonfouling/dataset_info.json +0 -3
  13. embeddings/nonfouling/state.json +0 -3
  14. embeddings/permeability/data-00000-of-00001.arrow +0 -3
  15. embeddings/permeability/dataset_info.json +0 -3
  16. embeddings/permeability/nc-CPP-processed.csv +0 -3
  17. embeddings/permeability/state.json +0 -3
  18. embeddings/solubility/data-00000-of-00001.arrow +0 -3
  19. embeddings/solubility/dataset_info.json +0 -3
  20. embeddings/solubility/shuffled_sol.csv +0 -3
  21. embeddings/solubility/state.json +0 -3
  22. metrics/binding/best_model_val_correlation.png +0 -3
  23. metrics/binding/binding_train_correlation.png +0 -3
  24. metrics/hemolysis/optimization_metrics.txt +0 -3
  25. metrics/hemolysis/train_classification_plot.png +0 -3
  26. metrics/hemolysis/train_predictions_binary.csv +0 -3
  27. metrics/hemolysis/val_classification_plot.png +0 -3
  28. metrics/hemolysis/val_predictions_binary.csv +0 -3
  29. metrics/nonfouling/optimization_metrics.txt +0 -22
  30. metrics/nonfouling/train_classification_plot.png +0 -0
  31. metrics/nonfouling/train_predictions_binary.csv +0 -3
  32. metrics/nonfouling/val_classification_plot.png +0 -0
  33. metrics/nonfouling/val_predictions_binary.csv +0 -3
  34. metrics/permeability/optimization_metrics.txt +0 -3
  35. metrics/permeability/train_correlation.png +0 -3
  36. metrics/permeability/train_predictions.csv +0 -3
  37. metrics/permeability/val_correlation.png +0 -3
  38. metrics/permeability/val_predictions.csv +0 -3
  39. metrics/solubility/optimization_metrics.txt +0 -3
  40. metrics/solubility/train_classification_plot.png +0 -3
  41. metrics/solubility/train_predictions_binary.csv +0 -3
  42. metrics/solubility/val_classification_plot.png +0 -3
  43. metrics/solubility/val_predictions_binary.csv +0 -3
  44. train/binary_xg.py +0 -223
  45. train/binding_affinity_model_clean.ipynb +0 -0
  46. train/binding_utils.py +0 -291
  47. train/permeability_xg.py +0 -186
.gitattributes CHANGED
@@ -88,3 +88,4 @@ training_data filter=lfs diff=lfs merge=lfs -text
88
  README.md filter=lfs diff=lfs merge=lfs -text
89
  embeddings filter=lfs diff=lfs merge=lfs -text
90
  models/binding_affinity_for_smiles.pt filter=lfs diff=lfs merge=lfs -text
 
 
88
  README.md filter=lfs diff=lfs merge=lfs -text
89
  embeddings filter=lfs diff=lfs merge=lfs -text
90
  models/binding_affinity_for_smiles.pt filter=lfs diff=lfs merge=lfs -text
91
+ *.csv filter=lfs diff=lfs merge=lfs -text
embeddings/binding/data-00000-of-00001.arrow DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:d9b08ce28b452e9767dfc7c60bd6285421bdc6b791150a5f55158da89c7bda4f
3
- size 15746448
 
 
 
 
embeddings/binding/dataset_info.json DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:bbe033d1c9b2ec182afa2b1682a4c41c4fdc0cd548c08d7a21c4364dc68b3595
3
- size 784
 
 
 
 
embeddings/binding/state.json DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:c2b6d15e6c6cf1f18c4f4cdb0ea339adaf7083a1c754266b8ad6a6468484f693
3
- size 247
 
 
 
 
embeddings/fast_embedding_generation.py DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:0396104ebf1dc28b0d297bdfebede1927aba9a23417c9f06cd8f39d999d099d3
3
- size 3900
 
 
 
 
embeddings/hemolysis/data-00000-of-00001.arrow DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:bef85bc99bc3c81c99fe290c0b2ef6b0d43f50c0089c59be7bf24219dd428d05
3
- size 20965576
 
 
 
 
embeddings/hemolysis/dataset_info.json DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:a0708b61d5f0aa0a11b62205cbd08b504ad6957c271ff3b984c3e3b9457ce9bf
3
- size 370
 
 
 
 
embeddings/hemolysis/shuffled_hemo.csv DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:994d4faa1f29c59cb7fa9d18faa8605766b4033ae2bc432ea76e0bbdf4876b29
3
- size 2236370
 
 
 
 
embeddings/hemolysis/state.json DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:17a8a2f0f7a7c272396e10572248a45faeaae1c6390ef311b551ab97f5eb72b8
3
- size 247
 
 
 
 
embeddings/nonfouling/combined_nonfouling.csv DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:90f770673946d37f5c8362b3956c789f62ec71a514ab8c46bb2523b3b3d5be2e
3
- size 28623153
 
 
 
 
embeddings/nonfouling/data-00000-of-00001.arrow DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:44c869d7a6a03ccdf9143c4007bc85c996146e12e40da3efbd455eaf49eba016
3
- size 81645736
 
 
 
 
embeddings/nonfouling/dataset_info.json DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:2a2517b8ba7deb4b26855d073360ab0c6a38c2780bc02211c03d2f51f9ccbb00
3
- size 368
 
 
 
 
embeddings/nonfouling/state.json DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:c5bf4d95c1fd57f172217a453408d1339a87f7eca5d02f4cd718bfdde6b519fb
3
- size 247
 
 
 
 
embeddings/permeability/data-00000-of-00001.arrow DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:82e749eafb2e903ef2dc47255dbe4e489e6db8055b3ba6af4c876d9b1a0f1b38
3
- size 22250496
 
 
 
 
embeddings/permeability/dataset_info.json DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:a0708b61d5f0aa0a11b62205cbd08b504ad6957c271ff3b984c3e3b9457ce9bf
3
- size 370
 
 
 
 
embeddings/permeability/nc-CPP-processed.csv DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:f0bbf79d5a78023460de72087bbadd8d1f5b841b21b05b2d148d9ccd7fc3a254
3
- size 1083167
 
 
 
 
embeddings/permeability/state.json DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:e7a5a8e0db7a08e1c26e0115a0c3225d1850b620d00ab24b42769d78ee6208fd
3
- size 247
 
 
 
 
embeddings/solubility/data-00000-of-00001.arrow DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:36ac428037f8d09d1f45fcd6a61517428c4409638d63230b3ff1d375bdd0e5cb
3
- size 106655176
 
 
 
 
embeddings/solubility/dataset_info.json DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:a0708b61d5f0aa0a11b62205cbd08b504ad6957c271ff3b984c3e3b9457ce9bf
3
- size 370
 
 
 
 
embeddings/solubility/shuffled_sol.csv DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:dc5a2139e52a6deb8a1eda45eb5f8b8abfd260724dbbd274a824a9aecb929890
3
- size 49775729
 
 
 
 
embeddings/solubility/state.json DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:403b965b3b855f52632a00e828046dd90749edc7808ef4725d54606673f1d5cb
3
- size 247
 
 
 
 
metrics/binding/best_model_val_correlation.png DELETED

Git LFS Details

  • SHA256: 3ae1a98e66a9cf4fdc0557cdd821024a8bf68ca180920488d128fea247dc0c70
  • Pointer size: 131 Bytes
  • Size of remote file: 442 kB
metrics/binding/binding_train_correlation.png DELETED

Git LFS Details

  • SHA256: 84e4ef160a333ebcf6d2bad1e794eebd3823095c8ea484d462badd461e7c0415
  • Pointer size: 131 Bytes
  • Size of remote file: 834 kB
metrics/hemolysis/optimization_metrics.txt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:2dc387c4d486b3c94a4237ceb04d8f04e18d9198e0b3c069e04c940b459f5422
3
- size 612
 
 
 
 
metrics/hemolysis/train_classification_plot.png DELETED

Git LFS Details

  • SHA256: ff405228b4d43a652dcf1e9a634c7dc2050f18b39084d23029dc4256faf70b5a
  • Pointer size: 130 Bytes
  • Size of remote file: 29.8 kB
metrics/hemolysis/train_predictions_binary.csv DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:5604b65e9f922252ccc92f1bbe9be33270245ada96fef2f3e92dc50c30a5487a
3
- size 87035
 
 
 
 
metrics/hemolysis/val_classification_plot.png DELETED

Git LFS Details

  • SHA256: 232a28293e6c2674b4a00d1404ace0a226ac967b2c7c52d638a806a5db5a3c89
  • Pointer size: 130 Bytes
  • Size of remote file: 49.2 kB
metrics/hemolysis/val_predictions_binary.csv DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:00aca5ad646502e56930f2473c596b1cf754dd96ec13c7d2b630eb2e8cec3d95
3
- size 21607
 
 
 
 
metrics/nonfouling/optimization_metrics.txt DELETED
@@ -1,22 +0,0 @@
1
-
2
- ============================================================
3
- OPTIMIZATION COMPLETE
4
- ============================================================
5
- Number of finished trials: 200
6
-
7
- Best Trial: #52
8
- Best F1 Score: 0.8774
9
- Best AUC Score: 0.9327
10
- Optuna Best Trial Value: 0.8774
11
-
12
- Best hyperparameters:
13
- lambda: 3.1278404540677405e-06
14
- alpha: 2.865349682111457
15
- colsample_bytree: 0.6388434847100901
16
- subsample: 0.975052331668336
17
- learning_rate: 0.1046988967097677
18
- max_depth: 5
19
- min_child_weight: 283
20
- gamma: 0.7863860752901305
21
- num_boost_round: 876
22
- ============================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
metrics/nonfouling/train_classification_plot.png DELETED
Binary file (32.9 kB)
 
metrics/nonfouling/train_predictions_binary.csv DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:07931d9e8b63b4284bc7ce77cdcb719ea3b81d16125e14c277fce06c0e3d6b1d
3
- size 219852
 
 
 
 
metrics/nonfouling/val_classification_plot.png DELETED
Binary file (39.2 kB)
 
metrics/nonfouling/val_predictions_binary.csv DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:ef9845cf64c142ff16fc915402953a1383e36ecb1c76b6174fae75c0dec59cd4
3
- size 54904
 
 
 
 
metrics/permeability/optimization_metrics.txt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:aa91d37a6e116d4ff44eb2606e439f3243ddbeed613abc567128dd8112901e0c
3
- size 676
 
 
 
 
metrics/permeability/train_correlation.png DELETED

Git LFS Details

  • SHA256: f378e9d3f6a45873b666263d2f5cf1258f19f0959c6c63f6efa740e4bc9f302d
  • Pointer size: 131 Bytes
  • Size of remote file: 101 kB
metrics/permeability/train_predictions.csv DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:bd3817ce0aa54ce50cb3d56c83ae2a3157cd4d4deca7af743f9e24d2b0bfb675
3
- size 89716
 
 
 
 
metrics/permeability/val_correlation.png DELETED

Git LFS Details

  • SHA256: 3862f998d06f9c017df8dc1a9e8b22ba0c1fe8aecb5fb7a2a95706df57699bfc
  • Pointer size: 131 Bytes
  • Size of remote file: 118 kB
metrics/permeability/val_predictions.csv DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:26e268ec0f9b5290ecc7b76c9b41e5819e4bc45bb1bb29843f2ae9b781ecab46
3
- size 22487
 
 
 
 
metrics/solubility/optimization_metrics.txt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:d935afa867779b07a7c99e51c12cdcd49828dd61e4be1ef357926e280fadf648
3
- size 612
 
 
 
 
metrics/solubility/train_classification_plot.png DELETED

Git LFS Details

  • SHA256: 0aaeaaf89a42cc837cd6db07cb0728fd04f6af855f8a5368b48a29c5401b7ea0
  • Pointer size: 130 Bytes
  • Size of remote file: 35.3 kB
metrics/solubility/train_predictions_binary.csv DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:852f71fe2fe196274e25866bfbea03874c7e75fec923556c96536ada85a71c7c
3
- size 244268
 
 
 
 
metrics/solubility/val_classification_plot.png DELETED

Git LFS Details

  • SHA256: a28973d586e0f8461cfd14996acc905a644820069537920475b5f5e906692c68
  • Pointer size: 130 Bytes
  • Size of remote file: 40.1 kB
metrics/solubility/val_predictions_binary.csv DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:8489b5b047b630571f3959d96e5fbb003f923854e7db777dd266331f586d115d
3
- size 61087
 
 
 
 
train/binary_xg.py DELETED
@@ -1,223 +0,0 @@
1
- import pandas as pd
2
- import numpy as np
3
- import torch
4
- from sklearn.model_selection import train_test_split
5
- from sklearn.metrics import precision_recall_curve, f1_score
6
- import optuna
7
- from optuna.trial import TrialState
8
- import xgboost as xgb
9
- import os
10
- from datasets import load_from_disk
11
- from lightning.pytorch import seed_everything
12
- from rdkit import Chem, rdBase, DataStructs
13
- from typing import List
14
- from rdkit.Chem import AllChem
15
- import matplotlib.pyplot as plt
16
- from sklearn.metrics import accuracy_score, roc_auc_score
17
-
18
- base_path = "/scratch/pranamlab/sophtang/home/scoring/PeptiVerse"
19
-
20
- def save_and_plot_binary_predictions(y_true_train, y_pred_train, y_true_val, y_pred_val, threshold, output_path):
21
- """
22
- Saves the true and predicted values for training and validation sets, and generates binary classification plots.
23
-
24
- Parameters:
25
- y_true_train (array): True labels for the training set.
26
- y_pred_train (array): Predicted probabilities for the training set.
27
- y_true_val (array): True labels for the validation set.
28
- y_pred_val (array): Predicted probabilities for the validation set.
29
- threshold (float): Classification threshold for predictions.
30
- output_path (str): Directory to save the CSV files and plots.
31
- """
32
- os.makedirs(output_path, exist_ok=True)
33
-
34
- # Convert probabilities to binary predictions
35
- y_pred_train_binary = (y_pred_train >= threshold).astype(int)
36
- y_pred_val_binary = (y_pred_val >= threshold).astype(int)
37
-
38
- # Save training predictions
39
- train_df = pd.DataFrame({
40
- 'True Label': y_true_train,
41
- 'Predicted Probability': y_pred_train,
42
- 'Predicted Label': y_pred_train_binary
43
- })
44
- train_df.to_csv(os.path.join(output_path, 'train_predictions_binary.csv'), index=False)
45
-
46
- # Save validation predictions
47
- val_df = pd.DataFrame({
48
- 'True Label': y_true_val,
49
- 'Predicted Probability': y_pred_val,
50
- 'Predicted Label': y_pred_val_binary
51
- })
52
- val_df.to_csv(os.path.join(output_path, 'val_predictions_binary.csv'), index=False)
53
-
54
- # Plot training predictions
55
- plot_binary_correlation(
56
- y_true_train,
57
- y_pred_train,
58
- threshold,
59
- title="Training Set Binary Classification Plot",
60
- output_file=os.path.join(output_path, 'train_classification_plot.png')
61
- )
62
-
63
- # Plot validation predictions
64
- plot_binary_correlation(
65
- y_true_val,
66
- y_pred_val,
67
- threshold,
68
- title="Validation Set Binary Classification Plot",
69
- output_file=os.path.join(output_path, 'val_classification_plot.png')
70
- )
71
-
72
- def plot_binary_correlation(y_true, y_pred, threshold, title, output_file):
73
- """
74
- Generates a scatter plot for binary classification and saves it to a file.
75
-
76
- Parameters:
77
- y_true (array): True labels.
78
- y_pred (array): Predicted probabilities.
79
- threshold (float): Classification threshold for predictions.
80
- title (str): Title of the plot.
81
- output_file (str): Path to save the plot.
82
- """
83
- # Scatter plot
84
- plt.figure(figsize=(10, 8))
85
- plt.scatter(y_true, y_pred, alpha=0.5, label='Data points', color='#BC80FF')
86
-
87
- # Add threshold line
88
- plt.axhline(y=threshold, color='red', linestyle='--', label=f'Threshold = {threshold}')
89
-
90
- # Add annotations
91
- plt.title(title)
92
- plt.xlabel("True Labels")
93
- plt.ylabel("Predicted Probability")
94
- plt.legend()
95
-
96
- # Save and show the plot
97
- plt.tight_layout()
98
- plt.savefig(output_file)
99
- plt.show()
100
-
101
- seed_everything(42)
102
-
103
- dataset = load_from_disk(f'{base_path}/data/solubility')
104
-
105
- sequences = np.stack(dataset['sequence']) # Ensure sequences are SMILES strings
106
- labels = np.stack(dataset['labels'])
107
- embeddings = np.stack(dataset['embedding'])
108
-
109
- # Initialize best F1 score and model path
110
- best_f1 = -np.inf
111
- best_model_path = f"{base_path}/src/solubility"
112
-
113
- # Trial callback
114
- def trial_info_callback(study, trial):
115
- if study.best_trial == trial:
116
- print(f"Trial {trial.number}:")
117
- print(f" Weighted F1 Score: {trial.value}")
118
-
119
-
120
-
121
- def objective(trial):
122
- # Define hyperparameters
123
- params = {
124
- 'objective': 'binary:logistic',
125
- 'lambda': trial.suggest_float('lambda', 1e-8, 50.0, log=True),
126
- 'alpha': trial.suggest_float('alpha', 1e-8, 50.0, log=True),
127
- 'colsample_bytree': trial.suggest_float('colsample_bytree', 0.3, 1.0),
128
- 'subsample': trial.suggest_float('subsample', 0.5, 1.0),
129
- 'learning_rate': trial.suggest_float('learning_rate', 0.001, 0.3),
130
- 'max_depth': trial.suggest_int('max_depth', 2, 15),
131
- 'min_child_weight': trial.suggest_int('min_child_weight', 1, 500),
132
- 'gamma': trial.suggest_float('gamma', 0, 10.0),
133
- 'tree_method': 'hist',
134
- 'device': 'cuda:6',
135
- }
136
-
137
- # Suggest number of boosting rounds
138
- num_boost_round = trial.suggest_int('num_boost_round', 10, 1000)
139
- threshold = 0.5 # Initial classification threshold
140
-
141
- # Split the data
142
- train_idx, val_idx = train_test_split(
143
- np.arange(len(sequences)), test_size=0.2, stratify=labels, random_state=42
144
- )
145
- train_subset = dataset.select(train_idx).with_format("torch")
146
- val_subset = dataset.select(val_idx).with_format("torch")
147
-
148
- # Extract embeddings and labels for train/validation
149
- train_embeddings = np.array(train_subset['embedding'])
150
- valid_embeddings = np.array(val_subset['embedding'])
151
- train_labels = np.array(train_subset['labels'])
152
- valid_labels = np.array(val_subset['labels'])
153
-
154
- # Prepare training and validation sets
155
- dtrain = xgb.DMatrix(train_embeddings, label=train_labels)
156
- dvalid = xgb.DMatrix(valid_embeddings, label=valid_labels)
157
-
158
- # Train the model
159
- model = xgb.train(
160
- params=params,
161
- dtrain=dtrain,
162
- num_boost_round=num_boost_round,
163
- evals=[(dvalid, "validation")],
164
- early_stopping_rounds=50,
165
- verbose_eval=False,
166
- )
167
-
168
- # Predict probabilities
169
- preds_train = model.predict(dtrain)
170
- preds_val = model.predict(dvalid)
171
-
172
- # Calculate metrics
173
- f1_val = f1_score(valid_labels, (preds_val >= threshold).astype(int), average="weighted")
174
- auc_val = roc_auc_score(valid_labels, preds_val)
175
- print(f"Trial {trial.number}: AUC: {auc_val:.3f}, F1 Score: {f1_val:.3f}")
176
-
177
- # Save the model if it has the best F1 score
178
- current_best = trial.study.user_attrs.get("best_f1", -np.inf)
179
- if f1_val > current_best:
180
- trial.study.set_user_attr("best_f1", f1_val)
181
- trial.study.set_user_attr("best_auc", auc_val)
182
- trial.study.set_user_attr("best_trial", trial.number)
183
- os.makedirs(best_model_path, exist_ok=True)
184
-
185
- # Save the model
186
- model.save_model(os.path.join(best_model_path, "best_model_f1.json"))
187
- print(f"✓ NEW BEST! Trial {trial.number}: F1={f1_val:.4f}, AUC={auc_val:.4f} - Model saved!")
188
-
189
- # Save and plot binary predictions
190
- save_and_plot_binary_predictions(
191
- train_labels, preds_train, valid_labels, preds_val, threshold, best_model_path
192
- )
193
-
194
- return f1_val
195
-
196
- if __name__ == "__main__":
197
- study = optuna.create_study(direction="maximize", pruner=optuna.pruners.MedianPruner())
198
- study.optimize(objective, n_trials=200)
199
-
200
- # Prepare summary text
201
- summary = []
202
- summary.append("\n" + "="*60)
203
- summary.append("OPTIMIZATION COMPLETE")
204
- summary.append("="*60)
205
- summary.append(f"Number of finished trials: {len(study.trials)}")
206
- summary.append(f"\nBest Trial: #{study.user_attrs.get('best_trial', 'N/A')}")
207
- summary.append(f"Best F1 Score: {study.user_attrs.get('best_f1', None):.4f}")
208
- summary.append(f"Best AUC Score: {study.user_attrs.get('best_auc', None):.4f}")
209
- summary.append(f"Optuna Best Trial Value: {study.best_trial.value:.4f}")
210
- summary.append(f"\nBest hyperparameters:")
211
- for key, value in study.best_trial.params.items():
212
- summary.append(f" {key}: {value}")
213
- summary.append("="*60)
214
-
215
- # Print to console
216
- for line in summary:
217
- print(line)
218
-
219
- # Save to file
220
- metrics_file = os.path.join(best_model_path, "optimization_metrics.txt")
221
- with open(metrics_file, 'w') as f:
222
- f.write('\n'.join(summary))
223
- print(f"\n✓ Metrics saved to: {metrics_file}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
train/binding_affinity_model_clean.ipynb DELETED
The diff for this file is too large to render. See raw diff
 
train/binding_utils.py DELETED
@@ -1,291 +0,0 @@
1
- from torch import nn
2
- import pdb
3
- import torch
4
- import numpy as np
5
-
6
- def to_var(x):
7
- if torch.cuda.is_available():
8
- x = x.cuda()
9
- return x
10
-
11
- class MultiHeadAttentionSequence(nn.Module):
12
-
13
- def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
14
-
15
- super().__init__()
16
-
17
- self.n_head = n_head
18
- self.d_model = d_model
19
- self.d_k = d_k
20
- self.d_v = d_v
21
-
22
- self.W_Q = nn.Linear(d_model, n_head*d_k)
23
- self.W_K = nn.Linear(d_model, n_head*d_k)
24
- self.W_V = nn.Linear(d_model, n_head*d_v)
25
- self.W_O = nn.Linear(n_head*d_v, d_model)
26
-
27
- self.layer_norm = nn.LayerNorm(d_model)
28
-
29
- self.dropout = nn.Dropout(dropout)
30
-
31
- def forward(self, q, k, v):
32
-
33
- batch, len_q, _ = q.size()
34
- batch, len_k, _ = k.size()
35
- batch, len_v, _ = v.size()
36
-
37
- Q = self.W_Q(q).view([batch, len_q, self.n_head, self.d_k])
38
- K = self.W_K(k).view([batch, len_k, self.n_head, self.d_k])
39
- V = self.W_V(v).view([batch, len_v, self.n_head, self.d_v])
40
-
41
- Q = Q.transpose(1, 2)
42
- K = K.transpose(1, 2).transpose(2, 3)
43
- V = V.transpose(1, 2)
44
-
45
- attention = torch.matmul(Q, K)
46
-
47
- attention = attention / np.sqrt(self.d_k)
48
-
49
- attention = F.softmax(attention, dim=-1)
50
-
51
- output = torch.matmul(attention, V)
52
-
53
- output = output.transpose(1, 2).reshape([batch, len_q, self.d_v*self.n_head])
54
-
55
- output = self.W_O(output)
56
-
57
- output = self.dropout(output)
58
-
59
- output = self.layer_norm(output + q)
60
-
61
- return output, attention
62
-
63
- class MultiHeadAttentionReciprocal(nn.Module):
64
-
65
-
66
- def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
67
-
68
- super().__init__()
69
-
70
- self.n_head = n_head
71
- self.d_model = d_model
72
- self.d_k = d_k
73
- self.d_v = d_v
74
-
75
- self.W_Q = nn.Linear(d_model, n_head*d_k)
76
- self.W_K = nn.Linear(d_model, n_head*d_k)
77
- self.W_V = nn.Linear(d_model, n_head*d_v)
78
- self.W_O = nn.Linear(n_head*d_v, d_model)
79
- self.W_V_2 = nn.Linear(d_model, n_head*d_v)
80
- self.W_O_2 = nn.Linear(n_head*d_v, d_model)
81
-
82
- self.layer_norm = nn.LayerNorm(d_model)
83
-
84
- self.dropout = nn.Dropout(dropout)
85
-
86
- self.layer_norm_2 = nn.LayerNorm(d_model)
87
-
88
- self.dropout_2 = nn.Dropout(dropout)
89
-
90
- def forward(self, q, k, v, v_2):
91
-
92
- batch, len_q, _ = q.size()
93
- batch, len_k, _ = k.size()
94
- batch, len_v, _ = v.size()
95
- batch, len_v_2, _ = v_2.size()
96
-
97
- Q = self.W_Q(q).view([batch, len_q, self.n_head, self.d_k])
98
- K = self.W_K(k).view([batch, len_k, self.n_head, self.d_k])
99
- V = self.W_V(v).view([batch, len_v, self.n_head, self.d_v])
100
- V_2 = self.W_V_2(v_2).view([batch, len_v_2, self.n_head, self.d_v])
101
-
102
- Q = Q.transpose(1, 2)
103
- K = K.transpose(1, 2).transpose(2, 3)
104
- V = V.transpose(1, 2)
105
- V_2 = V_2.transpose(1,2)
106
-
107
- attention = torch.matmul(Q, K)
108
-
109
-
110
- attention = attention /np.sqrt(self.d_k)
111
-
112
- attention_2 = attention.transpose(-2, -1)
113
-
114
-
115
-
116
- attention = F.softmax(attention, dim=-1)
117
-
118
- attention_2 = F.softmax(attention_2, dim=-1)
119
-
120
-
121
- output = torch.matmul(attention, V)
122
-
123
- output_2 = torch.matmul(attention_2, V_2)
124
-
125
- output = output.transpose(1, 2).reshape([batch, len_q, self.d_v*self.n_head])
126
-
127
- output_2 = output_2.transpose(1, 2).reshape([batch, len_k, self.d_v*self.n_head])
128
-
129
- output = self.W_O(output)
130
-
131
- output_2 = self.W_O_2(output_2)
132
-
133
- output = self.dropout(output)
134
-
135
- output = self.layer_norm(output + q)
136
-
137
- output_2 = self.dropout(output_2)
138
-
139
- output_2 = self.layer_norm(output_2 + k)
140
-
141
-
142
- return output, output_2, attention, attention_2
143
-
144
-
145
- class FFN(nn.Module):
146
-
147
- def __init__(self, d_in, d_hid, dropout=0.1):
148
- super().__init__()
149
-
150
- self.layer_1 = nn.Conv1d(d_in, d_hid,1)
151
- self.layer_2 = nn.Conv1d(d_hid, d_in,1)
152
- self.relu = nn.ReLU()
153
- self.layer_norm = nn.LayerNorm(d_in)
154
-
155
- self.dropout = nn.Dropout(dropout)
156
-
157
- def forward(self, x):
158
-
159
- residual = x
160
- output = self.layer_1(x.transpose(1, 2))
161
-
162
- output = self.relu(output)
163
-
164
- output = self.layer_2(output)
165
-
166
- output = self.dropout(output)
167
-
168
- output = self.layer_norm(output.transpose(1, 2)+residual)
169
-
170
- return output
171
-
172
- class ConvLayer(nn.Module):
173
- def __init__(self, in_channels, out_channels, kernel_size, padding, dilation):
174
- super(ConvLayer, self).__init__()
175
- self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, padding=padding, dilation=dilation)
176
- self.relu = nn.ReLU()
177
-
178
- def forward(self, x):
179
- out = self.conv(x)
180
- out = self.relu(out)
181
- return out
182
-
183
-
184
- class DilatedCNN(nn.Module):
185
- def __init__(self, d_model, d_hidden):
186
- super(DilatedCNN, self).__init__()
187
- self.first_ = nn.ModuleList()
188
- self.second_ = nn.ModuleList()
189
- self.third_ = nn.ModuleList()
190
-
191
- dilation_tuple = (1, 2, 3)
192
- dim_in_tuple = (d_model, d_hidden, d_hidden)
193
- dim_out_tuple = (d_hidden, d_hidden, d_hidden)
194
-
195
- for i, dilation_rate in enumerate(dilation_tuple):
196
- self.first_.append(ConvLayer(dim_in_tuple[i], dim_out_tuple[i], kernel_size=3, padding=dilation_rate,
197
- dilation=dilation_rate))
198
-
199
- for i, dilation_rate in enumerate(dilation_tuple):
200
- self.second_.append(ConvLayer(dim_in_tuple[i], dim_out_tuple[i], kernel_size=5, padding=2*dilation_rate,
201
- dilation=dilation_rate))
202
-
203
- for i, dilation_rate in enumerate(dilation_tuple):
204
- self.third_.append(ConvLayer(dim_in_tuple[i], dim_out_tuple[i], kernel_size=7, padding=3*dilation_rate,
205
- dilation=dilation_rate))
206
-
207
- def forward(self, protein_seq_enc):
208
- # pdb.set_trace()
209
- protein_seq_enc = protein_seq_enc.transpose(1, 2) # protein_seq_enc's shape: B*L*d_model -> B*d_model*L
210
-
211
- first_embedding = protein_seq_enc
212
- second_embedding = protein_seq_enc
213
- third_embedding = protein_seq_enc
214
-
215
- for i in range(len(self.first_)):
216
- first_embedding = self.first_[i](first_embedding)
217
-
218
- for i in range(len(self.second_)):
219
- second_embedding = self.second_[i](second_embedding)
220
-
221
- for i in range(len(self.third_)):
222
- third_embedding = self.third_[i](third_embedding)
223
-
224
- # pdb.set_trace()
225
-
226
- protein_seq_enc = first_embedding + second_embedding + third_embedding
227
-
228
- return protein_seq_enc.transpose(1, 2)
229
-
230
-
231
- class ReciprocalLayerwithCNN(nn.Module):
232
-
233
- def __init__(self, d_model, d_inner, d_hidden, n_head, d_k, d_v):
234
- super().__init__()
235
-
236
- self.cnn = DilatedCNN(d_model, d_hidden)
237
-
238
- self.sequence_attention_layer = MultiHeadAttentionSequence(n_head, d_hidden, d_k, d_v)
239
-
240
- self.protein_attention_layer = MultiHeadAttentionSequence(n_head, d_hidden, d_k, d_v)
241
-
242
- self.reciprocal_attention_layer = MultiHeadAttentionReciprocal(n_head, d_hidden, d_k, d_v)
243
-
244
- self.ffn_seq = FFN(d_hidden, d_inner)
245
-
246
- self.ffn_protein = FFN(d_hidden, d_inner)
247
-
248
- def forward(self, sequence_enc, protein_seq_enc):
249
- # pdb.set_trace() # protein_seq_enc.shape = B * L * d_model
250
- protein_seq_enc = self.cnn(protein_seq_enc)
251
- prot_enc, prot_attention = self.protein_attention_layer(protein_seq_enc, protein_seq_enc, protein_seq_enc)
252
-
253
- seq_enc, sequence_attention = self.sequence_attention_layer(sequence_enc, sequence_enc, sequence_enc)
254
-
255
- prot_enc, seq_enc, prot_seq_attention, seq_prot_attention = self.reciprocal_attention_layer(prot_enc, seq_enc, seq_enc, prot_enc)
256
-
257
- prot_enc = self.ffn_protein(prot_enc)
258
-
259
- seq_enc = self.ffn_seq(seq_enc)
260
-
261
- return prot_enc, seq_enc, prot_attention, sequence_attention, prot_seq_attention, seq_prot_attention
262
-
263
-
264
- class ReciprocalLayer(nn.Module):
265
-
266
- def __init__(self, d_model, d_inner, n_head, d_k, d_v):
267
-
268
- super().__init__()
269
-
270
- self.sequence_attention_layer = MultiHeadAttentionSequence(n_head, d_model, d_k, d_v)
271
-
272
- self.protein_attention_layer = MultiHeadAttentionSequence(n_head, d_model, d_k, d_v)
273
-
274
- self.reciprocal_attention_layer = MultiHeadAttentionReciprocal(n_head, d_model, d_k, d_v)
275
-
276
- self.ffn_seq = FFN(d_model, d_inner)
277
-
278
- self.ffn_protein = FFN(d_model, d_inner)
279
-
280
- def forward(self, sequence_enc, protein_seq_enc):
281
- prot_enc, prot_attention = self.protein_attention_layer(protein_seq_enc, protein_seq_enc, protein_seq_enc)
282
-
283
- seq_enc, sequence_attention = self.sequence_attention_layer(sequence_enc, sequence_enc, sequence_enc)
284
-
285
-
286
- prot_enc, seq_enc, prot_seq_attention, seq_prot_attention = self.reciprocal_attention_layer(prot_enc, seq_enc, seq_enc, prot_enc)
287
- prot_enc = self.ffn_protein(prot_enc)
288
-
289
- seq_enc = self.ffn_seq(seq_enc)
290
-
291
- return prot_enc, seq_enc, prot_attention, sequence_attention, prot_seq_attention, seq_prot_attention
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
train/permeability_xg.py DELETED
@@ -1,186 +0,0 @@
1
- import pandas as pd
2
- import numpy as np
3
- import optuna
4
- from optuna.trial import TrialState
5
- from rdkit import Chem
6
- from rdkit.Chem import AllChem
7
- from sklearn.metrics import mean_squared_error
8
- from sklearn.model_selection import train_test_split
9
- import xgboost as xgb
10
- import os
11
- from datasets import load_from_disk
12
- from scipy.stats import spearmanr
13
- import matplotlib.pyplot as plt
14
-
15
- base_path = "/scratch/pranamlab/sophtang/home/scoring/PeptiVerse"
16
-
17
- def save_and_plot_predictions(y_true_train, y_pred_train, y_true_val, y_pred_val, output_path):
18
- os.makedirs(output_path, exist_ok=True)
19
-
20
- # Save training predictions
21
- train_df = pd.DataFrame({'True Permeability': y_true_train, 'Predicted Permeability': y_pred_train})
22
- train_df.to_csv(os.path.join(output_path, 'train_predictions.csv'), index=False)
23
-
24
- # Save validation predictions
25
- val_df = pd.DataFrame({'True Permeability': y_true_val, 'Predicted Permeability': y_pred_val})
26
- val_df.to_csv(os.path.join(output_path, 'val_predictions.csv'), index=False)
27
-
28
- # Plot training predictions
29
- plot_correlation(
30
- y_true_train,
31
- y_pred_train,
32
- title="Training Set Correlation Plot",
33
- output_file=os.path.join(output_path, 'train_correlation.png'),
34
- )
35
-
36
- # Plot validation predictions
37
- plot_correlation(
38
- y_true_val,
39
- y_pred_val,
40
- title="Validation Set Correlation Plot",
41
- output_file=os.path.join(output_path, 'val_correlation.png'),
42
- )
43
-
44
- def plot_correlation(y_true, y_pred, title, output_file):
45
- spearman_corr, _ = spearmanr(y_true, y_pred)
46
-
47
- # Scatter plot
48
- plt.figure(figsize=(10, 8))
49
- plt.scatter(y_true, y_pred, alpha=0.5, label='Data points', color='#BC80FF')
50
- plt.plot([min(y_true), max(y_true)], [min(y_true), max(y_true)], color='teal', linestyle='--', label='Ideal fit')
51
-
52
- # Add annotations
53
- plt.title(f"{title}\nSpearman Correlation: {spearman_corr:.3f}")
54
- plt.xlabel("True Permeability (logP)")
55
- plt.ylabel("Predicted Affinity (logP)")
56
- plt.legend()
57
-
58
- # Save and show the plot
59
- plt.tight_layout()
60
- plt.savefig(output_file)
61
- plt.show()
62
-
63
- # Load dataset
64
- dataset = load_from_disk(f'{base_path}/data/permeability')
65
-
66
- # Extract sequences, labels, and embeddings
67
- sequences = np.stack(dataset['sequence'])
68
- labels = np.stack(dataset['labels']) # Regression labels
69
- embeddings = np.stack(dataset['embedding']) # Pre-trained embeddings
70
-
71
- # Function to compute Morgan fingerprints
72
- def compute_morgan_fingerprints(smiles_list, radius=2, n_bits=2048):
73
- fps = []
74
- for smiles in smiles_list:
75
- mol = Chem.MolFromSmiles(smiles)
76
- if mol is not None:
77
- fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=n_bits)
78
- fps.append(np.array(fp))
79
- else:
80
- # If the SMILES string is invalid, use a zero vector
81
- fps.append(np.zeros(n_bits))
82
- print(f"Invalid SMILES: {smiles}")
83
- return np.array(fps)
84
-
85
- # Compute Morgan fingerprints for the sequences
86
- #morgan_fingerprints = compute_morgan_fingerprints(sequences)
87
-
88
- # Concatenate embeddings with Morgan fingerprints
89
- #input_features = np.concatenate([embeddings, morgan_fingerprints], axis=1)
90
- input_features = embeddings
91
-
92
- # Initialize global variables
93
- best_model_path = f"{base_path}/src/permeability"
94
- os.makedirs(best_model_path, exist_ok=True)
95
-
96
- def trial_info_callback(study, trial):
97
- if study.best_trial == trial:
98
- print(f"Trial {trial.number}:")
99
- print(f" MSE: {trial.value}")
100
-
101
- def objective(trial):
102
- # Define hyperparameters
103
- params = {
104
- 'objective': 'reg:squarederror',
105
- 'lambda': trial.suggest_float('lambda', 0.1, 10.0, log=True),
106
- 'alpha': trial.suggest_float('alpha', 0.1, 10.0, log=True),
107
- 'gamma': trial.suggest_float('gamma', 0, 5),
108
- 'colsample_bytree': trial.suggest_float('colsample_bytree', 0.5, 1.0),
109
- 'subsample': trial.suggest_float('subsample', 0.6, 0.9),
110
- 'learning_rate': trial.suggest_float('learning_rate', 1e-5, 0.1),
111
- 'max_depth': trial.suggest_int('max_depth', 2, 30),
112
- 'min_child_weight': trial.suggest_int('min_child_weight', 1, 20),
113
- 'tree_method': 'hist',
114
- 'scale_pos_weight': trial.suggest_float('scale_pos_weight', 0.5, 10.0, log=True),
115
- 'device': 'cuda:6',
116
- }
117
- num_boost_round = trial.suggest_int('num_boost_round', 10, 1000)
118
-
119
- # Train-validation split
120
- X_train, X_val, y_train, y_val = train_test_split(input_features, labels, test_size=0.2, random_state=42)
121
-
122
- # Convert data to DMatrix
123
- dtrain = xgb.DMatrix(X_train, label=y_train)
124
- dvalid = xgb.DMatrix(X_val, label=y_val)
125
-
126
- # Train XGBoost
127
- model = xgb.train(
128
- params=params,
129
- dtrain=dtrain,
130
- num_boost_round=num_boost_round,
131
- evals=[(dvalid, "validation")],
132
- early_stopping_rounds=50,
133
- verbose_eval=False,
134
- )
135
-
136
- # Predict and evaluate
137
- preds_train = model.predict(dtrain)
138
- preds_val = model.predict(dvalid)
139
-
140
- mse = mean_squared_error(y_val, preds_val)
141
-
142
- # Calculate Spearman Rank Correlation for both train and validation
143
- spearman_train, _ = spearmanr(y_train, preds_train)
144
- spearman_val, _ = spearmanr(y_val, preds_val)
145
- print(f"Train Spearman: {spearman_train:.4f}, Val Spearman: {spearman_val:.4f}")
146
-
147
- # Save the best model
148
- if trial.study.user_attrs.get("best_mse", np.inf) > mse:
149
- trial.study.set_user_attr("best_mse", mse)
150
- trial.study.set_user_attr("best_spearman_train", spearman_train)
151
- trial.study.set_user_attr("best_spearman_val", spearman_val)
152
- trial.study.set_user_attr("best_trial", trial.number)
153
- model.save_model(os.path.join(best_model_path, "best_model.json"))
154
- save_and_plot_predictions(y_train, preds_train, y_val, preds_val, best_model_path)
155
- print(f"✓ NEW BEST! Trial {trial.number}: MSE={mse:.4f}, Train Spearman={spearman_train:.4f}, Val Spearman={spearman_val:.4f}")
156
-
157
- return mse
158
-
159
- if __name__ == "__main__":
160
- study = optuna.create_study(direction="minimize", pruner=optuna.pruners.MedianPruner())
161
- study.optimize(objective, n_trials=200, callbacks=[trial_info_callback])
162
-
163
- # Prepare summary text
164
- summary = []
165
- summary.append("\n" + "="*60)
166
- summary.append("OPTIMIZATION COMPLETE")
167
- summary.append("="*60)
168
- summary.append(f"Number of finished trials: {len(study.trials)}")
169
- summary.append(f"\nBest Trial: #{study.user_attrs.get('best_trial', 'N/A')}")
170
- summary.append(f"Best MSE: {study.best_trial.value:.4f}")
171
- summary.append(f"Best Training Spearman Correlation: {study.user_attrs.get('best_spearman_train', None):.4f}")
172
- summary.append(f"Best Validation Spearman Correlation: {study.user_attrs.get('best_spearman_val', None):.4f}")
173
- summary.append(f"\nBest hyperparameters:")
174
- for key, value in study.best_trial.params.items():
175
- summary.append(f" {key}: {value}")
176
- summary.append("="*60)
177
-
178
- # Print to console
179
- for line in summary:
180
- print(line)
181
-
182
- # Save to file
183
- metrics_file = os.path.join(best_model_path, "optimization_metrics.txt")
184
- with open(metrics_file, 'w') as f:
185
- f.write('\n'.join(summary))
186
- print(f"\n✓ Metrics saved to: {metrics_file}")