Spaces:
Runtime error
Runtime error
import argparse | |
import json | |
import os | |
import warnings | |
from typing import Dict, Any | |
from datetime import datetime | |
def parse_args() -> Dict[str, Any]: | |
"""Parse and validate command line arguments.""" | |
parser = create_argument_parser() | |
args = parser.parse_args() | |
# Validate and process arguments | |
validate_args(args) | |
process_dataset_config(args) | |
setup_output_dirs(args) | |
setup_wandb_config(args) | |
return args | |
def create_argument_parser() -> argparse.ArgumentParser: | |
"""Create argument parser with all training arguments.""" | |
parser = argparse.ArgumentParser() | |
# Model parameters | |
add_model_args(parser) | |
# Dataset parameters | |
add_dataset_args(parser) | |
# Training parameters | |
add_training_args(parser) | |
# Output parameters | |
add_output_args(parser) | |
# Wandb parameters | |
add_wandb_args(parser) | |
return parser | |
def add_model_args(parser: argparse.ArgumentParser): | |
"""Add model-related arguments.""" | |
model_group = parser.add_argument_group('Model Configuration') | |
model_group.add_argument('--hidden_size', type=int, default=None) | |
model_group.add_argument('--num_attention_head', type=int, default=8) | |
model_group.add_argument('--attention_probs_dropout', type=float, default=0.1) | |
model_group.add_argument('--plm_model', type=str, default='facebook/esm2_t33_650M_UR50D') | |
model_group.add_argument('--pooling_method', type=str, default='mean', | |
choices=['mean', 'attention1d', 'light_attention']) | |
model_group.add_argument('--pooling_dropout', type=float, default=0.1) | |
def add_dataset_args(parser: argparse.ArgumentParser): | |
"""Add dataset-related arguments.""" | |
data_group = parser.add_argument_group('Dataset Configuration') | |
data_group.add_argument('--dataset', type=str) | |
data_group.add_argument('--dataset_config', type=str) | |
data_group.add_argument('--normalize', type=str) | |
data_group.add_argument('--num_labels', type=int) | |
data_group.add_argument('--problem_type', type=str) | |
data_group.add_argument('--pdb_type', type=str) | |
data_group.add_argument('--train_file', type=str) | |
data_group.add_argument('--valid_file', type=str) | |
data_group.add_argument('--test_file', type=str) | |
data_group.add_argument('--metrics', type=str) | |
def add_training_args(parser: argparse.ArgumentParser): | |
"""Add training-related arguments.""" | |
train_group = parser.add_argument_group('Training Configuration') | |
train_group.add_argument('--seed', type=int, default=3407) | |
train_group.add_argument('--learning_rate', type=float, default=1e-3) | |
train_group.add_argument('--scheduler', type=str, choices=['linear', 'cosine', 'step']) | |
train_group.add_argument('--warmup_steps', type=int, default=0) | |
train_group.add_argument('--num_workers', type=int, default=4) | |
train_group.add_argument('--batch_size', type=int) | |
train_group.add_argument('--batch_token', type=int) | |
train_group.add_argument('--num_epochs', type=int, default=100) | |
train_group.add_argument('--max_seq_len', type=int, default=-1) | |
train_group.add_argument('--gradient_accumulation_steps', type=int, default=1) | |
train_group.add_argument('--max_grad_norm', type=float, default=-1) | |
train_group.add_argument('--patience', type=int, default=10) | |
train_group.add_argument('--monitor', type=str) | |
train_group.add_argument('--monitor_strategy', type=str, choices=['max', 'min']) | |
train_group.add_argument('--training_method', type=str, default='freeze', | |
choices=['full', 'freeze', 'lora', 'ses-adapter', 'plm-lora', 'plm-qlora', 'plm-adalora', 'plm-dora', 'plm-ia3']) | |
parser.add_argument("--lora_r", type=int, default=8, help="lora r") | |
parser.add_argument("--lora_alpha", type=int, default=32, help="lora_alpha") | |
parser.add_argument("--lora_dropout", type=float, default=0.1, help="lora_dropout") | |
parser.add_argument("--feedforward_modules", type=str, default="w0") | |
parser.add_argument( | |
"--lora_target_modules", | |
nargs="+", | |
default=["query", "key", "value"], | |
help="lora target module", | |
) | |
train_group.add_argument('--structure_seq', type=str, default='') | |
def add_output_args(parser: argparse.ArgumentParser): | |
"""Add output-related arguments.""" | |
output_group = parser.add_argument_group('Output Configuration') | |
output_group.add_argument('--output_model_name', type=str) | |
output_group.add_argument('--output_root', default="ckpt") | |
output_group.add_argument('--output_dir', default=None) | |
def add_wandb_args(parser: argparse.ArgumentParser): | |
"""Add wandb-related arguments.""" | |
wandb_group = parser.add_argument_group('Wandb Configuration') | |
wandb_group.add_argument('--wandb', action='store_true') | |
wandb_group.add_argument('--wandb_entity', type=str) | |
wandb_group.add_argument('--wandb_project', type=str, default='VenusFactory') | |
wandb_group.add_argument('--wandb_run_name', type=str) | |
def validate_args(args: argparse.Namespace): | |
"""Validate command line arguments.""" | |
if args.batch_size is None and args.batch_token is None: | |
raise ValueError("batch_size or batch_token must be provided") | |
if args.training_method == 'ses-adapter': | |
if args.structure_seq is None: | |
raise ValueError("structure_seq must be provided for ses-adapter") | |
args.structure_seq = args.structure_seq.split(',') | |
else: | |
args.structure_seq = [] | |
def process_dataset_config(args: argparse.Namespace): | |
"""Process dataset configuration file.""" | |
if not args.dataset_config: | |
return | |
config = json.load(open(args.dataset_config)) | |
# Update args with dataset config values if not already set | |
for key in ['dataset', 'pdb_type', 'train_file', 'valid_file', 'test_file', | |
'num_labels', 'problem_type', 'monitor', 'monitor_strategy', | |
'metrics', 'normalize']: | |
if getattr(args, key) is None and key in config: | |
setattr(args, key, config[key]) | |
# Handle metrics specially | |
if args.metrics: | |
args.metrics = args.metrics.split(',') | |
if args.metrics == ['None']: | |
args.metrics = ['loss'] | |
warnings.warn("No metrics provided, using default metrics: loss") | |
def setup_output_dirs(args: argparse.Namespace): | |
"""Setup output directories.""" | |
if args.output_dir is None: | |
current_date = strftime("%Y%m%d", localtime()) | |
args.output_dir = os.path.join(args.output_root, current_date) | |
else: | |
args.output_dir = os.path.join(args.output_root, args.output_dir) | |
os.makedirs(args.output_dir, exist_ok=True) | |
def setup_wandb_config(args: argparse.Namespace): | |
"""Setup wandb configuration.""" | |
if args.wandb: | |
if args.wandb_run_name is None: | |
args.wandb_run_name = f"VenusFactory-{args.dataset}" | |
if args.output_model_name is None: | |
args.output_model_name = f"{args.wandb_run_name}.pt" | |