ctheodoris's picture
fixed bug related to dynamic ranges in dictionary with 'min' and 'max' value mismatch in optuna suggest fn (#380)
fe1640b verified
import os
import random
import numpy as np
import pandas as pd
import torch
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from .imports import *
from .model import GeneformerMultiTask
from .utils import calculate_task_specific_metrics, get_layer_freeze_range
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def initialize_wandb(config):
if config.get("use_wandb", False):
import wandb
wandb.init(project=config["wandb_project"], config=config)
print("Weights & Biases (wandb) initialized and will be used for logging.")
else:
print(
"Weights & Biases (wandb) is not enabled. Logging will use other methods."
)
def create_model(config, num_labels_list, device):
model = GeneformerMultiTask(
config["pretrained_path"],
num_labels_list,
dropout_rate=config["dropout_rate"],
use_task_weights=config["use_task_weights"],
task_weights=config["task_weights"],
max_layers_to_freeze=config["max_layers_to_freeze"],
use_attention_pooling=config["use_attention_pooling"],
)
if config["use_data_parallel"]:
model = nn.DataParallel(model)
return model.to(device)
def setup_optimizer_and_scheduler(model, config, total_steps):
optimizer = AdamW(
model.parameters(),
lr=config["learning_rate"],
weight_decay=config["weight_decay"],
)
warmup_steps = int(config["warmup_ratio"] * total_steps)
if config["lr_scheduler_type"] == "linear":
scheduler = get_linear_schedule_with_warmup(
optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps
)
elif config["lr_scheduler_type"] == "cosine":
scheduler = get_cosine_schedule_with_warmup(
optimizer,
num_warmup_steps=warmup_steps,
num_training_steps=total_steps,
num_cycles=0.5,
)
return optimizer, scheduler
def train_epoch(
model, train_loader, optimizer, scheduler, device, config, writer, epoch
):
model.train()
progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['epochs']}")
for batch_idx, batch in enumerate(progress_bar):
optimizer.zero_grad()
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
labels = [
batch["labels"][task_name].to(device) for task_name in config["task_names"]
]
loss, _, _ = model(input_ids, attention_mask, labels)
loss.backward()
if config["gradient_clipping"]:
torch.nn.utils.clip_grad_norm_(model.parameters(), config["max_grad_norm"])
optimizer.step()
scheduler.step()
writer.add_scalar(
"Training Loss", loss.item(), epoch * len(train_loader) + batch_idx
)
if config.get("use_wandb", False):
import wandb
wandb.log({"Training Loss": loss.item()})
# Update progress bar
progress_bar.set_postfix({"loss": f"{loss.item():.4f}"})
return loss.item() # Return the last batch loss
def validate_model(model, val_loader, device, config):
model.eval()
val_loss = 0.0
task_true_labels = {task_name: [] for task_name in config["task_names"]}
task_pred_labels = {task_name: [] for task_name in config["task_names"]}
task_pred_probs = {task_name: [] for task_name in config["task_names"]}
with torch.no_grad():
for batch in val_loader:
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
labels = [
batch["labels"][task_name].to(device)
for task_name in config["task_names"]
]
loss, logits, _ = model(input_ids, attention_mask, labels)
val_loss += loss.item()
for sample_idx in range(len(batch["input_ids"])):
for i, task_name in enumerate(config["task_names"]):
true_label = batch["labels"][task_name][sample_idx].item()
pred_label = torch.argmax(logits[i][sample_idx], dim=-1).item()
pred_prob = (
torch.softmax(logits[i][sample_idx], dim=-1).cpu().numpy()
)
task_true_labels[task_name].append(true_label)
task_pred_labels[task_name].append(pred_label)
task_pred_probs[task_name].append(pred_prob)
val_loss /= len(val_loader)
return val_loss, task_true_labels, task_pred_labels, task_pred_probs
def log_metrics(task_metrics, val_loss, config, writer, epochs):
for task_name, metrics in task_metrics.items():
print(
f"{task_name} - Validation F1 Macro: {metrics['f1']:.4f}, Validation Accuracy: {metrics['accuracy']:.4f}"
)
if config.get("use_wandb", False):
import wandb
wandb.log(
{
f"{task_name} Validation F1 Macro": metrics["f1"],
f"{task_name} Validation Accuracy": metrics["accuracy"],
}
)
writer.add_scalar("Validation Loss", val_loss, epochs)
for task_name, metrics in task_metrics.items():
writer.add_scalar(f"{task_name} - Validation F1 Macro", metrics["f1"], epochs)
writer.add_scalar(
f"{task_name} - Validation Accuracy", metrics["accuracy"], epochs
)
def save_validation_predictions(
val_cell_id_mapping,
task_true_labels,
task_pred_labels,
task_pred_probs,
config,
trial_number=None,
):
if trial_number is not None:
trial_results_dir = os.path.join(config["results_dir"], f"trial_{trial_number}")
os.makedirs(trial_results_dir, exist_ok=True)
val_preds_file = os.path.join(trial_results_dir, "val_preds.csv")
else:
val_preds_file = os.path.join(config["results_dir"], "manual_run_val_preds.csv")
rows = []
for sample_idx in range(len(val_cell_id_mapping)):
row = {"Cell ID": val_cell_id_mapping[sample_idx]}
for task_name in config["task_names"]:
row[f"{task_name} True"] = task_true_labels[task_name][sample_idx]
row[f"{task_name} Pred"] = task_pred_labels[task_name][sample_idx]
row[f"{task_name} Probabilities"] = ",".join(
map(str, task_pred_probs[task_name][sample_idx])
)
rows.append(row)
df = pd.DataFrame(rows)
df.to_csv(val_preds_file, index=False)
print(f"Validation predictions saved to {val_preds_file}")
def train_model(
config,
device,
train_loader,
val_loader,
train_cell_id_mapping,
val_cell_id_mapping,
num_labels_list,
):
set_seed(config["seed"])
initialize_wandb(config)
model = create_model(config, num_labels_list, device)
total_steps = len(train_loader) * config["epochs"]
optimizer, scheduler = setup_optimizer_and_scheduler(model, config, total_steps)
log_dir = os.path.join(config["tensorboard_log_dir"], "manual_run")
writer = SummaryWriter(log_dir=log_dir)
epoch_progress = tqdm(range(config["epochs"]), desc="Training Progress")
for epoch in epoch_progress:
last_loss = train_epoch(
model, train_loader, optimizer, scheduler, device, config, writer, epoch
)
epoch_progress.set_postfix({"last_loss": f"{last_loss:.4f}"})
val_loss, task_true_labels, task_pred_labels, task_pred_probs = validate_model(
model, val_loader, device, config
)
task_metrics = calculate_task_specific_metrics(task_true_labels, task_pred_labels)
log_metrics(task_metrics, val_loss, config, writer, config["epochs"])
writer.close()
save_validation_predictions(
val_cell_id_mapping, task_true_labels, task_pred_labels, task_pred_probs, config
)
if config.get("use_wandb", False):
import wandb
wandb.finish()
print(f"\nFinal Validation Loss: {val_loss:.4f}")
return val_loss, model # Return both the validation loss and the trained model
def objective(
trial,
train_loader,
val_loader,
train_cell_id_mapping,
val_cell_id_mapping,
num_labels_list,
config,
device,
):
set_seed(config["seed"]) # Set the seed before each trial
initialize_wandb(config)
# Hyperparameters
config["learning_rate"] = trial.suggest_float(
"learning_rate",
config["hyperparameters"]["learning_rate"]["low"],
config["hyperparameters"]["learning_rate"]["high"],
log=config["hyperparameters"]["learning_rate"]["log"],
)
config["warmup_ratio"] = trial.suggest_float(
"warmup_ratio",
config["hyperparameters"]["warmup_ratio"]["low"],
config["hyperparameters"]["warmup_ratio"]["high"],
)
config["weight_decay"] = trial.suggest_float(
"weight_decay",
config["hyperparameters"]["weight_decay"]["low"],
config["hyperparameters"]["weight_decay"]["high"],
)
config["dropout_rate"] = trial.suggest_float(
"dropout_rate",
config["hyperparameters"]["dropout_rate"]["low"],
config["hyperparameters"]["dropout_rate"]["high"],
)
config["lr_scheduler_type"] = trial.suggest_categorical(
"lr_scheduler_type", config["hyperparameters"]["lr_scheduler_type"]["choices"]
)
config["use_attention_pooling"] = trial.suggest_categorical(
"use_attention_pooling", [False]
)
if config["use_task_weights"]:
config["task_weights"] = [
trial.suggest_float(
f"task_weight_{i}",
config["hyperparameters"]["task_weights"]["low"],
config["hyperparameters"]["task_weights"]["high"],
)
for i in range(len(num_labels_list))
]
weight_sum = sum(config["task_weights"])
config["task_weights"] = [
weight / weight_sum for weight in config["task_weights"]
]
else:
config["task_weights"] = None
# Dynamic range for max_layers_to_freeze
freeze_range = get_layer_freeze_range(config["pretrained_path"])
config["max_layers_to_freeze"] = trial.suggest_int(
"max_layers_to_freeze",
freeze_range["min"],
freeze_range["max"]
)
model = create_model(config, num_labels_list, device)
total_steps = len(train_loader) * config["epochs"]
optimizer, scheduler = setup_optimizer_and_scheduler(model, config, total_steps)
log_dir = os.path.join(config["tensorboard_log_dir"], f"trial_{trial.number}")
writer = SummaryWriter(log_dir=log_dir)
for epoch in range(config["epochs"]):
train_epoch(
model, train_loader, optimizer, scheduler, device, config, writer, epoch
)
val_loss, task_true_labels, task_pred_labels, task_pred_probs = validate_model(
model, val_loader, device, config
)
task_metrics = calculate_task_specific_metrics(task_true_labels, task_pred_labels)
log_metrics(task_metrics, val_loss, config, writer, config["epochs"])
writer.close()
save_validation_predictions(
val_cell_id_mapping,
task_true_labels,
task_pred_labels,
task_pred_probs,
config,
trial.number,
)
trial.set_user_attr("model_state_dict", model.state_dict())
trial.set_user_attr("task_weights", config["task_weights"])
trial.report(val_loss, config["epochs"])
if trial.should_prune():
raise optuna.TrialPruned()
if config.get("use_wandb", False):
import wandb
wandb.log(
{
"trial_number": trial.number,
"val_loss": val_loss,
**{
f"{task_name}_f1": metrics["f1"]
for task_name, metrics in task_metrics.items()
},
**{
f"{task_name}_accuracy": metrics["accuracy"]
for task_name, metrics in task_metrics.items()
},
**{
k: v
for k, v in config.items()
if k
in [
"learning_rate",
"warmup_ratio",
"weight_decay",
"dropout_rate",
"lr_scheduler_type",
"use_attention_pooling",
"max_layers_to_freeze",
]
},
}
)
wandb.finish()
return val_loss