|
import pandas as pd |
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
from torch.utils.data import Dataset, DataLoader |
|
from transformers import BertConfig, BertModel, AutoTokenizer |
|
from rdkit import Chem, RDLogger |
|
from rdkit.Chem.Scaffolds import MurckoScaffold |
|
import copy |
|
from tqdm import tqdm |
|
import os |
|
from sklearn.metrics import roc_auc_score, root_mean_squared_error, mean_absolute_error |
|
from itertools import compress |
|
from collections import defaultdict |
|
from sklearn.metrics.pairwise import cosine_similarity |
|
from sklearn.preprocessing import StandardScaler, MinMaxScaler |
|
import optuna |
|
import warnings |
|
warnings.filterwarnings("ignore") |
|
RDLogger.DisableLog('rdApp.*') |
|
|
|
torch.set_float32_matmul_precision('high') |
|
|
|
|
|
class PrecomputedContrastiveSmilesDataset(Dataset): |
|
""" |
|
A Dataset class that reads pre-augmented SMILES pairs from a Parquet file. |
|
This is significantly faster as it offloads the expensive SMILES randomization |
|
to a one-time preprocessing step. |
|
""" |
|
def __init__(self, tokenizer, file_path: str, max_length: int = 512): |
|
self.tokenizer = tokenizer |
|
self.max_length = max_length |
|
|
|
|
|
|
|
print(f"Loading pre-computed data from {file_path}...") |
|
self.data = pd.read_parquet(file_path) |
|
print("Data loaded successfully.") |
|
|
|
def __len__(self): |
|
"""Returns the total number of pairs in the dataset.""" |
|
return len(self.data) |
|
|
|
def __getitem__(self, idx): |
|
""" |
|
Retrieves a pre-augmented pair, tokenizes it, and returns it |
|
in the format expected by the DataCollator. |
|
""" |
|
|
|
row = self.data.iloc[idx] |
|
smiles_1 = row['smiles_1'] |
|
smiles_2 = row['smiles_2'] |
|
|
|
|
|
tokens_1 = self.tokenizer(smiles_1, max_length=self.max_length, truncation=True, padding='max_length') |
|
tokens_2 = self.tokenizer(smiles_2, max_length=self.max_length, truncation=True, padding='max_length') |
|
|
|
return { |
|
'input_ids_1': torch.tensor(tokens_1['input_ids']), |
|
'attention_mask_1': torch.tensor(tokens_1['attention_mask']), |
|
'input_ids_2': torch.tensor(tokens_2['input_ids']), |
|
'attention_mask_2': torch.tensor(tokens_2['attention_mask']), |
|
} |
|
|
|
|
|
class SmilesEnumerator: |
|
"""Generates randomized SMILES strings for data augmentation.""" |
|
def randomize_smiles(self, smiles): |
|
try: |
|
mol = Chem.MolFromSmiles(smiles) |
|
return Chem.MolToSmiles(mol, doRandom=True, canonical=False) if mol else smiles |
|
except: |
|
return smiles |
|
|
|
def compute_embedding_similarity_precomputed(encoder, dataset, device): |
|
""" |
|
Compute embedding similarity using pre-computed augmented SMILES pairs |
|
""" |
|
encoder.eval() |
|
similarities = [] |
|
|
|
dataloader = DataLoader(dataset, batch_size=32, shuffle=False) |
|
|
|
with torch.no_grad(): |
|
for batch in dataloader: |
|
input_ids_1 = batch['input_ids_1'].to(device) |
|
attention_mask_1 = batch['attention_mask_1'].to(device) |
|
input_ids_2 = batch['input_ids_2'].to(device) |
|
attention_mask_2 = batch['attention_mask_2'].to(device) |
|
|
|
emb_1 = encoder(input_ids_1, attention_mask_1).cpu().numpy() |
|
emb_2 = encoder(input_ids_2, attention_mask_2).cpu().numpy() |
|
|
|
|
|
batch_similarities = [] |
|
for i in range(len(emb_1)): |
|
sim = cosine_similarity([emb_1[i]], [emb_2[i]])[0][0] |
|
batch_similarities.append(sim) |
|
|
|
similarities.extend(batch_similarities) |
|
|
|
return np.array(similarities) |
|
|
|
def create_augmented_smiles_file(smiles_list, output_path, num_augmentations=1): |
|
""" |
|
Create a parquet file with pre-computed augmented SMILES pairs |
|
""" |
|
enumerator = SmilesEnumerator() |
|
pairs = [] |
|
|
|
print(f"Generating {num_augmentations} augmentations for {len(smiles_list)} SMILES...") |
|
|
|
for smiles in tqdm(smiles_list): |
|
for _ in range(num_augmentations): |
|
augmented = enumerator.randomize_smiles(smiles) |
|
pairs.append({ |
|
'smiles_1': smiles, |
|
'smiles_2': augmented |
|
}) |
|
|
|
df = pd.DataFrame(pairs) |
|
df.to_parquet(output_path, index=False) |
|
print(f"Saved {len(pairs)} augmented pairs to {output_path}") |
|
return output_path |
|
|
|
|
|
def load_lists_from_url(data): |
|
|
|
if data == 'bbbp': |
|
df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/BBBP.csv') |
|
smiles, labels = df.smiles, df.p_np |
|
elif data == 'clintox': |
|
df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/clintox.csv.gz', compression='gzip') |
|
smiles = df.smiles |
|
labels = df.drop(['smiles'], axis=1) |
|
elif data == 'hiv': |
|
df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/HIV.csv') |
|
smiles, labels = df.smiles, df.HIV_active |
|
elif data == 'sider': |
|
df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/sider.csv.gz', compression='gzip') |
|
smiles = df.smiles |
|
labels = df.drop(['smiles'], axis=1) |
|
elif data == 'esol': |
|
df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/delaney-processed.csv') |
|
smiles = df.smiles |
|
labels = df['ESOL predicted log solubility in mols per litre'] |
|
elif data == 'freesolv': |
|
df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/SAMPL.csv') |
|
smiles = df.smiles |
|
labels = df.calc |
|
elif data == 'lipophicility': |
|
df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/Lipophilicity.csv') |
|
smiles, labels = df.smiles, df['exp'] |
|
elif data == 'tox21': |
|
df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/tox21.csv.gz', compression='gzip') |
|
df = df.dropna(axis=0, how='any').reset_index(drop=True) |
|
smiles = df.smiles |
|
labels = df.drop(['mol_id', 'smiles'], axis=1) |
|
elif data == 'bace': |
|
df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/bace.csv') |
|
smiles, labels = df.mol, df.Class |
|
elif data == 'qm8': |
|
df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/qm8.csv') |
|
df = df.dropna(axis=0, how='any').reset_index(drop=True) |
|
smiles = df.smiles |
|
labels = df.drop(['smiles', 'E2-PBE0.1', 'E1-PBE0.1', 'f1-PBE0.1', 'f2-PBE0.1'], axis=1) |
|
return smiles, labels |
|
|
|
|
|
class ScaffoldSplitter: |
|
def __init__(self, data, seed, train_frac=0.8, val_frac=0.1, test_frac=0.1, include_chirality=True): |
|
self.data = data |
|
self.seed = seed |
|
self.include_chirality = include_chirality |
|
self.train_frac = train_frac |
|
self.val_frac = val_frac |
|
self.test_frac = test_frac |
|
|
|
def generate_scaffold(self, smiles): |
|
mol = Chem.MolFromSmiles(smiles) |
|
scaffold = MurckoScaffold.MurckoScaffoldSmiles(mol=mol, includeChirality=self.include_chirality) |
|
return scaffold |
|
|
|
def scaffold_split(self): |
|
smiles, labels = load_lists_from_url(self.data) |
|
non_null = np.ones(len(smiles)) == 0 |
|
|
|
if self.data in {'tox21', 'sider', 'clintox'}: |
|
for i in range(len(smiles)): |
|
if Chem.MolFromSmiles(smiles[i]) and labels.loc[i].isnull().sum() == 0: |
|
non_null[i] = 1 |
|
else: |
|
for i in range(len(smiles)): |
|
if Chem.MolFromSmiles(smiles[i]): |
|
non_null[i] = 1 |
|
|
|
smiles_list = list(compress(enumerate(smiles), non_null)) |
|
rng = np.random.RandomState(self.seed) |
|
|
|
scaffolds = defaultdict(list) |
|
for i, sms in smiles_list: |
|
scaffold = self.generate_scaffold(sms) |
|
scaffolds[scaffold].append(i) |
|
|
|
scaffold_sets = list(scaffolds.values()) |
|
rng.shuffle(scaffold_sets) |
|
n_total_val = int(np.floor(self.val_frac * len(smiles_list))) |
|
n_total_test = int(np.floor(self.test_frac * len(smiles_list))) |
|
train_idx, val_idx, test_idx = [], [], [] |
|
|
|
for scaffold_set in scaffold_sets: |
|
if len(val_idx) + len(scaffold_set) <= n_total_val: |
|
val_idx.extend(scaffold_set) |
|
elif len(test_idx) + len(scaffold_set) <= n_total_test: |
|
test_idx.extend(scaffold_set) |
|
else: |
|
train_idx.extend(scaffold_set) |
|
return train_idx, val_idx, test_idx |
|
|
|
|
|
def random_split_indices(n, seed=42, train_frac=0.8, val_frac=0.1, test_frac=0.1): |
|
np.random.seed(seed) |
|
indices = np.random.permutation(n) |
|
n_train = int(n * train_frac) |
|
n_val = int(n * val_frac) |
|
train_idx = indices[:n_train] |
|
val_idx = indices[n_train:n_train+n_val] |
|
test_idx = indices[n_train+n_val:] |
|
return train_idx.tolist(), val_idx.tolist(), test_idx.tolist() |
|
|
|
|
|
class MoleculeDataset(Dataset): |
|
def __init__(self, smiles_list, labels, tokenizer, max_len=512): |
|
self.smiles_list = smiles_list |
|
self.labels = labels |
|
self.tokenizer = tokenizer |
|
self.max_len = max_len |
|
|
|
def __len__(self): |
|
return len(self.smiles_list) |
|
|
|
def __getitem__(self, idx): |
|
smiles = self.smiles_list[idx] |
|
label = self.labels.iloc[idx] |
|
|
|
encoding = self.tokenizer( |
|
smiles, |
|
truncation=True, |
|
padding='max_length', |
|
max_length=self.max_len, |
|
return_tensors='pt' |
|
) |
|
item = {key: val.squeeze(0) for key, val in encoding.items()} |
|
if isinstance(label, pd.Series): |
|
label_values = label.values.astype(np.float32) |
|
else: |
|
label_values = np.array([label], dtype=np.float32) |
|
item['labels'] = torch.tensor(label_values, dtype=torch.float) |
|
return item |
|
|
|
|
|
def global_ap(x): |
|
return torch.mean(x.view(x.size(0), x.size(1), -1), dim=1) |
|
|
|
class SimSonEncoder(nn.Module): |
|
def __init__(self, config: BertConfig, max_len: int, dropout: float = 0.1): |
|
super(SimSonEncoder, self).__init__() |
|
self.config = config |
|
self.max_len = max_len |
|
self.bert = BertModel(config, add_pooling_layer=False) |
|
self.linear = nn.Linear(config.hidden_size, max_len) |
|
self.dropout = nn.Dropout(dropout) |
|
def forward(self, input_ids, attention_mask=None): |
|
if attention_mask is None: |
|
attention_mask = input_ids.ne(self.config.pad_token_id) |
|
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) |
|
hidden_states = self.dropout(outputs.last_hidden_state) |
|
pooled = global_ap(hidden_states) |
|
return self.linear(pooled) |
|
|
|
class SimSonClassifier(nn.Module): |
|
def __init__(self, encoder: SimSonEncoder, num_labels: int, dropout=0.1): |
|
super(SimSonClassifier, self).__init__() |
|
self.encoder = encoder |
|
self.clf = nn.Linear(encoder.max_len, num_labels) |
|
self.relu = nn.ReLU() |
|
self.dropout = nn.Dropout(dropout) |
|
def forward(self, input_ids, attention_mask=None): |
|
x = self.encoder(input_ids, attention_mask) |
|
x = self.relu(self.dropout(x)) |
|
logits = self.clf(x) |
|
return logits |
|
|
|
def load_encoder_params(self, state_dict_path): |
|
self.encoder.load_state_dict(torch.load(state_dict_path)) |
|
|
|
|
|
def get_criterion(task_type, num_labels): |
|
if task_type == 'classification': |
|
return nn.BCEWithLogitsLoss() |
|
elif task_type == 'regression': |
|
return nn.MSELoss() |
|
else: |
|
raise ValueError(f"Unknown task type: {task_type}") |
|
|
|
def train_epoch(model, dataloader, optimizer, scheduler, criterion, device): |
|
model.train() |
|
total_loss = 0 |
|
for batch in dataloader: |
|
inputs = {k: v.to(device) for k, v in batch.items() if k != 'labels'} |
|
labels = batch['labels'].to(device) |
|
optimizer.zero_grad() |
|
outputs = model(**inputs) |
|
loss = criterion(outputs, labels) |
|
loss.backward() |
|
optimizer.step() |
|
if scheduler is not None: |
|
scheduler.step() |
|
total_loss += loss.item() |
|
return total_loss / len(dataloader) |
|
|
|
def calc_val_metrics(model, dataloader, criterion, device, task_type): |
|
model.eval() |
|
all_labels, all_preds = [], [] |
|
total_loss = 0 |
|
with torch.no_grad(): |
|
for batch in dataloader: |
|
inputs = {k: v.to(device) for k, v in batch.items() if k != 'labels'} |
|
labels = batch['labels'].to(device) |
|
outputs = model(**inputs) |
|
loss = criterion(outputs, labels) |
|
total_loss += loss.item() |
|
if task_type == 'classification': |
|
pred_probs = torch.sigmoid(outputs).cpu().numpy() |
|
all_preds.append(pred_probs) |
|
all_labels.append(labels.cpu().numpy()) |
|
else: |
|
|
|
preds = outputs.cpu().numpy() |
|
all_preds.append(preds) |
|
all_labels.append(labels.cpu().numpy()) |
|
avg_loss = total_loss / len(dataloader) |
|
if task_type == 'classification': |
|
y_true = np.concatenate(all_labels) |
|
y_pred = np.concatenate(all_preds) |
|
try: |
|
score = roc_auc_score(y_true, y_pred, average='macro') |
|
except Exception: |
|
score = 0.0 |
|
return avg_loss, score |
|
else: |
|
return avg_loss, None |
|
|
|
def test_model(model, dataloader, device, task_type): |
|
model.eval() |
|
all_preds, all_labels = [], [] |
|
with torch.no_grad(): |
|
for batch in dataloader: |
|
inputs = {k: v.to(device) for k, v in batch.items() if k != 'labels'} |
|
labels = batch['labels'] |
|
outputs = model(**inputs) |
|
if task_type == 'classification': |
|
preds = torch.sigmoid(outputs) |
|
else: |
|
preds = outputs |
|
all_preds.append(preds.cpu().numpy()) |
|
all_labels.append(labels.numpy()) |
|
return np.concatenate(all_preds), np.concatenate(all_labels) |
|
|
|
|
|
def create_objective(name, info, train_smiles, train_labels, val_smiles, val_labels, |
|
test_smiles, test_labels, scaler, tokenizer, encoder_config, device): |
|
"""Creates objective function for Optuna optimization""" |
|
|
|
def objective(trial): |
|
|
|
lr = trial.suggest_float('lr', 1e-6, 1e-4, log=True) |
|
batch_size = trial.suggest_categorical('batch_size', [16, 32, 64, 128, 256]) |
|
dropout = trial.suggest_float('dropout', 0.1, 0.5) |
|
weight_decay = trial.suggest_float('weight_decay', 0.0, 0.1) |
|
scheduler_type = trial.suggest_categorical('scheduler', ['plateau', 'cosine', 'step']) |
|
|
|
|
|
patience_lr = trial.suggest_int('patience_lr', 3, 10) |
|
gamma = trial.suggest_float('gamma', 0.5, 0.9) if scheduler_type == 'step' else 0.1 |
|
|
|
try: |
|
|
|
train_dataset = MoleculeDataset(train_smiles, train_labels, tokenizer, 512) |
|
val_dataset = MoleculeDataset(val_smiles, val_labels, tokenizer, 512) |
|
test_dataset = MoleculeDataset(test_smiles, test_labels, tokenizer, 512) |
|
|
|
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) |
|
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) |
|
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) |
|
|
|
|
|
encoder = SimSonEncoder(encoder_config, 512, dropout=dropout) |
|
encoder = torch.compile(encoder) |
|
model = SimSonClassifier(encoder, num_labels=info['num_labels'], dropout=dropout).to(device) |
|
model.load_encoder_params('../simson_checkpoints/checkpoint_best_model.bin') |
|
|
|
criterion = get_criterion(info['task_type'], info['num_labels']) |
|
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay) |
|
|
|
|
|
if scheduler_type == 'plateau': |
|
scheduler = optim.lr_scheduler.ReduceLROnPlateau( |
|
optimizer, mode='max', factor=gamma, patience=patience_lr |
|
) |
|
elif scheduler_type == 'cosine': |
|
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50) |
|
else: |
|
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=gamma) |
|
|
|
|
|
best_val_metric = -np.inf |
|
patience_counter = 0 |
|
patience = 15 |
|
|
|
for epoch in range(50): |
|
train_loss = train_epoch(model, train_loader, optimizer, |
|
scheduler if scheduler_type == 'cosine' else None, |
|
criterion, device) |
|
val_loss, val_metric = calc_val_metrics(model, val_loader, criterion, device, info['task_type']) |
|
|
|
|
|
if scheduler_type == 'plateau': |
|
scheduler.step(val_loss if val_loss is not None else -val_loss) |
|
elif scheduler_type == 'step': |
|
scheduler.step() |
|
|
|
|
|
if info['task_type'] == 'classification': |
|
current_metric = val_loss if val_loss is not None else 0.0 |
|
else: |
|
current_metric = -val_loss |
|
|
|
|
|
if current_metric <= val_loss: |
|
best_val_metric = current_metric |
|
patience_counter = 0 |
|
else: |
|
patience_counter += 1 |
|
if patience_counter >= patience: |
|
break |
|
|
|
|
|
trial.report(current_metric, epoch) |
|
if trial.should_prune(): |
|
raise optuna.TrialPruned() |
|
|
|
return best_val_metric |
|
|
|
except Exception as e: |
|
print(f"Trial failed with error: {e}") |
|
return -np.inf |
|
|
|
return objective |
|
|
|
|
|
def main(): |
|
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
print(f"Using device: {DEVICE}") |
|
|
|
DATASETS_TO_RUN = { |
|
|
|
|
|
|
|
|
|
'clintox': {'task_type': 'classification', 'num_labels': 2, 'split': 'scaffold'}, |
|
'tox21': {'task_type': 'classification', 'num_labels': 12, 'split': 'random'}, |
|
'bbbp': {'task_type': 'classification', 'num_labels': 1, 'split': 'scaffold'}, |
|
'hiv': {'task_type': 'classification', 'num_labels': 1, 'split': 'scaffold'}, |
|
} |
|
|
|
MAX_LEN = 512 |
|
N_TRIALS = 100 |
|
|
|
TOKENIZER = AutoTokenizer.from_pretrained('DeepChem/ChemBERTa-77M-MTR') |
|
ENCODER_CONFIG = BertConfig( |
|
vocab_size=TOKENIZER.vocab_size, |
|
hidden_size=768, |
|
num_hidden_layers=4, |
|
num_attention_heads=12, |
|
intermediate_size=2048, |
|
max_position_embeddings=512 |
|
) |
|
|
|
aggregated_results = {} |
|
|
|
for name, info in DATASETS_TO_RUN.items(): |
|
print(f"\n{'='*20} Processing Dataset: {name.upper()} ({info['split']} split) {'='*20}") |
|
smiles, labels = load_lists_from_url(name) |
|
|
|
|
|
scaler = None |
|
if info["task_type"] == "regression": |
|
scaler = StandardScaler() |
|
all_labels = labels.values.reshape(-1, 1) |
|
scaler.fit(all_labels) |
|
labels = pd.Series(scaler.transform(all_labels).flatten(), index=labels.index) |
|
|
|
|
|
if info.get('split', 'scaffold') == 'scaffold': |
|
splitter = ScaffoldSplitter(data=name, seed=42) |
|
train_idx, val_idx, test_idx = splitter.scaffold_split() |
|
elif info['split'] == 'random': |
|
train_idx, val_idx, test_idx = random_split_indices(len(smiles), seed=42) |
|
else: |
|
raise ValueError(f"Unknown split type for {name}: {info['split']}") |
|
|
|
train_smiles = smiles.iloc[train_idx].reset_index(drop=True) |
|
train_labels = labels.iloc[train_idx].reset_index(drop=True) |
|
val_smiles = smiles.iloc[val_idx].reset_index(drop=True) |
|
val_labels = labels.iloc[val_idx].reset_index(drop=True) |
|
test_smiles = smiles.iloc[test_idx].reset_index(drop=True) |
|
test_labels = labels.iloc[test_idx].reset_index(drop=True) |
|
print(f"Data split - Train: {len(train_smiles)}, Val: {len(val_smiles)}, Test: {len(test_smiles)}") |
|
|
|
|
|
study = optuna.create_study( |
|
direction='maximize', |
|
pruner=optuna.pruners.MedianPruner(n_startup_trials=5, n_warmup_steps=10) |
|
) |
|
|
|
|
|
objective_func = create_objective( |
|
name, info, train_smiles, train_labels, val_smiles, val_labels, |
|
test_smiles, test_labels, scaler, TOKENIZER, ENCODER_CONFIG, DEVICE |
|
) |
|
|
|
|
|
print(f"Starting Optuna optimization with {N_TRIALS} trials...") |
|
study.optimize(objective_func, n_trials=N_TRIALS, timeout=None) |
|
|
|
|
|
best_params = study.best_params |
|
best_score = study.best_value |
|
print(f"Best parameters: {best_params}") |
|
print(f"Best validation score: {0:.4f}") |
|
|
|
|
|
print("Training final model with best parameters...") |
|
train_dataset = MoleculeDataset(train_smiles, train_labels, TOKENIZER, MAX_LEN) |
|
val_dataset = MoleculeDataset(val_smiles, val_labels, TOKENIZER, MAX_LEN) |
|
test_dataset = MoleculeDataset(test_smiles, test_labels, TOKENIZER, MAX_LEN) |
|
|
|
train_loader = DataLoader(train_dataset, batch_size=best_params['batch_size'], shuffle=True) |
|
val_loader = DataLoader(val_dataset, batch_size=best_params['batch_size'], shuffle=False) |
|
test_loader = DataLoader(test_dataset, batch_size=best_params['batch_size'], shuffle=False) |
|
|
|
|
|
encoder = SimSonEncoder(ENCODER_CONFIG, 512, dropout=best_params['dropout']) |
|
encoder = torch.compile(encoder) |
|
model = SimSonClassifier(encoder, num_labels=info['num_labels'], dropout=best_params['dropout']).to(DEVICE) |
|
model.load_encoder_params('../simson_checkpoints/checkpoint_best_model.bin') |
|
|
|
criterion = get_criterion(info['task_type'], info['num_labels']) |
|
optimizer = optim.Adam(model.parameters(), lr=best_params['lr'], weight_decay=best_params['weight_decay']) |
|
|
|
|
|
if best_params['scheduler'] == 'plateau': |
|
scheduler = optim.lr_scheduler.ReduceLROnPlateau( |
|
optimizer, mode='max', factor=best_params.get('gamma', 0.7), |
|
patience=best_params.get('patience_lr', 5) |
|
) |
|
elif best_params['scheduler'] == 'cosine': |
|
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50) |
|
else: |
|
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=best_params.get('gamma', 0.1)) |
|
|
|
|
|
best_val_metric = -np.inf |
|
best_model_state = None |
|
patience_counter = 0 |
|
patience = 15 |
|
|
|
for epoch in range(50): |
|
train_loss = train_epoch(model, train_loader, optimizer, |
|
scheduler if best_params['scheduler'] == 'cosine' else None, |
|
criterion, DEVICE) |
|
val_loss, val_metric = calc_val_metrics(model, val_loader, criterion, DEVICE, info['task_type']) |
|
|
|
if best_params['scheduler'] == 'plateau': |
|
scheduler.step(val_loss if val_loss is not None else -val_loss) |
|
elif best_params['scheduler'] == 'step': |
|
scheduler.step() |
|
|
|
if info['task_type'] == 'classification': |
|
print(f"Epoch {epoch+1}/50 | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | ROC AUC: {val_metric:.4f}") |
|
current_metric = val_metric if val_metric is not None else 0.0 |
|
else: |
|
print(f"Epoch {epoch+1}/50 | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}") |
|
current_metric = -val_loss |
|
|
|
if current_metric <= val_loss: |
|
best_val_metric = current_metric |
|
best_model_state = copy.deepcopy(model.state_dict()) |
|
patience_counter = 0 |
|
else: |
|
patience_counter += 1 |
|
if patience_counter >= patience: |
|
print(f'Early stopping at epoch {epoch+1}') |
|
break |
|
|
|
|
|
if best_model_state is not None: |
|
model.load_state_dict(best_model_state) |
|
|
|
test_preds, test_true = test_model(model, test_loader, DEVICE, info['task_type']) |
|
|
|
|
|
if info['task_type'] == 'regression' and scaler is not None: |
|
test_preds = scaler.inverse_transform(test_preds.reshape(-1, 1)).flatten() |
|
test_true = scaler.inverse_transform(test_true.reshape(-1, 1)).flatten() |
|
rmse = root_mean_squared_error(test_true, test_preds) |
|
mae = mean_absolute_error(test_true, test_preds) |
|
final_score = -rmse |
|
print(f"Test RMSE: {rmse:.4f}, MAE: {mae:.4f}") |
|
else: |
|
try: |
|
final_score = roc_auc_score(test_true, test_preds, average='macro') |
|
print(f"Test ROC AUC: {final_score:.4f}") |
|
except Exception: |
|
final_score = 0.0 |
|
|
|
|
|
print("Creating pre-computed augmented SMILES for similarity computation...") |
|
test_smiles_list = list(test_smiles) |
|
similarity_file_path = f"{name}_test_augmented.parquet" |
|
create_augmented_smiles_file(test_smiles_list, similarity_file_path, num_augmentations=1) |
|
|
|
|
|
similarity_dataset = PrecomputedContrastiveSmilesDataset( |
|
TOKENIZER, similarity_file_path, max_length=MAX_LEN |
|
) |
|
|
|
similarities = compute_embedding_similarity_precomputed( |
|
model.encoder, similarity_dataset, DEVICE |
|
) |
|
print(f"Similarity score: {similarities.mean():.4f}") |
|
|
|
|
|
if os.path.exists(similarity_file_path): |
|
os.remove(similarity_file_path) |
|
|
|
aggregated_results[name] = { |
|
'best_score': final_score, |
|
'best_params': best_params, |
|
'optuna_trials': len(study.trials), |
|
'study': study, |
|
'similarity_score': similarities.mean() |
|
} |
|
|
|
if name == 'do_not_save': |
|
torch.save(model.encoder.state_dict(), 'moleculenet_clintox_encoder.bin') |
|
|
|
print(f"\n{'='*20} AGGREGATED RESULTS {'='*20}") |
|
for name, result in aggregated_results.items(): |
|
print(f"{name}: Best score: {result['best_score']:.4f}") |
|
print(f" Best parameters: {result['best_params']}") |
|
print(f" Total trials: {result['optuna_trials']}") |
|
print(f" Similarity score: {result['similarity_score']:.4f}") |
|
|
|
print("\nScript finished.") |
|
|
|
if __name__ == '__main__': |
|
main() |
|
|