nikraf's picture
Upload folder using huggingface_hub
714cf46 verified
import os
import argparse
import sys
import yaml
from types import SimpleNamespace
from modal_cli import _run_on_modal_cli, _should_auto_run_modal
from modal_utils import parse_modal_api_key
def parse_arguments():
raw_argv = sys.argv[1:]
parser = argparse.ArgumentParser(description="Script with arguments mirroring the provided YAML settings.")
# ----------------- ID ----------------- #
parser.add_argument("--hf_username", default="Synthyra", help="Hugging Face username.")
parser.add_argument("--hf_token", default=None, help="Hugging Face token.")
parser.add_argument("--synthyra_api_key", default=None, help="Synthyra API key.")
parser.add_argument("--wandb_api_key", default=None, help="Wandb API key.")
parser.add_argument("--modal_token_id", default=None, help="Modal token ID used for authentication.")
parser.add_argument("--modal_token_secret", default=None, help="Modal token secret used for authentication.")
parser.add_argument("--modal_api_key", default=None, help="Backward-compatible Modal key formatted as '<modal_token_id>:<modal_token_secret>'.")
parser.add_argument("--rebuild_modal", action="store_true", default=False, help="Force rebuild and deploy of the Modal backend before running.")
parser.add_argument("--delete_modal_embeddings", action="store_true", default=False, help="Delete all embedding cache files from the Modal volume before submission.")
# ----------------- Paths ----------------- #
parser.add_argument("--hf_home", type=str, default=None, help="Customize the HF cache directory.")
parser.add_argument("--yaml_path", type=str, default=None, help="Path to the YAML file.")
parser.add_argument("--log_dir", type=str, default="logs", help="Path to the log directory.")
parser.add_argument("--results_dir", type=str, default="results", help="Path to the results directory.")
parser.add_argument("--model_save_dir", default="weights", help="Directory to save models.")
parser.add_argument("--embedding_save_dir", default="embeddings", help="Directory to save embeddings.")
parser.add_argument("--download_dir", default="Synthyra/vector_embeddings", help="Directory to download embeddings to.")
parser.add_argument("--plots_dir", default="plots", help="Directory to save plots.")
parser.add_argument("--replay_path", type=str, default=None, help="Path to the replay file.")
parser.add_argument("--pretrained_probe_path", type=str, default=None) # TODO not used right now
# ----------------- DataArguments ----------------- #
parser.add_argument("--delimiter", default=",", help="Delimiter for data.")
parser.add_argument("--col_names", nargs="+", default=["seqs", "labels"], help="Column names.") # DEPRECATED, found automatically now
parser.add_argument("--max_length", type=int, default=2048, help="Maximum sequence length.")
parser.add_argument("--trim", action="store_true", default=False,
help="Whether to trim sequences (default: False). If False, sequences are removed from the dataset if they are longer than max length. If True, they are truncated to max length."
)
parser.add_argument("--data_names", nargs="+", default=[], help="List of HF dataset names.") # TODO rename to data_names
parser.add_argument("--data_dirs", nargs="+", default=[], help="List of local data directories.")
parser.add_argument("--aa_to_dna", action="store_true", default=False, help="Translate amino-acid sequences to DNA codon sequences using common human synonymous codons.")
parser.add_argument("--aa_to_rna", action="store_true", default=False, help="Translate amino-acid sequences to RNA codon sequences using common human synonymous codons.")
parser.add_argument("--dna_to_aa", action="store_true", default=False, help="Translate DNA codon sequences to amino-acid sequences and drop stop codons.")
parser.add_argument("--rna_to_aa", action="store_true", default=False, help="Translate RNA codon sequences to amino-acid sequences and drop stop codons.")
parser.add_argument("--codon_to_aa", action="store_true", default=False, help="Translate codon-token sequences to amino-acid sequences and drop stop codons.")
parser.add_argument("--aa_to_codon", action="store_true", default=False, help="Translate amino-acid sequences to codon-token sequences.")
parser.add_argument("--random_pair_flipping", action="store_true", default=False, help="Enable random swapping of paired inputs during training.")
# ----------------- BaseModelArguments ----------------- #
parser.add_argument("--model_names", nargs="+", default=None, help="List of preset model names to use (e.g. ESM2-8). Mutually exclusive with --model_paths/--model_types.")
parser.add_argument("--model_paths", nargs="+", default=None, help="List of model paths (HuggingFace or local). Must be paired with --model_types. Mutually exclusive with --model_names.")
parser.add_argument("--model_types", nargs="+", default=None, help="List of model type keywords paired with --model_paths (e.g. esm2, esmc, protbert, prott5, ankh, glm, dplm, dplm2, protclm, onehot, amplify, e1, vec2vec, calm, custom, random).")
parser.add_argument("--model_dtype", type=str, choices=["fp32", "fp16", "bf16", "float32", "float16", "bfloat16"], default="bf16", help="Data type for loading base models.")
parser.add_argument("--use_xformers", action="store_true", default=False, help="Use xformers memory efficient attention for AMPLIFY models (default: False).")
# ----------------- ProbeArguments ----------------- #
parser.add_argument("--probe_type", choices=["linear", "transformer", "retrievalnet", "lyra"], default="linear", help="Type of probe.")
parser.add_argument("--tokenwise", action="store_true", default=False, help="Tokenwise probe (default: False).")
parser.add_argument("--hidden_size", type=int, default=8192, help="Hidden dimension size for linear probe MLP.")
parser.add_argument("--transformer_hidden_size", type=int, default=512, help="Hidden dimension size for transformer probe.")
parser.add_argument("--dropout", type=float, default=0.2, help="Dropout rate.")
parser.add_argument("--n_layers", type=int, default=1, help="Number of layers.")
parser.add_argument("--pre_ln", action="store_false", default=True,
help="Disable pre-layernorm (default: enabled). Use --pre_ln to toggle off.")
parser.add_argument("--classifier_size", type=int, default=4096, help="Feed-forward dimension.")
parser.add_argument("--transformer_dropout", type=float, default=0.1, help="Dropout rate for the transformer layers.")
parser.add_argument("--classifier_dropout", type=float, default=0.2, help="Dropout rate for the classifier.")
parser.add_argument("--n_heads", type=int, default=4, help="Number of heads in multi-head attention.")
parser.add_argument("--rotary", action="store_false", default=True,
help="Disable rotary embeddings (default: enabled). Use --rotary to toggle off.")
parser.add_argument("--probe_pooling_types", nargs="+", default=["mean", "var"], help="Pooling types to use.")
parser.add_argument("--use_bias", action="store_true", default=False, help="Use bias in Linear layers (default: False)")
parser.add_argument("--save_model", action="store_true", default=False, help="Save trained model (default: False).")
parser.add_argument("--production_model", action="store_true", default=False, help="Production model (default: False).")
parser.add_argument("--lora", action="store_true", default=False, help="Use LoRA (default: False).")
parser.add_argument("--lora_r", type=int, default=8, help="Number of trainable parameters in the LoRA model.")
parser.add_argument("--lora_alpha", type=float, default=32.0, help="Alpha for the LoRA model.")
parser.add_argument("--lora_dropout", type=float, default=0.01, help="Dropout rate for the LoRA model.")
parser.add_argument("--sim_type", choices=["dot", "euclidean", "cosine"], default="dot", help="Cross-attention mechanism for token-parameter-attention")
parser.add_argument("--token_attention", action="store_true", default=False, help="If true, use TokenFormer instead of Transformer blocks")
parser.add_argument("--add_token_ids", action="store_true", default=False, help="If true, add learned token type embeddings to distinguish protein A vs B in PPI tasks.")
# ----------------- ScikitArguments ----------------- #
parser.add_argument("--scikit_n_iter", type=int, default=10, help="Number of iterations for scikit model.")
parser.add_argument("--scikit_cv", type=int, default=3, help="Number of cross-validation folds for scikit model.")
parser.add_argument("--scikit_random_state", type=int, default=None, help="Random state for scikit model (if None, uses global seed).")
parser.add_argument("--scikit_model_name", type=str, default=None, help="Name of the scikit model to use.")
parser.add_argument("--scikit_model_args", type=str, default=None, help="JSON string of hyperparameters to use (skips tuning). E.g. '{\"n_estimators\": 500, \"max_depth\": 7}'")
parser.add_argument("--use_scikit", action="store_true", default=False, help="Use scikit model (default: False).")
parser.add_argument("--n_jobs", type=int, default=1, help="Number of processes to use in scikit.") # TODO integrate with GUI and main
# ----------------- EmbeddingArguments ----------------- #
parser.add_argument("--embedding_batch_size", type=int, default=16, help="Batch size for embedding generation.")
parser.add_argument("--embedding_num_workers", type=int, default=0, help="Number of worker processes for embedding generation.")
parser.add_argument("--num_workers", type=int, default=0, help="Number of worker processes for data loading.")
parser.add_argument("--download_embeddings", action="store_true", default=False, help="Whether to download embeddings (default: False).")
parser.add_argument("--matrix_embed", action="store_true", default=False, help="Use matrix embedding (default: False).")
parser.add_argument("--embedding_pooling_types", nargs="+", default=["mean", "var"], help="Pooling types for embeddings.")
parser.add_argument("--save_embeddings", action="store_true", default=False, help="Save computed embeddings (default: False).")
parser.add_argument("--embed_dtype", type=str, choices=["fp32", "fp16", "bf16", "float32", "float16", "bfloat16"], default=None, help="Data type for embeddings. If omitted, uses --model_dtype.")
parser.add_argument("--sql", action="store_true", default=False, help="Whether to use SQL storage (default: False).")
parser.add_argument("--read_scaler", type=int, default=100, help="Read scaler for SQL storage.")
# ----------------- Multi-Column Sequences ----------------- #
parser.add_argument("--multi_column", nargs="+", default=None, help="If set, list of sequence column names to combine per sample.")
# ----------------- TrainerArguments ----------------- #
parser.add_argument("--num_epochs", type=int, default=200, help="Number of epochs to train for.")
parser.add_argument("--probe_batch_size", type=int, default=64, help="Batch size for probe training.")
parser.add_argument("--base_batch_size", type=int, default=4, help="Batch size for base model training.")
parser.add_argument("--probe_grad_accum", type=int, default=1, help='Gradient accumulation steps for probe training.')
parser.add_argument("--base_grad_accum", type=int, default=8, help='Gradient accumulation steps for base model training.')
parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate.")
### TODO integrate
#parser.add_argument("--probe_lr", type=float, default=1e-4, help="Learning rate for probe training.")
#parser.add_argument("--base_lr", type=float, default=1e-5, help="Learning rate for base model training.")
#parser.add_argument("--lr_scheduler", type=str, default='cosine', help='Learning rate scheduler.')
#parser.add_argument("--optimizer", type=str, default='adamw', help='Optimizer.')
parser.add_argument("--weight_decay", type=float, default=0.00, help="Weight decay.")
parser.add_argument("--patience", type=int, default=1, help="Patience for early stopping.")
parser.add_argument("--seed", type=int, default=None, help="Seed for reproducibility (if omitted, current time is used).")
parser.add_argument("--deterministic", action="store_true", default=False,
help="Enable deterministic behavior for reproducibility (can slightly slow down training).")
parser.add_argument("--full_finetuning", action="store_true", default=False, help="Full finetuning (default: False).")
parser.add_argument("--hybrid_probe", action="store_true", default=False, help="Hybrid probe (default: False).")
parser.add_argument("--num_runs", type=int, default=1, help="Number of training runs with different seeds. Results will show mean±std across runs.")
# ----------------- ProteinGym Arguments ----------------- #
parser.add_argument("--dms_ids", nargs="+", default=["all"],
help="ProteinGym DMS assay IDs to evaluate (space-separated), or 'all' to run all assays.")
parser.add_argument("--proteingym", action="store_true", default=False, help="ProteinGym (default: False).")
parser.add_argument("--mode", type=str, default='benchmark',
help="ProteinGym zero-shot mode: 'benchmark', 'indels', 'multiples', 'singles'")
parser.add_argument("--scoring_method", choices=["masked_marginal", "mutant_marginal", "wildtype_marginal", "pll", "global_log_prob"], default="masked_marginal",
help="Select a scoring method for ProteinGym zero-shot.")
parser.add_argument("--scoring_window", choices=["optimal", "sliding"], default="optimal",
help="Select how to slice the sequence for ProteinGym zero-shot.")
parser.add_argument("--pg_batch_size", type=int, default=32,
help="Batch size for ProteinGym zero-shot scoring (default: 32).")
parser.add_argument("--compare_scoring_methods", action="store_true", default=False,
help="Compare different scoring methods across models and DMS assays (default: False).")
parser.add_argument("--score_only", action="store_true", default=False,
help="Only run the ProteinGym benchmarking script on existing CSV files, skip zero-shot scoring (default: False).")
# ----------------- W&B Arguments ----------------- #
parser.add_argument("--use_wandb_hyperopt", action="store_true", default=False, help="Use Weights & Biases hyperparameter optimization.")
parser.add_argument("--wandb_project", type=str, default="Protify", help="W&B project name for sweeps.")
parser.add_argument("--wandb_entity", type=str, default=None, help="W&B entity (team/user) for sweeps.")
parser.add_argument("--sweep_config_path", type=str, default="yamls/sweep.yaml", help="Path to W&B sweep config YAML.")
parser.add_argument("--sweep_count", type=int, default=10, help="Number of hyperparameter trials to run in the sweep.")
parser.add_argument("--sweep_method", type=str, default="bayes", choices=["bayes", "grid", "random"], help="Sweep method for hyperparameter optimization.")
parser.add_argument("--sweep_metric_cls",type=str,default="eval_loss", help="Classification metric to optimize during sweep (e.g., eval_f1, eval_accuracy, eval_mcc)")
parser.add_argument("--sweep_metric_reg",type=str,default="eval_loss", help="Regression metric to optimize during sweep (e.g., eval_r_squared, eval_spearman_rho, eval_pearson_rho)")
parser.add_argument("--sweep_goal", type=str, default='minimize', choices=['maximize', 'minimize'], help="Goal for the sweep metric (maximize/minimize)")
args = parser.parse_args()
# Validate model_names vs model_paths/model_types mutual exclusivity
if args.model_paths is not None:
assert args.model_types is not None, "--model_types is required when --model_paths is provided."
assert len(args.model_paths) == len(args.model_types), f"--model_paths ({len(args.model_paths)}) and --model_types ({len(args.model_types)}) must have the same length."
assert args.model_names is None, "--model_names cannot be used together with --model_paths/--model_types."
elif args.model_types is not None:
assert args.model_paths is not None, "--model_paths is required when --model_types is provided."
if args.model_names is None and args.model_paths is None:
args.model_names = ["ESM2-8"]
args.modal_cli_credentials_provided = (
("--modal_api_key" in raw_argv)
or ("--modal_token_id" in raw_argv)
or ("--modal_token_secret" in raw_argv)
or any(item.startswith("--modal_api_key=") for item in raw_argv)
or any(item.startswith("--modal_token_id=") for item in raw_argv)
or any(item.startswith("--modal_token_secret=") for item in raw_argv)
)
if args.modal_api_key is not None and (args.modal_token_id is None or args.modal_token_secret is None):
parsed_modal_token_id, parsed_modal_token_secret = parse_modal_api_key(args.modal_api_key)
if args.modal_token_id is None:
args.modal_token_id = parsed_modal_token_id
if args.modal_token_secret is None:
args.modal_token_secret = parsed_modal_token_secret
if args.modal_token_id is not None:
os.environ["MODAL_TOKEN_ID"] = args.modal_token_id
if args.modal_token_secret is not None:
os.environ["MODAL_TOKEN_SECRET"] = args.modal_token_secret
if args.hf_token is not None:
from huggingface_hub import login
# Override environment variable to ensure this token is used
os.environ["HF_TOKEN"] = args.hf_token
login(args.hf_token)
print(f"Logged in to HuggingFace Hub with token from arguments")
else:
# Check if token exists in environment (from Modal secret or other source)
hf_token_env = os.environ.get("HF_TOKEN")
if hf_token_env:
print(f"Note: HF_TOKEN found in environment (from Modal secret or other source)")
print(f"Note: This token will be used for read operations only unless overridden")
if args.wandb_api_key is not None:
try:
import wandb
wandb.login(key=args.wandb_api_key)
print('Logged into Weights & Biases')
except Exception as e:
print(f'W&B login failed: {e}')
if args.synthyra_api_key is not None:
print('Synthyra API not integrated yet')
if args.yaml_path is not None:
with open(args.yaml_path, 'r') as file:
settings = yaml.safe_load(file)
yaml_args = SimpleNamespace(**settings)
def _merge_store_true(cli_value: bool, key: str) -> bool:
if cli_value:
return True
if key in yaml_args.__dict__:
return bool(yaml_args.__dict__[key])
return False
if args.hf_token is not None:
yaml_args.hf_token = args.hf_token
elif "hf_token" not in yaml_args.__dict__:
yaml_args.hf_token = None
if args.hf_home is not None:
yaml_args.hf_home = args.hf_home
elif "hf_home" not in yaml_args.__dict__:
yaml_args.hf_home = None
if args.synthyra_api_key is not None:
yaml_args.synthyra_api_key = args.synthyra_api_key
elif "synthyra_api_key" not in yaml_args.__dict__:
yaml_args.synthyra_api_key = None
if args.wandb_api_key is not None:
yaml_args.wandb_api_key = args.wandb_api_key
elif "wandb_api_key" not in yaml_args.__dict__:
yaml_args.wandb_api_key = None
if args.modal_token_id is not None:
yaml_args.modal_token_id = args.modal_token_id
elif "modal_token_id" not in yaml_args.__dict__:
yaml_args.modal_token_id = None
if args.modal_token_secret is not None:
yaml_args.modal_token_secret = args.modal_token_secret
elif "modal_token_secret" not in yaml_args.__dict__:
yaml_args.modal_token_secret = None
if args.modal_api_key is not None:
yaml_args.modal_api_key = args.modal_api_key
elif "modal_api_key" not in yaml_args.__dict__:
yaml_args.modal_api_key = None
yaml_args.rebuild_modal = _merge_store_true(args.rebuild_modal, "rebuild_modal")
yaml_args.delete_modal_embeddings = _merge_store_true(args.delete_modal_embeddings, "delete_modal_embeddings")
yaml_args.use_wandb_hyperopt = _merge_store_true(args.use_wandb_hyperopt, "use_wandb_hyperopt")
if (args.wandb_project != "Protify") or ("wandb_project" not in yaml_args.__dict__):
yaml_args.wandb_project = args.wandb_project
if (args.wandb_entity is not None) or ("wandb_entity" not in yaml_args.__dict__):
yaml_args.wandb_entity = args.wandb_entity
if (args.sweep_config_path != "yamls/sweep.yaml") or ("sweep_config_path" not in yaml_args.__dict__):
yaml_args.sweep_config_path = args.sweep_config_path
if (args.sweep_count != 10) or ("sweep_count" not in yaml_args.__dict__):
yaml_args.sweep_count = args.sweep_count
if (args.sweep_method != "bayes") or ("sweep_method" not in yaml_args.__dict__):
yaml_args.sweep_method = args.sweep_method
if (args.sweep_metric_cls != "eval_loss") or ("sweep_metric_cls" not in yaml_args.__dict__):
yaml_args.sweep_metric_cls = args.sweep_metric_cls
if (args.sweep_metric_reg != "eval_loss") or ("sweep_metric_reg" not in yaml_args.__dict__):
yaml_args.sweep_metric_reg = args.sweep_metric_reg
if (args.sweep_goal != "minimize") or ("sweep_goal" not in yaml_args.__dict__):
yaml_args.sweep_goal = args.sweep_goal
yaml_args.yaml_path = args.yaml_path
yaml_args.aa_to_dna = _merge_store_true(args.aa_to_dna, "aa_to_dna")
yaml_args.aa_to_rna = _merge_store_true(args.aa_to_rna, "aa_to_rna")
yaml_args.dna_to_aa = _merge_store_true(args.dna_to_aa, "dna_to_aa")
yaml_args.rna_to_aa = _merge_store_true(args.rna_to_aa, "rna_to_aa")
yaml_args.codon_to_aa = _merge_store_true(args.codon_to_aa, "codon_to_aa")
yaml_args.aa_to_codon = _merge_store_true(args.aa_to_codon, "aa_to_codon")
yaml_args.random_pair_flipping = _merge_store_true(args.random_pair_flipping, "random_pair_flipping")
# Ensure ProteinGym defaults exist when using YAML configs
if not hasattr(yaml_args, 'proteingym'):
yaml_args.proteingym = False
if not hasattr(yaml_args, 'dms_ids'):
yaml_args.dms_ids = ["all"]
if not hasattr(yaml_args, 'mode'):
yaml_args.mode = None
if not hasattr(yaml_args, 'scoring_method'):
yaml_args.scoring_method = "masked_marginal"
# Ensure num_runs default exists
if not hasattr(yaml_args, 'num_runs'):
yaml_args.num_runs = 1
if "model_dtype" not in yaml_args.__dict__ or yaml_args.model_dtype is None:
yaml_args.model_dtype = args.model_dtype
if "embed_dtype" not in yaml_args.__dict__:
yaml_args.embed_dtype = args.embed_dtype
if "model_paths" not in yaml_args.__dict__:
yaml_args.model_paths = args.model_paths
if "model_types" not in yaml_args.__dict__:
yaml_args.model_types = args.model_types
if "model_names" not in yaml_args.__dict__:
yaml_args.model_names = args.model_names
return yaml_args
else:
return args
if __name__ == "__main__":
# Settings that need to happen pre-imports
args = parse_arguments()
# Require that either datasets are specified or a ProteinGym experiment is chosen
has_datasets = bool(args.data_names or args.data_dirs)
has_proteingym = bool(args.proteingym)
has_modal_maintenance = bool(args.modal_cli_credentials_provided and (args.rebuild_modal or args.delete_modal_embeddings))
if not has_datasets and not has_proteingym and not has_modal_maintenance:
raise AssertionError("No datasets specified. Provide --data_names or --data_dirs, or run a ProteinGym experiment.")
if args.use_xformers:
os.environ["_USE_XFORMERS"] = "1"
print("xformers memory efficient attention enabled for AMPLIFY models")
if args.hf_home is not None:
# Needs to happen before any HF imports
import pathlib
base_path = args.hf_home
cache_root = f"{base_path}/hf_cache"
tmp_root = f"{base_path}/tmp"
pathlib.Path(cache_root).mkdir(parents=True, exist_ok=True)
pathlib.Path(tmp_root).mkdir(parents=True, exist_ok=True)
os.environ["HF_HOME"] = cache_root
os.environ["HF_DATASETS_CACHE"] = f"{cache_root}/datasets"
os.environ["TRANSFORMERS_CACHE"] = f"{cache_root}/transformers" # this is deprecated, but does not hurt anything
os.environ["HF_HUB_CACHE"] = f"{cache_root}/hub"
print(f"HF_HOME: {os.environ['HF_HOME']}")
print(f"HF_DATASETS_CACHE: {os.environ['HF_DATASETS_CACHE']}")
print(f"TRANSFORMERS_CACHE: {os.environ['TRANSFORMERS_CACHE']}")
print(f"HF_HUB_CACHE: {os.environ['HF_HUB_CACHE']}")
# Set global seed before doing anything else
# If seed is None, set_global_seed will derive it from current time
if args.deterministic:
from seed_utils import set_determinism
set_determinism()
import entrypoint_setup # needs to happen after set_determinism()
import torch
from torchinfo import summary
from probes.get_probe import ProbeArguments, get_probe
from base_models.get_base_models import BaseModelArguments, get_tokenizer, get_base_model_for_training
from base_models.utils import wrap_lora
from data.data_mixin import DataMixin, DataArguments
from probes.trainers import TrainerMixin, TrainerArguments
from probes.scikit_classes import ScikitArguments, ScikitProbe
from embedder import EmbeddingArguments, Embedder, get_embedding_filename
from logger import MetricsLogger, log_method_calls
from utils import torch_load, print_message, expand_dms_ids_all
from visualization.plot_result import create_plots
from hyperopt_utils import HyperoptModule
from benchmarks.proteingym.scorer import ProteinGymRunner
from benchmarks.proteingym.compare_scoring_methods import compare_scoring_methods
from seed_utils import set_global_seed
class MainProcess(MetricsLogger, DataMixin, TrainerMixin):
def __init__(self, full_args, GUI=False):
super(MainProcess, self).__init__(full_args)
super(DataMixin, self).__init__()
super(TrainerMixin, self).__init__()
self.full_args = full_args
if not GUI:
self.start_log_main()
self.dtype_map = {
"fp32": torch.float32,
"fp16": torch.float16,
"bf16": torch.bfloat16,
"float32": torch.float32,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
"float8_e4m3fn": torch.float8_e4m3fn,
"float8_e5m2": torch.float8_e5m2,
#"int8": torch.int8,
}
def _build_scikit_args(self):
if "scikit_n_iter" in self.full_args.__dict__:
n_iter = self.full_args.scikit_n_iter
else:
n_iter = 10
if "scikit_cv" in self.full_args.__dict__:
cv = self.full_args.scikit_cv
else:
cv = 3
if "scikit_random_state" in self.full_args.__dict__:
random_state = self.full_args.scikit_random_state
else:
random_state = None
if "scikit_model_name" in self.full_args.__dict__:
model_name = self.full_args.scikit_model_name
else:
model_name = None
if "production_model" in self.full_args.__dict__:
production_model = self.full_args.production_model
else:
production_model = False
return ScikitArguments(
n_iter=n_iter,
cv=cv,
random_state=random_state,
model_name=model_name,
production_model=production_model,
)
@log_method_calls
def apply_current_settings(self):
if "model_dtype" not in self.full_args.__dict__:
self.full_args.model_dtype = "bf16"
if "embed_dtype" not in self.full_args.__dict__:
self.full_args.embed_dtype = None
if isinstance(self.full_args.model_dtype, str):
self.full_args.model_dtype = self.dtype_map[self.full_args.model_dtype]
if self.full_args.embed_dtype is None:
self.full_args.embed_dtype = self.full_args.model_dtype
elif isinstance(self.full_args.embed_dtype, str):
self.full_args.embed_dtype = self.dtype_map[self.full_args.embed_dtype]
else:
self.full_args.embed_dtype = self.full_args.embed_dtype
self.data_args = DataArguments(**self.full_args.__dict__)
self.embedding_args = EmbeddingArguments(**self.full_args.__dict__)
self.model_args = BaseModelArguments(**self.full_args.__dict__)
self.probe_args = ProbeArguments(**self.full_args.__dict__)
self.trainer_args = TrainerArguments(**self.full_args.__dict__)
self.logger_args = SimpleNamespace(**self.full_args.__dict__)
self.scikit_args = self._build_scikit_args()
self._sql = self.full_args.sql
self._full = self.full_args.matrix_embed
self._max_length = self.full_args.max_length
self._trim = self.full_args.trim
self._delimiter = self.full_args.delimiter
self._col_names = self.full_args.col_names
self._aa_to_dna = self.full_args.aa_to_dna
self._aa_to_rna = self.full_args.aa_to_rna
self._dna_to_aa = self.full_args.dna_to_aa
self._rna_to_aa = self.full_args.rna_to_aa
self._codon_to_aa = self.full_args.codon_to_aa
self._aa_to_codon = self.full_args.aa_to_codon
self._multi_column = getattr(self.full_args, 'multi_column', None)
@log_method_calls
def get_datasets(self):
self.datasets, self.all_seqs = self.get_data()
@log_method_calls
def save_embeddings_to_disk(self):
self.embedding_args.save_embeddings = True
embedder = Embedder(self.embedding_args, self.all_seqs)
for display_name, dispatch_type, model_path in self.model_args.model_entries():
_ = embedder(display_name, model_type=dispatch_type, model_path=model_path)
def _create_model_factory(self, model_name, tokenwise, num_labels, hybrid, model_path=None):
"""Function for creating fresh models in multi-run mode."""
def factory():
model, _ = get_base_model_for_training(
model_name,
tokenwise=tokenwise,
num_labels=num_labels,
hybrid=hybrid,
dtype=self.model_args.model_dtype,
model_path=model_path,
)
if self.probe_args.lora:
model = wrap_lora(model, self.probe_args.lora_r, self.probe_args.lora_alpha, self.probe_args.lora_dropout)
return model
return factory
def _create_probe_factory(self):
"""Function for creating fresh probes in multi-run mode."""
def factory():
return get_probe(self.probe_args)
return factory
def _run_nn_probe(
self,
model_name,
data_name,
train_set,
valid_set,
test_set,
tokenizer,
emb_dict=None,
ppi=False,
source_model_name=None,
sweep_mode: bool = False,
):
if source_model_name is None:
source_model_name = model_name
# Create initial probe (for single run or as template for multi-run)
probe = get_probe(self.probe_args)
summary(probe)
# trainer_probe handles multi-run internally if num_runs > 1
probe, valid_metrics, test_metrics, _, _ = self.trainer_probe(
model=probe,
tokenizer=tokenizer,
model_name=model_name,
data_name=data_name,
train_dataset=train_set,
valid_dataset=valid_set,
test_dataset=test_set,
emb_dict=emb_dict,
ppi=ppi,
log_id=self.random_id,
source_model_name=source_model_name,
)
if not sweep_mode:
self.log_metrics(data_name, model_name, valid_metrics, split_name='valid')
self.log_metrics(data_name, model_name, test_metrics, split_name='test')
return probe, valid_metrics, test_metrics
def _train_nn_probe_fold(self, model_name, dms_id, subtrain_seqs, subtrain_labels,
valid_seqs, valid_labels, test_seqs, test_labels,
emb_dict, fold_info):
"""Trains a neural network probe on a ProteinGym DMS assay CV fold."""
train_set = {'seqs': subtrain_seqs, 'labels': subtrain_labels}
valid_set = None if (valid_seqs is None or valid_labels is None) else {'seqs': valid_seqs, 'labels': valid_labels}
test_set = {'seqs': test_seqs, 'labels': test_labels}
# Get tokenizer and determine input dimensions
tokenizer = get_tokenizer(model_name)
if self._sql:
save_path = os.path.join(self.embedding_args.embedding_save_dir,
f'{model_name}_{self._full}.db')
input_dim = self.get_embedding_dim_sql(save_path, subtrain_seqs[0], tokenizer)
emb_for_training = None
else:
save_path = os.path.join(self.embedding_args.embedding_save_dir,
f'{model_name}_{self._full}.pth')
emb_for_training = torch_load(save_path) if os.path.exists(save_path) else emb_dict
input_dim = self.get_embedding_dim_pth(emb_for_training, subtrain_seqs[0], tokenizer)
# Configure probe for regression
self.probe_args.input_size = input_dim
self.probe_args.task_type = 'regression'
self.probe_args.num_labels = 1
self.trainer_args.task_type = 'regression'
probe = get_probe(self.probe_args)
_, _, test_metrics = self.trainer_probe(
model=probe,
tokenizer=tokenizer,
model_name=model_name,
data_name=f"{dms_id}_{fold_info}",
train_dataset=train_set,
valid_dataset=valid_set,
test_dataset=test_set,
emb_dict=emb_for_training,
ppi=False,
log_id=f"{self.random_id}_{fold_info}",
source_model_name=model_name,
)
# Handle both plain and test-prefixed metric keys returned by HF Trainer
rho = test_metrics.get('spearman_rho', test_metrics.get('test_spearman_rho', None))
mse = test_metrics.get('mse', test_metrics.get('test_mse', None))
return rho, mse
def _run_full_finetuning(
self,
model_name,
data_name,
train_set,
valid_set,
test_set,
ppi=False,
source_model_name=None,
sweep_mode: bool = False,
model_path: str = None,
):
if source_model_name is None:
source_model_name = model_name
tokenwise = self.probe_args.tokenwise
num_labels = self.probe_args.num_labels
num_runs = getattr(self.trainer_args, 'num_runs', 1)
model_factory = self._create_model_factory(model_name, tokenwise, num_labels, hybrid=False, model_path=model_path) if num_runs > 1 else None
model, tokenizer = get_base_model_for_training(
model_name,
tokenwise=tokenwise,
num_labels=num_labels,
hybrid=False,
dtype=self.model_args.model_dtype,
model_path=model_path,
)
if self.probe_args.lora:
model = wrap_lora(model, self.probe_args.lora_r, self.probe_args.lora_alpha, self.probe_args.lora_dropout)
summary(model)
model, valid_metrics, test_metrics, _, _ = self.trainer_base_model(
model=model,
tokenizer=tokenizer,
model_name=model_name,
data_name=data_name,
train_dataset=train_set,
valid_dataset=valid_set,
test_dataset=test_set,
ppi=ppi,
log_id=self.random_id,
source_model_name=source_model_name,
model_factory=model_factory,
)
if not sweep_mode:
self.log_metrics(data_name, model_name, valid_metrics, split_name='valid')
self.log_metrics(data_name, model_name, test_metrics, split_name='test')
return model, valid_metrics, test_metrics
def _run_hybrid_probe(
self,
model_name,
data_name,
train_set,
valid_set,
test_set,
tokenizer,
emb_dict=None,
ppi=False,
source_model_name=None,
sweep_mode: bool = False,
model_path: str = None,
):
if source_model_name is None:
source_model_name = model_name
# Random models don't have a trainable base model, so fall back to regular probe
if "random" in model_name.lower():
print_message(f"Model {model_name} does not support hybrid training. Training a linear probe instead.")
probe = get_probe(self.probe_args)
summary(probe)
probe, valid_metrics, test_metrics = self.trainer_probe(
model=probe,
tokenizer=tokenizer,
model_name=model_name,
data_name=data_name,
train_dataset=train_set,
valid_dataset=valid_set,
test_dataset=test_set,
emb_dict=emb_dict,
ppi=ppi,
log_id=self.random_id,
source_model_name=source_model_name,
)
if not sweep_mode:
self.log_metrics(data_name, model_name, valid_metrics, split_name='valid')
self.log_metrics(data_name, model_name, test_metrics, split_name='test')
return probe, valid_metrics, test_metrics
tokenwise = self.probe_args.tokenwise
num_labels = self.probe_args.num_labels
num_runs = getattr(self.trainer_args, 'num_runs', 1)
model_factory = self._create_model_factory(model_name, tokenwise, num_labels, hybrid=True, model_path=model_path) if num_runs > 1 else None
probe_factory = self._create_probe_factory() if num_runs > 1 else None
model, tokenizer = get_base_model_for_training(
model_name,
tokenwise=tokenwise,
num_labels=num_labels,
hybrid=True,
dtype=self.model_args.model_dtype,
model_path=model_path,
)
if self.probe_args.lora:
model = wrap_lora(model, self.probe_args.lora_r, self.probe_args.lora_alpha, self.probe_args.lora_dropout)
probe = get_probe(self.probe_args)
summary(model)
summary(probe)
model, valid_metrics, test_metrics, _, _ = self.trainer_hybrid_model(
model=model,
tokenizer=tokenizer,
probe=probe,
model_name=model_name,
data_name=data_name,
train_dataset=train_set,
valid_dataset=valid_set,
test_dataset=test_set,
emb_dict=emb_dict,
ppi=ppi,
log_id=self.random_id,
source_model_name=source_model_name,
model_factory=model_factory,
probe_factory=probe_factory,
)
if not sweep_mode:
self.log_metrics(data_name, model_name, valid_metrics, split_name='valid')
self.log_metrics(data_name, model_name, test_metrics, split_name='test')
return model, valid_metrics, test_metrics
@log_method_calls
def run_full_finetuning(self):
total_combinations = len(self.model_args.model_names) * len(self.datasets)
self.logger.info(f"Processing {total_combinations} model/dataset combinations")
for display_name, dispatch_type, model_path in self.model_args.model_entries():
for data_name, dataset in self.datasets.items():
self.logger.info(f"Processing dataset: {data_name}")
train_set, valid_set, test_set, num_labels, label_type, ppi = dataset
self.probe_args.num_labels = num_labels
self.probe_args.task_type = label_type
self.trainer_args.task_type = label_type
self.logger.info(f'Training probe for {data_name} with {display_name}')
_ = self._run_full_finetuning(dispatch_type, data_name, train_set, valid_set, test_set, ppi, model_path=model_path)
torch.cuda.empty_cache()
@log_method_calls
def run_hybrid_probes(self):
probe_args = self.probe_args
test_seq = self.all_seqs[0]
# Log the combinations we're going to process
total_combinations = len(self.model_args.model_names) * len(self.datasets)
self.logger.info(f"Processing {total_combinations} model/dataset combinations")
# for each model, gather the settings and embeddings
# assumes save_embeddings_to_disk has already been called
for display_name, dispatch_type, model_path in self.model_args.model_entries():
self.logger.info(f"Processing model: {display_name}")
# get tokenizer
tokenizer = get_tokenizer(dispatch_type, model_path=model_path)
# get embedding size
pooling_types = self.embedding_args.pooling_types
if self._sql:
# for sql, the embeddings will be gathered in real time during training
filename = get_embedding_filename(display_name, self._full, pooling_types, 'db')
save_path = os.path.join(self.embedding_args.embedding_save_dir, filename)
input_size = self.get_embedding_dim_sql(save_path, test_seq, tokenizer)
emb_dict = None
else:
# for pth, the embeddings are loaded entirely into RAM and accessed during training
filename = get_embedding_filename(display_name, self._full, pooling_types, 'pth')
save_path = os.path.join(self.embedding_args.embedding_save_dir, filename)
emb_dict = torch_load(save_path)
input_size = self.get_embedding_dim_pth(emb_dict, test_seq, tokenizer)
# Adjust input dim for multi-column vector embeddings
if (not self._full) and getattr(self.full_args, 'multi_column', None):
input_size = input_size * len(self.full_args.multi_column)
# for each dataset, gather the settings and train the probe
for data_name, dataset in self.datasets.items():
self.logger.info(f"Processing dataset: {data_name}")
train_set, valid_set, test_set, num_labels, label_type, ppi = dataset
if ppi and not self._full:
probe_args.input_size = input_size * 2
else:
probe_args.input_size = input_size
self.probe_args.num_labels = num_labels
self.probe_args.task_type = label_type
### TODO we currently need both, settings should probably be consolidated
self.trainer_args.task_type = label_type
self.logger.info(f'Training probe for {data_name} with {display_name}')
### TODO eventually add options for optimizers and schedulers
### TODO here is probably where we can differentiate between the different training schemes
_ = self._run_hybrid_probe(
model_name=dispatch_type,
data_name=data_name,
train_set=train_set,
valid_set=valid_set,
test_set=test_set,
tokenizer=tokenizer,
emb_dict=emb_dict,
ppi=ppi,
source_model_name=display_name,
model_path=model_path,
)
torch.cuda.empty_cache()
### TODO may link from probe here to running inference on input csv or HF datasets
@log_method_calls
def run_nn_probes(self):
probe_args = self.probe_args
test_seq = self.all_seqs[0]
# Log the combinations we're going to process
total_combinations = len(self.model_args.model_names) * len(self.datasets)
self.logger.info(f"Processing {total_combinations} model/dataset combinations")
# for each model, gather the settings and embeddings
# assumes save_embeddings_to_disk has already been called
for display_name, dispatch_type, model_path in self.model_args.model_entries():
self.logger.info(f"Processing model: {display_name}")
# get tokenizer
tokenizer = get_tokenizer(dispatch_type, model_path=model_path)
# get embedding size
pooling_types = self.embedding_args.pooling_types
if self._sql:
# for sql, the embeddings will be gathered in real time during training
filename = get_embedding_filename(display_name, self._full, pooling_types, 'db')
save_path = os.path.join(self.embedding_args.embedding_save_dir, filename)
input_size = self.get_embedding_dim_sql(save_path, test_seq, tokenizer)
emb_dict = None
else:
# for pth, the embeddings are loaded entirely into RAM and accessed during training
filename = get_embedding_filename(display_name, self._full, pooling_types, 'pth')
save_path = os.path.join(self.embedding_args.embedding_save_dir, filename)
emb_dict = torch_load(save_path)
input_size = self.get_embedding_dim_pth(emb_dict, test_seq, tokenizer)
# Adjust input dim for multi-column vector embeddings
if (not self._full) and getattr(self.full_args, 'multi_column', None):
input_size = input_size * len(self.full_args.multi_column)
print(f'Input dim: {input_size}')
# for each dataset, gather the settings and train the probe
for data_name, dataset in self.datasets.items():
self.logger.info(f"Processing dataset: {data_name}")
train_set, valid_set, test_set, num_labels, label_type, ppi = dataset
if ppi and not self._full:
probe_args.input_size = input_size * 2
else:
probe_args.input_size = input_size
self.probe_args.num_labels = num_labels
self.probe_args.task_type = label_type
### TODO we currently need both, settings should probably be consolidated
self.trainer_args.task_type = label_type
self.logger.info(f'Training probe for {data_name} with {display_name}')
### TODO eventually add options for optimizers and schedulers
### TODO here is probably where we can differentiate between the different training schemes
_ = self._run_nn_probe(
model_name=display_name,
data_name=data_name,
train_set=train_set,
valid_set=valid_set,
test_set=test_set,
tokenizer=tokenizer,
emb_dict=emb_dict,
ppi=ppi,
source_model_name=display_name,
)
torch.cuda.empty_cache()
### TODO may link from probe here to running inference on input csv or HF datasets
@log_method_calls
def run_scikit_scheme(self):
self.scikit_args = self._build_scikit_args()
scikit_probe = ScikitProbe(self.scikit_args)
if "n_jobs" in self.full_args.__dict__:
scikit_probe.n_jobs = self.full_args.n_jobs
else:
scikit_probe.n_jobs = 1
for display_name, dispatch_type, model_path in self.model_args.model_entries():
for data_name, dataset in self.datasets.items():
### find best scikit model and parameters via cross validation and lazy predict
X_train, y_train, X_valid, y_valid, X_test, y_test, label_type = self.prepare_scikit_dataset(display_name, dataset)
# If a specific model is specified, skip LazyPredict and go straight to that model
if self.scikit_args.model_name is not None:
print_message(f"Skipping LazyPredict, using specified model: {self.scikit_args.model_name}")
results = scikit_probe.run_specific_model(X_train, y_train, X_valid, y_valid, X_test, y_test, model_results=None)
else:
# Find best model via LazyPredict
if label_type == 'singlelabel':
results = scikit_probe.find_best_classifier(X_train, y_train, X_valid, y_valid)
elif label_type == 'regression':
results = scikit_probe.find_best_regressor(X_train, y_train, X_valid, y_valid)
else:
raise ValueError(f'Label type {label_type} not supported')
# Train and evaluate best model with optimal hyperparameters
results = scikit_probe.run_specific_model(X_train, y_train, X_valid, y_valid, X_test, y_test, results)
# Log the results for plotting
metrics_dict = {'test_mcc': results.final_scores} if isinstance(results.final_scores, (int, float)) else results.final_scores
self.log_metrics(data_name, display_name, metrics_dict, split_name='test')
@log_method_calls
def generate_plots(self):
print_message("Generating visualization plots...")
# Determine which results file to use
results_file = os.path.join(self.full_args.results_dir, f"{self.random_id}.tsv")
# Check if the results file exists
if not os.path.exists(results_file):
print_message(f"Results file not found: {results_file}")
return
# Get output directory
output_dir = self.full_args.plots_dir
print_message(f"Generating plots in {output_dir}...")
create_plots(results_file, output_dir)
print_message("Plots generated successfully!")
def run_proteingym_zero_shot(self):
"""Run ProteinGym zero-shot for all specified models and DMS ids."""
dms_ids = getattr(self.full_args, 'dms_ids', []) or []
mode = getattr(self.full_args, 'mode', 'benchmark')
dms_ids = expand_dms_ids_all(dms_ids, mode=mode)
if len(dms_ids) == 0:
raise ValueError("--dms_ids is required when --proteingym is specified")
model_names = self.model_args.model_names
if len(model_names) == 0:
raise ValueError("--model_names must specify at least one model")
assert self.model_args._model_paths is None, "ProteinGym zero-shot requires --model_names (preset names), not --model_paths/--model_types."
# Where to write results
results_root = getattr(self.full_args, 'results_dir', 'results')
results_dir = os.path.join(results_root, 'proteingym')
scoring_method = getattr(self.full_args, 'scoring_method', 'masked_marginal')
scoring_window = getattr(self.full_args, 'scoring_window', 'optimal')
if isinstance(mode, str) and mode.lower() == 'indels':
print_message("Only pll is currently supported for indels scoring.")
scoring_method = 'pll'
# Log the run
self.logger.info(f"Running ProteinGym zero-shot with [{scoring_method}] scoring on {len(dms_ids)} DMS ids with models: {model_names}")
runner = ProteinGymRunner(
results_dir=results_dir,
repo_id="GleghornLab/ProteinGym_DMS",
)
self._proteingym_timing = runner.run(
dms_ids=dms_ids,
model_names=model_names,
mode=mode,
scoring_method=scoring_method,
scoring_window=scoring_window,
batch_size=getattr(self.full_args, 'pg_batch_size', 32),
)
print_message(f"ProteinGym zero-shot complete. Results in {results_dir}")
# After all models are scored, run standardized performance benchmarking
runner.run_benchmark(model_names, dms_ids, mode, scoring_method)
def main(args: SimpleNamespace):
chosen_seed = set_global_seed(args.seed)
args.seed = chosen_seed
if _should_auto_run_modal(args):
return _run_on_modal_cli(args)
if args.replay_path is not None:
from logger import LogReplayer
replayer = LogReplayer(args.replay_path)
replay_args = replayer.parse_log()
replay_args.replay_path = args.replay_path
# Re-apply seed using the replayed settings to ensure exact reproducibility
try:
# If no seed is present in replay, fall back to time-based seed
if not hasattr(replay_args, 'seed') or replay_args.seed is None:
replay_args.seed = None
if not hasattr(replay_args, 'deterministic') or replay_args.deterministic is None:
replay_args.deterministic = getattr(args, 'deterministic', False)
chosen_seed = set_global_seed(replay_args.seed, deterministic=replay_args.deterministic)
replay_args.seed = chosen_seed
except Exception:
pass
main = MainProcess(replay_args, GUI=False)
for k, v in main.full_args.__dict__.items():
print(f"{k}:\t{v}")
replayer.run_replay(main)
else:
main = MainProcess(args, GUI=False)
for k, v in main.full_args.__dict__.items():
print(f"{k}:\t{v}")
if getattr(args, 'compare_scoring_methods', False) and getattr(args, 'proteingym', False):
# Run scoring method comparison
print_message("Running scoring method comparison...")
dms_ids = getattr(args, 'dms_ids', []) or []
mode = getattr(args, 'mode', 'benchmark')
dms_ids = expand_dms_ids_all(dms_ids, mode=mode)
model_names = getattr(args, 'model_names', []) or []
if len(model_names) == 0:
raise ValueError("--model_names must specify at least one model")
# Set up output path
results_root = getattr(args, 'results_dir', 'results')
output_csv = os.path.join(results_root, 'scoring_methods_comparison.csv')
summary_df = compare_scoring_methods(
model_names=model_names,
device=None,
methods=None,
dms_ids=dms_ids,
progress=True,
output_csv=output_csv
)
print_message(f"Scoring method comparison complete. Results saved to {output_csv}")
return
# Determine if current experiment passed datasets
has_datasets = bool(getattr(args, 'data_names', []) or getattr(args, 'data_dirs', []))
# Run through datasets first (if any)
if has_datasets:
main.apply_current_settings()
main.get_datasets()
print_message(f"Number of sequences: {len(main.all_seqs)}")
if main.full_args.use_wandb_hyperopt:
if not main.full_args.full_finetuning:
main.save_embeddings_to_disk()
HyperoptModule.run_wandb_hyperopt(main)
elif main.full_args.full_finetuning:
main.run_full_finetuning()
elif main.full_args.hybrid_probe:
main.save_embeddings_to_disk()
main.run_hybrid_probes()
elif main.full_args.use_scikit:
main.save_embeddings_to_disk()
main.run_scikit_scheme()
else:
main.save_embeddings_to_disk()
main.run_nn_probes()
else:
# Determine if current experiment passed datasets
has_datasets = bool(getattr(args, 'data_names', []) or getattr(args, 'data_dirs', []))
# Run through datasets first (if any)
if has_datasets:
main.apply_current_settings()
main.get_datasets()
num_seqs = len(main.all_seqs) if hasattr(main, 'all_seqs') else 0
print_message(f"Number of sequences: {num_seqs}")
if main.full_args.full_finetuning:
main.run_full_finetuning()
elif main.full_args.hybrid_probe:
main.save_embeddings_to_disk()
main.run_hybrid_probes()
elif main.full_args.use_scikit:
main.save_embeddings_to_disk()
main.run_scikit_scheme()
else:
main.save_embeddings_to_disk()
main.run_nn_probes()
else:
print_message("No datasets specified; proceeding with ProteinGym.")
if getattr(args, 'proteingym', False):
main.run_proteingym_zero_shot()
try:
results_root = getattr(args, 'results_dir', 'results')
results_dir = os.path.join(results_root, 'proteingym')
pg_scores = ProteinGymRunner.collect_spearman(results_dir, getattr(args, 'model_names', []))
for model_name, score in pg_scores.items():
if isinstance(score, (int, float)):
training_time = getattr(main, '_proteingym_timing', {}).get(model_name, None)
metrics_dict = {'spearman': float(score)}
metrics_dict['training_time_seconds'] = float(training_time)
main.log_metrics('proteingym', model_name, metrics_dict)
except Exception as e:
print_message(f"Failed to log ProteinGym metrics: {e}")
# Write results and generate plots
main.write_results()
main.generate_plots()
main.end_log()
return 0
if __name__ == "__main__":
sys.exit(main(args))