import argparse import json import os import warnings from typing import Dict, Any from datetime import datetime def parse_args() -> Dict[str, Any]: """Parse and validate command line arguments.""" parser = create_argument_parser() args = parser.parse_args() # Validate and process arguments validate_args(args) process_dataset_config(args) setup_output_dirs(args) setup_wandb_config(args) return args def create_argument_parser() -> argparse.ArgumentParser: """Create argument parser with all training arguments.""" parser = argparse.ArgumentParser() # Model parameters add_model_args(parser) # Dataset parameters add_dataset_args(parser) # Training parameters add_training_args(parser) # Output parameters add_output_args(parser) # Wandb parameters add_wandb_args(parser) return parser def add_model_args(parser: argparse.ArgumentParser): """Add model-related arguments.""" model_group = parser.add_argument_group('Model Configuration') model_group.add_argument('--hidden_size', type=int, default=None) model_group.add_argument('--num_attention_head', type=int, default=8) model_group.add_argument('--attention_probs_dropout', type=float, default=0.1) model_group.add_argument('--plm_model', type=str, default='facebook/esm2_t33_650M_UR50D') model_group.add_argument('--pooling_method', type=str, default='mean', choices=['mean', 'attention1d', 'light_attention']) model_group.add_argument('--pooling_dropout', type=float, default=0.1) def add_dataset_args(parser: argparse.ArgumentParser): """Add dataset-related arguments.""" data_group = parser.add_argument_group('Dataset Configuration') data_group.add_argument('--dataset', type=str) data_group.add_argument('--dataset_config', type=str) data_group.add_argument('--normalize', type=str) data_group.add_argument('--num_labels', type=int) data_group.add_argument('--problem_type', type=str) data_group.add_argument('--pdb_type', type=str) data_group.add_argument('--train_file', type=str) data_group.add_argument('--valid_file', type=str) data_group.add_argument('--test_file', type=str) data_group.add_argument('--metrics', type=str) def add_training_args(parser: argparse.ArgumentParser): """Add training-related arguments.""" train_group = parser.add_argument_group('Training Configuration') train_group.add_argument('--seed', type=int, default=3407) train_group.add_argument('--learning_rate', type=float, default=1e-3) train_group.add_argument('--scheduler', type=str, choices=['linear', 'cosine', 'step']) train_group.add_argument('--warmup_steps', type=int, default=0) train_group.add_argument('--num_workers', type=int, default=4) train_group.add_argument('--batch_size', type=int) train_group.add_argument('--batch_token', type=int) train_group.add_argument('--num_epochs', type=int, default=100) train_group.add_argument('--max_seq_len', type=int, default=-1) train_group.add_argument('--gradient_accumulation_steps', type=int, default=1) train_group.add_argument('--max_grad_norm', type=float, default=-1) train_group.add_argument('--patience', type=int, default=10) train_group.add_argument('--monitor', type=str) train_group.add_argument('--monitor_strategy', type=str, choices=['max', 'min']) train_group.add_argument('--training_method', type=str, default='freeze', choices=['full', 'freeze', 'lora', 'ses-adapter', 'plm-lora', 'plm-qlora', 'plm-adalora', 'plm-dora', 'plm-ia3']) parser.add_argument("--lora_r", type=int, default=8, help="lora r") parser.add_argument("--lora_alpha", type=int, default=32, help="lora_alpha") parser.add_argument("--lora_dropout", type=float, default=0.1, help="lora_dropout") parser.add_argument("--feedforward_modules", type=str, default="w0") parser.add_argument( "--lora_target_modules", nargs="+", default=["query", "key", "value"], help="lora target module", ) train_group.add_argument('--structure_seq', type=str, default='') def add_output_args(parser: argparse.ArgumentParser): """Add output-related arguments.""" output_group = parser.add_argument_group('Output Configuration') output_group.add_argument('--output_model_name', type=str) output_group.add_argument('--output_root', default="ckpt") output_group.add_argument('--output_dir', default=None) def add_wandb_args(parser: argparse.ArgumentParser): """Add wandb-related arguments.""" wandb_group = parser.add_argument_group('Wandb Configuration') wandb_group.add_argument('--wandb', action='store_true') wandb_group.add_argument('--wandb_entity', type=str) wandb_group.add_argument('--wandb_project', type=str, default='VenusFactory') wandb_group.add_argument('--wandb_run_name', type=str) def validate_args(args: argparse.Namespace): """Validate command line arguments.""" if args.batch_size is None and args.batch_token is None: raise ValueError("batch_size or batch_token must be provided") if args.training_method == 'ses-adapter': if args.structure_seq is None: raise ValueError("structure_seq must be provided for ses-adapter") args.structure_seq = args.structure_seq.split(',') else: args.structure_seq = [] def process_dataset_config(args: argparse.Namespace): """Process dataset configuration file.""" if not args.dataset_config: return config = json.load(open(args.dataset_config)) # Update args with dataset config values if not already set for key in ['dataset', 'pdb_type', 'train_file', 'valid_file', 'test_file', 'num_labels', 'problem_type', 'monitor', 'monitor_strategy', 'metrics', 'normalize']: if getattr(args, key) is None and key in config: setattr(args, key, config[key]) # Handle metrics specially if args.metrics: args.metrics = args.metrics.split(',') if args.metrics == ['None']: args.metrics = ['loss'] warnings.warn("No metrics provided, using default metrics: loss") def setup_output_dirs(args: argparse.Namespace): """Setup output directories.""" if args.output_dir is None: current_date = strftime("%Y%m%d", localtime()) args.output_dir = os.path.join(args.output_root, current_date) else: args.output_dir = os.path.join(args.output_root, args.output_dir) os.makedirs(args.output_dir, exist_ok=True) def setup_wandb_config(args: argparse.Namespace): """Setup wandb configuration.""" if args.wandb: if args.wandb_run_name is None: args.wandb_run_name = f"VenusFactory-{args.dataset}" if args.output_model_name is None: args.output_model_name = f"{args.wandb_run_name}.pt"