| import argparse |
| import copy |
|
|
| import numpy as np |
| import torch |
| from omegaconf import OmegaConf |
| from sklearn.metrics import roc_auc_score |
| from torch import nn, optim |
|
|
| from barista.data.braintreebank_dataset import BrainTreebankDataset |
| from barista.models.model import Barista |
| from barista.models.utils import seed_everything |
|
|
|
|
| def parse_args(): |
| """Parse command line arguments.""" |
| parser = argparse.ArgumentParser( |
| description="Fine-tune Barista model on BrainTreebank dataset" |
| ) |
| parser.add_argument( |
| "--dataset_config", |
| type=str, |
| default="barista/config/braintreebank.yaml", |
| help="Path to dataset configuration file", |
| ) |
| parser.add_argument( |
| "--train_config", |
| type=str, |
| default="barista/config/train.yaml", |
| help="Path to training configuration file", |
| ) |
| parser.add_argument( |
| "--model_config", |
| type=str, |
| default="barista/config/model.yaml", |
| help="Path to model configuration file", |
| ) |
| parser.add_argument( |
| "--override", |
| type=str, |
| nargs="+", |
| default=[], |
| help="Override config parameters (e.g., --override epochs=50 optimization.finetune_lr=1e-4)", |
| ) |
| return parser.parse_args() |
|
|
|
|
| def load_configs(args): |
| """Load all configuration files.""" |
| dataset_config = OmegaConf.load(args.dataset_config) |
| train_config = OmegaConf.load(args.train_config) |
| model_config = OmegaConf.load(args.model_config) |
|
|
| assert ( |
| len(dataset_config.finetune_sessions) == 1 |
| ), "Specify one session for finetuning" |
|
|
| return dataset_config, train_config, model_config |
|
|
|
|
| def apply_overrides(config_dict, overrides): |
| """Apply command-line overrides to configs using dot notation.""" |
| if not overrides: |
| return config_dict |
|
|
| override_dict = {} |
| for override in overrides: |
| if "=" not in override: |
| raise ValueError( |
| f"Invalid override format: {override}. Expected format: key=value" |
| ) |
|
|
| key, value = override.split("=", 1) |
|
|
| try: |
| if value.isnumeric(): |
| if "." in value: |
| value = float(value) |
| else: |
| value = int(value) |
| elif value.startswith("[") or value in ("True", "False"): |
| value = eval(value) |
| except ValueError as e: |
| print(e) |
| pass |
|
|
| keys = key.split(".") |
| current = override_dict |
| for k in keys[:-1]: |
| if k not in current: |
| current[k] = {} |
| current = current[k] |
| current[keys[-1]] = value |
|
|
| |
| override_conf = OmegaConf.create(override_dict) |
|
|
| |
| merged_configs = {} |
| for config_name, config in config_dict.items(): |
| config_keys = set(OmegaConf.to_container(config).keys()) |
| override_keys = set(override_dict.keys()) |
|
|
| if config_keys.intersection(override_keys): |
| merged_configs[config_name] = OmegaConf.merge(config, override_conf) |
| else: |
| merged_configs[config_name] = config |
|
|
| if merged_configs.get("train") is not None: |
| merged_configs["train"] = OmegaConf.merge( |
| merged_configs["train"], override_conf |
| ) |
|
|
| return merged_configs |
|
|
|
|
| def setup_dataloaders(dataset_config, train_config): |
| """Initialize dataset and create dataloaders.""" |
| dataset = BrainTreebankDataset(dataset_config) |
|
|
| train_dataloader = dataset.get_dataloader("train", train_config) |
| val_dataloader = dataset.get_dataloader("val", train_config) |
| test_dataloader = dataset.get_dataloader("test", train_config) |
|
|
| print(f"Train: {len(train_dataloader.dataset.metadata)} samples") |
| print(f"Val: {len(val_dataloader.dataset.metadata)} samples") |
| print(f"Test: {len(test_dataloader.dataset.metadata)} samples") |
|
|
| dataset.check_no_common_segment(train_dataloader, val_dataloader, test_dataloader) |
|
|
| return dataset, train_dataloader, val_dataloader, test_dataloader |
|
|
|
|
| def get_optimizer(model, finetune_lr=1e-4, new_param_lr=1e-3): |
| """Create optimizer with different learning rates for task and upstream parameters.""" |
| task_params, upstream_params = [], [] |
|
|
| for _, p in model.get_task_params(): |
| if p.requires_grad: |
| task_params.append(p) |
|
|
| for _, p in model.get_upstream_params(): |
| if p.requires_grad: |
| upstream_params.append(p) |
|
|
| params = [ |
| {"params": upstream_params, "lr": finetune_lr}, |
| {"params": task_params, "lr": new_param_lr}, |
| ] |
|
|
| optimizer = optim.AdamW(params, lr=finetune_lr, weight_decay=1e-2) |
| return optimizer |
|
|
|
|
| def get_lr_scheduler(optimizer): |
| """Create learning rate scheduler with warmup and exponential decay.""" |
| milestone = 5 |
|
|
| lr_schedulers_list = [ |
| torch.optim.lr_scheduler.LinearLR( |
| optimizer, |
| start_factor=0.2, |
| end_factor=1.0, |
| total_iters=milestone, |
| ), |
| torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99), |
| ] |
|
|
| lr_scheduler = torch.optim.lr_scheduler.SequentialLR( |
| optimizer, |
| lr_schedulers_list, |
| milestones=[milestone], |
| ) |
| return lr_scheduler |
|
|
|
|
| def load_pretrained_weights(model, checkpoint_path, device): |
| """Load pretrained weights, excluding masked_recon and multi_head_fc layers.""" |
| checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True) |
| model.load_state_dict(checkpoint) |
| print(f"Pretrained weights loaded from {checkpoint_path}") |
| return model |
|
|
|
|
| def freeze_tokenizer(model): |
| for n, p in model.tokenizer.named_parameters(): |
| p.requires_grad = False |
|
|
|
|
| def print_number_of_parmas(model): |
| trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| total_params = sum(p.numel() for p in model.parameters()) |
|
|
| print(f"Model parameters: {total_params}\t Trainable params: {trainable_params}") |
|
|
|
|
| def run_epoch( |
| model, dataloader, criterion, device, optimizer=None, scheduler=None, train=False |
| ): |
| """Run one epoch of training or evaluation.""" |
| if train: |
| model.train() |
| else: |
| model.eval() |
|
|
| all_preds = [] |
| all_labels = [] |
| running_loss = 0 |
|
|
| for batch in dataloader: |
| x = [x_item.to(device) for x_item in batch.x] |
| y = batch.labels.flatten().long().to(device) |
|
|
| if train: |
| optimizer.zero_grad() |
|
|
| with torch.set_grad_enabled(train): |
| logits = model( |
| x, |
| subject_sessions=batch.subject_sessions, |
| ) |
| loss = criterion(logits, y) |
|
|
| if train: |
| loss.backward() |
| optimizer.step() |
|
|
| running_loss += loss.item() * y.size(0) |
|
|
| probs = torch.softmax(logits, dim=1)[:, 1].detach().cpu().numpy() |
| labels = y.detach().cpu().numpy() |
|
|
| all_preds.append(probs) |
| all_labels.append(labels) |
|
|
| if train: |
| |
| scheduler.step() |
|
|
| all_preds = np.concatenate(all_preds) |
| all_labels = np.concatenate(all_labels) |
|
|
| try: |
| auc = roc_auc_score(all_labels, all_preds) |
| except: |
| auc = float("nan") |
|
|
| avg_loss = running_loss / len(dataloader.dataset) |
| return avg_loss, auc |
|
|
|
|
| def finetune_model(model, train_dataloader, val_dataloader, train_config, device): |
| """Finetune the model and track best validation performance.""" |
| criterion = nn.CrossEntropyLoss() |
| optimizer = get_optimizer( |
| model, |
| finetune_lr=train_config.optimization.finetune_lr, |
| new_param_lr=train_config.optimization.new_param_lr, |
| ) |
| scheduler = get_lr_scheduler(optimizer) |
|
|
| best_val_auc = -1 |
| best_state = None |
| num_epochs = train_config.epochs |
|
|
| for epoch in range(num_epochs): |
| train_loss, train_auc = run_epoch( |
| model, train_dataloader, criterion, device, optimizer, scheduler, train=True |
| ) |
| val_loss, val_auc = evaluate_model(model, val_dataloader, criterion, device) |
|
|
| print( |
| f"Epoch {epoch+1}/{num_epochs} " |
| f"- Train Loss: {train_loss:.4f}, AUC: {train_auc:.4f} " |
| f"- Val Loss: {val_loss:.4f}, AUC: {val_auc:.4f}" |
| ) |
|
|
| |
| if best_state is None or val_auc > best_val_auc: |
| best_val_auc = val_auc |
| best_state = { |
| "epoch": epoch + 1, |
| "model": copy.deepcopy(model.state_dict()), |
| "optimizer": copy.deepcopy(optimizer.state_dict()), |
| "scheduler": copy.deepcopy(scheduler.state_dict()), |
| "val_auc": val_auc, |
| } |
|
|
| return best_state, criterion |
|
|
|
|
| def evaluate_model(model, test_dataloader, criterion, device): |
| """Evaluate model on test set.""" |
| test_loss, test_auc = run_epoch( |
| model, test_dataloader, criterion, device, train=False |
| ) |
| return test_loss, test_auc |
|
|
|
|
| def main(): |
| """Main training pipeline.""" |
| |
| args = parse_args() |
| dataset_config, train_config, model_config = load_configs(args) |
|
|
| configs = {"dataset": dataset_config, "train": train_config, "model": model_config} |
| configs = apply_overrides(configs, args.override) |
| dataset_config = configs["dataset"] |
| train_config = configs["train"] |
| model_config = configs["model"] |
|
|
| |
| seed_everything(train_config.seed) |
|
|
| |
| dataset, train_dataloader, val_dataloader, test_dataloader = setup_dataloaders( |
| dataset_config, train_config |
| ) |
|
|
| |
| ft_session = dataset_config.finetune_sessions[0] |
| ft_session_n_chans = dataset.metadata.get_subject_session_full_d_data()[ft_session][ |
| -1 |
| ] |
|
|
| |
| device = train_config.device |
| model = Barista(model_config, dataset.metadata) |
|
|
| |
| if train_config.checkpoint_path: |
| print("Running pretrained model") |
| model = load_pretrained_weights(model, train_config.checkpoint_path, device) |
|
|
| |
| if train_config.optimization.freeze_tokenizer: |
| freeze_tokenizer(model) |
|
|
| else: |
| print("Running non-pretrained model") |
|
|
| |
| model.create_downstream_head(n_chans=ft_session_n_chans, output_dim=2) |
| model.to(device) |
|
|
| print_number_of_parmas(model) |
|
|
| |
| best_state, criterion = finetune_model( |
| model, train_dataloader, val_dataloader, train_config, device |
| ) |
| print(f"\nBEST VAL AUC: {best_state['val_auc']:.4f}") |
|
|
| |
| _, last_test_auc = evaluate_model(model, test_dataloader, criterion, device) |
| print(f"LAST TEST AUC: {last_test_auc:.4f}") |
|
|
| |
| model.load_state_dict(best_state["model"]) |
|
|
| |
| _, test_auc = evaluate_model(model, test_dataloader, criterion, device) |
|
|
| print(f"BEST TEST AUC: {test_auc:.4f}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|