| import os |
| import argparse |
| import sys |
| import yaml |
| from types import SimpleNamespace |
|
|
| from modal_cli import _run_on_modal_cli, _should_auto_run_modal |
| from modal_utils import parse_modal_api_key |
|
|
|
|
| def parse_arguments(): |
| raw_argv = sys.argv[1:] |
| parser = argparse.ArgumentParser(description="Script with arguments mirroring the provided YAML settings.") |
| |
| parser.add_argument("--hf_username", default="Synthyra", help="Hugging Face username.") |
| parser.add_argument("--hf_token", default=None, help="Hugging Face token.") |
| parser.add_argument("--synthyra_api_key", default=None, help="Synthyra API key.") |
| parser.add_argument("--wandb_api_key", default=None, help="Wandb API key.") |
| parser.add_argument("--modal_token_id", default=None, help="Modal token ID used for authentication.") |
| parser.add_argument("--modal_token_secret", default=None, help="Modal token secret used for authentication.") |
| parser.add_argument("--modal_api_key", default=None, help="Backward-compatible Modal key formatted as '<modal_token_id>:<modal_token_secret>'.") |
| parser.add_argument("--rebuild_modal", action="store_true", default=False, help="Force rebuild and deploy of the Modal backend before running.") |
| parser.add_argument("--delete_modal_embeddings", action="store_true", default=False, help="Delete all embedding cache files from the Modal volume before submission.") |
|
|
| |
| parser.add_argument("--hf_home", type=str, default=None, help="Customize the HF cache directory.") |
| parser.add_argument("--yaml_path", type=str, default=None, help="Path to the YAML file.") |
| parser.add_argument("--log_dir", type=str, default="logs", help="Path to the log directory.") |
| parser.add_argument("--results_dir", type=str, default="results", help="Path to the results directory.") |
| parser.add_argument("--model_save_dir", default="weights", help="Directory to save models.") |
| parser.add_argument("--embedding_save_dir", default="embeddings", help="Directory to save embeddings.") |
| parser.add_argument("--download_dir", default="Synthyra/vector_embeddings", help="Directory to download embeddings to.") |
| parser.add_argument("--plots_dir", default="plots", help="Directory to save plots.") |
| parser.add_argument("--replay_path", type=str, default=None, help="Path to the replay file.") |
| parser.add_argument("--pretrained_probe_path", type=str, default=None) |
| |
| |
| parser.add_argument("--delimiter", default=",", help="Delimiter for data.") |
| parser.add_argument("--col_names", nargs="+", default=["seqs", "labels"], help="Column names.") |
| parser.add_argument("--max_length", type=int, default=2048, help="Maximum sequence length.") |
| parser.add_argument("--trim", action="store_true", default=False, |
| help="Whether to trim sequences (default: False). If False, sequences are removed from the dataset if they are longer than max length. If True, they are truncated to max length." |
| ) |
| parser.add_argument("--data_names", nargs="+", default=[], help="List of HF dataset names.") |
| parser.add_argument("--data_dirs", nargs="+", default=[], help="List of local data directories.") |
| parser.add_argument("--aa_to_dna", action="store_true", default=False, help="Translate amino-acid sequences to DNA codon sequences using common human synonymous codons.") |
| parser.add_argument("--aa_to_rna", action="store_true", default=False, help="Translate amino-acid sequences to RNA codon sequences using common human synonymous codons.") |
| parser.add_argument("--dna_to_aa", action="store_true", default=False, help="Translate DNA codon sequences to amino-acid sequences and drop stop codons.") |
| parser.add_argument("--rna_to_aa", action="store_true", default=False, help="Translate RNA codon sequences to amino-acid sequences and drop stop codons.") |
| parser.add_argument("--codon_to_aa", action="store_true", default=False, help="Translate codon-token sequences to amino-acid sequences and drop stop codons.") |
| parser.add_argument("--aa_to_codon", action="store_true", default=False, help="Translate amino-acid sequences to codon-token sequences.") |
| parser.add_argument("--random_pair_flipping", action="store_true", default=False, help="Enable random swapping of paired inputs during training.") |
|
|
| |
| parser.add_argument("--model_names", nargs="+", default=None, help="List of preset model names to use (e.g. ESM2-8). Mutually exclusive with --model_paths/--model_types.") |
| parser.add_argument("--model_paths", nargs="+", default=None, help="List of model paths (HuggingFace or local). Must be paired with --model_types. Mutually exclusive with --model_names.") |
| parser.add_argument("--model_types", nargs="+", default=None, help="List of model type keywords paired with --model_paths (e.g. esm2, esmc, protbert, prott5, ankh, glm, dplm, dplm2, protclm, onehot, amplify, e1, vec2vec, calm, custom, random).") |
| parser.add_argument("--model_dtype", type=str, choices=["fp32", "fp16", "bf16", "float32", "float16", "bfloat16"], default="bf16", help="Data type for loading base models.") |
| parser.add_argument("--use_xformers", action="store_true", default=False, help="Use xformers memory efficient attention for AMPLIFY models (default: False).") |
|
|
| |
| parser.add_argument("--probe_type", choices=["linear", "transformer", "retrievalnet", "lyra"], default="linear", help="Type of probe.") |
| parser.add_argument("--tokenwise", action="store_true", default=False, help="Tokenwise probe (default: False).") |
| parser.add_argument("--hidden_size", type=int, default=8192, help="Hidden dimension size for linear probe MLP.") |
| parser.add_argument("--transformer_hidden_size", type=int, default=512, help="Hidden dimension size for transformer probe.") |
| parser.add_argument("--dropout", type=float, default=0.2, help="Dropout rate.") |
| parser.add_argument("--n_layers", type=int, default=1, help="Number of layers.") |
| parser.add_argument("--pre_ln", action="store_false", default=True, |
| help="Disable pre-layernorm (default: enabled). Use --pre_ln to toggle off.") |
| parser.add_argument("--classifier_size", type=int, default=4096, help="Feed-forward dimension.") |
| parser.add_argument("--transformer_dropout", type=float, default=0.1, help="Dropout rate for the transformer layers.") |
| parser.add_argument("--classifier_dropout", type=float, default=0.2, help="Dropout rate for the classifier.") |
| parser.add_argument("--n_heads", type=int, default=4, help="Number of heads in multi-head attention.") |
| parser.add_argument("--rotary", action="store_false", default=True, |
| help="Disable rotary embeddings (default: enabled). Use --rotary to toggle off.") |
| parser.add_argument("--probe_pooling_types", nargs="+", default=["mean", "var"], help="Pooling types to use.") |
| parser.add_argument("--use_bias", action="store_true", default=False, help="Use bias in Linear layers (default: False)") |
| parser.add_argument("--save_model", action="store_true", default=False, help="Save trained model (default: False).") |
| parser.add_argument("--production_model", action="store_true", default=False, help="Production model (default: False).") |
| parser.add_argument("--lora", action="store_true", default=False, help="Use LoRA (default: False).") |
| parser.add_argument("--lora_r", type=int, default=8, help="Number of trainable parameters in the LoRA model.") |
| parser.add_argument("--lora_alpha", type=float, default=32.0, help="Alpha for the LoRA model.") |
| parser.add_argument("--lora_dropout", type=float, default=0.01, help="Dropout rate for the LoRA model.") |
| parser.add_argument("--sim_type", choices=["dot", "euclidean", "cosine"], default="dot", help="Cross-attention mechanism for token-parameter-attention") |
| parser.add_argument("--token_attention", action="store_true", default=False, help="If true, use TokenFormer instead of Transformer blocks") |
| parser.add_argument("--add_token_ids", action="store_true", default=False, help="If true, add learned token type embeddings to distinguish protein A vs B in PPI tasks.") |
|
|
| |
| parser.add_argument("--scikit_n_iter", type=int, default=10, help="Number of iterations for scikit model.") |
| parser.add_argument("--scikit_cv", type=int, default=3, help="Number of cross-validation folds for scikit model.") |
| parser.add_argument("--scikit_random_state", type=int, default=None, help="Random state for scikit model (if None, uses global seed).") |
| parser.add_argument("--scikit_model_name", type=str, default=None, help="Name of the scikit model to use.") |
| parser.add_argument("--scikit_model_args", type=str, default=None, help="JSON string of hyperparameters to use (skips tuning). E.g. '{\"n_estimators\": 500, \"max_depth\": 7}'") |
| parser.add_argument("--use_scikit", action="store_true", default=False, help="Use scikit model (default: False).") |
| parser.add_argument("--n_jobs", type=int, default=1, help="Number of processes to use in scikit.") |
|
|
| |
| parser.add_argument("--embedding_batch_size", type=int, default=16, help="Batch size for embedding generation.") |
| parser.add_argument("--embedding_num_workers", type=int, default=0, help="Number of worker processes for embedding generation.") |
| parser.add_argument("--num_workers", type=int, default=0, help="Number of worker processes for data loading.") |
| parser.add_argument("--download_embeddings", action="store_true", default=False, help="Whether to download embeddings (default: False).") |
| parser.add_argument("--matrix_embed", action="store_true", default=False, help="Use matrix embedding (default: False).") |
| parser.add_argument("--embedding_pooling_types", nargs="+", default=["mean", "var"], help="Pooling types for embeddings.") |
| parser.add_argument("--save_embeddings", action="store_true", default=False, help="Save computed embeddings (default: False).") |
| parser.add_argument("--embed_dtype", type=str, choices=["fp32", "fp16", "bf16", "float32", "float16", "bfloat16"], default=None, help="Data type for embeddings. If omitted, uses --model_dtype.") |
| parser.add_argument("--sql", action="store_true", default=False, help="Whether to use SQL storage (default: False).") |
| parser.add_argument("--read_scaler", type=int, default=100, help="Read scaler for SQL storage.") |
| |
| |
| parser.add_argument("--multi_column", nargs="+", default=None, help="If set, list of sequence column names to combine per sample.") |
|
|
| |
| parser.add_argument("--num_epochs", type=int, default=200, help="Number of epochs to train for.") |
| parser.add_argument("--probe_batch_size", type=int, default=64, help="Batch size for probe training.") |
| parser.add_argument("--base_batch_size", type=int, default=4, help="Batch size for base model training.") |
| parser.add_argument("--probe_grad_accum", type=int, default=1, help='Gradient accumulation steps for probe training.') |
| parser.add_argument("--base_grad_accum", type=int, default=8, help='Gradient accumulation steps for base model training.') |
| parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate.") |
| |
| |
| |
| |
| |
| parser.add_argument("--weight_decay", type=float, default=0.00, help="Weight decay.") |
| parser.add_argument("--patience", type=int, default=1, help="Patience for early stopping.") |
| parser.add_argument("--seed", type=int, default=None, help="Seed for reproducibility (if omitted, current time is used).") |
| parser.add_argument("--deterministic", action="store_true", default=False, |
| help="Enable deterministic behavior for reproducibility (can slightly slow down training).") |
| parser.add_argument("--full_finetuning", action="store_true", default=False, help="Full finetuning (default: False).") |
| parser.add_argument("--hybrid_probe", action="store_true", default=False, help="Hybrid probe (default: False).") |
| parser.add_argument("--num_runs", type=int, default=1, help="Number of training runs with different seeds. Results will show mean±std across runs.") |
| |
| |
| parser.add_argument("--dms_ids", nargs="+", default=["all"], |
| help="ProteinGym DMS assay IDs to evaluate (space-separated), or 'all' to run all assays.") |
| parser.add_argument("--proteingym", action="store_true", default=False, help="ProteinGym (default: False).") |
| parser.add_argument("--mode", type=str, default='benchmark', |
| help="ProteinGym zero-shot mode: 'benchmark', 'indels', 'multiples', 'singles'") |
| parser.add_argument("--scoring_method", choices=["masked_marginal", "mutant_marginal", "wildtype_marginal", "pll", "global_log_prob"], default="masked_marginal", |
| help="Select a scoring method for ProteinGym zero-shot.") |
| parser.add_argument("--scoring_window", choices=["optimal", "sliding"], default="optimal", |
| help="Select how to slice the sequence for ProteinGym zero-shot.") |
| parser.add_argument("--pg_batch_size", type=int, default=32, |
| help="Batch size for ProteinGym zero-shot scoring (default: 32).") |
| parser.add_argument("--compare_scoring_methods", action="store_true", default=False, |
| help="Compare different scoring methods across models and DMS assays (default: False).") |
| parser.add_argument("--score_only", action="store_true", default=False, |
| help="Only run the ProteinGym benchmarking script on existing CSV files, skip zero-shot scoring (default: False).") |
|
|
| |
| parser.add_argument("--use_wandb_hyperopt", action="store_true", default=False, help="Use Weights & Biases hyperparameter optimization.") |
| parser.add_argument("--wandb_project", type=str, default="Protify", help="W&B project name for sweeps.") |
| parser.add_argument("--wandb_entity", type=str, default=None, help="W&B entity (team/user) for sweeps.") |
| parser.add_argument("--sweep_config_path", type=str, default="yamls/sweep.yaml", help="Path to W&B sweep config YAML.") |
| parser.add_argument("--sweep_count", type=int, default=10, help="Number of hyperparameter trials to run in the sweep.") |
| parser.add_argument("--sweep_method", type=str, default="bayes", choices=["bayes", "grid", "random"], help="Sweep method for hyperparameter optimization.") |
| parser.add_argument("--sweep_metric_cls",type=str,default="eval_loss", help="Classification metric to optimize during sweep (e.g., eval_f1, eval_accuracy, eval_mcc)") |
| parser.add_argument("--sweep_metric_reg",type=str,default="eval_loss", help="Regression metric to optimize during sweep (e.g., eval_r_squared, eval_spearman_rho, eval_pearson_rho)") |
| parser.add_argument("--sweep_goal", type=str, default='minimize', choices=['maximize', 'minimize'], help="Goal for the sweep metric (maximize/minimize)") |
| args = parser.parse_args() |
|
|
| |
| if args.model_paths is not None: |
| assert args.model_types is not None, "--model_types is required when --model_paths is provided." |
| assert len(args.model_paths) == len(args.model_types), f"--model_paths ({len(args.model_paths)}) and --model_types ({len(args.model_types)}) must have the same length." |
| assert args.model_names is None, "--model_names cannot be used together with --model_paths/--model_types." |
| elif args.model_types is not None: |
| assert args.model_paths is not None, "--model_paths is required when --model_types is provided." |
| if args.model_names is None and args.model_paths is None: |
| args.model_names = ["ESM2-8"] |
|
|
| args.modal_cli_credentials_provided = ( |
| ("--modal_api_key" in raw_argv) |
| or ("--modal_token_id" in raw_argv) |
| or ("--modal_token_secret" in raw_argv) |
| or any(item.startswith("--modal_api_key=") for item in raw_argv) |
| or any(item.startswith("--modal_token_id=") for item in raw_argv) |
| or any(item.startswith("--modal_token_secret=") for item in raw_argv) |
| ) |
|
|
| if args.modal_api_key is not None and (args.modal_token_id is None or args.modal_token_secret is None): |
| parsed_modal_token_id, parsed_modal_token_secret = parse_modal_api_key(args.modal_api_key) |
| if args.modal_token_id is None: |
| args.modal_token_id = parsed_modal_token_id |
| if args.modal_token_secret is None: |
| args.modal_token_secret = parsed_modal_token_secret |
|
|
| if args.modal_token_id is not None: |
| os.environ["MODAL_TOKEN_ID"] = args.modal_token_id |
| if args.modal_token_secret is not None: |
| os.environ["MODAL_TOKEN_SECRET"] = args.modal_token_secret |
|
|
| if args.hf_token is not None: |
| from huggingface_hub import login |
| |
| os.environ["HF_TOKEN"] = args.hf_token |
| login(args.hf_token) |
| print(f"Logged in to HuggingFace Hub with token from arguments") |
| else: |
| |
| hf_token_env = os.environ.get("HF_TOKEN") |
| if hf_token_env: |
| print(f"Note: HF_TOKEN found in environment (from Modal secret or other source)") |
| print(f"Note: This token will be used for read operations only unless overridden") |
| if args.wandb_api_key is not None: |
| try: |
| import wandb |
| wandb.login(key=args.wandb_api_key) |
| print('Logged into Weights & Biases') |
| except Exception as e: |
| print(f'W&B login failed: {e}') |
| if args.synthyra_api_key is not None: |
| print('Synthyra API not integrated yet') |
|
|
| if args.yaml_path is not None: |
| with open(args.yaml_path, 'r') as file: |
| settings = yaml.safe_load(file) |
| yaml_args = SimpleNamespace(**settings) |
|
|
| def _merge_store_true(cli_value: bool, key: str) -> bool: |
| if cli_value: |
| return True |
| if key in yaml_args.__dict__: |
| return bool(yaml_args.__dict__[key]) |
| return False |
|
|
| if args.hf_token is not None: |
| yaml_args.hf_token = args.hf_token |
| elif "hf_token" not in yaml_args.__dict__: |
| yaml_args.hf_token = None |
|
|
| if args.hf_home is not None: |
| yaml_args.hf_home = args.hf_home |
| elif "hf_home" not in yaml_args.__dict__: |
| yaml_args.hf_home = None |
|
|
| if args.synthyra_api_key is not None: |
| yaml_args.synthyra_api_key = args.synthyra_api_key |
| elif "synthyra_api_key" not in yaml_args.__dict__: |
| yaml_args.synthyra_api_key = None |
|
|
| if args.wandb_api_key is not None: |
| yaml_args.wandb_api_key = args.wandb_api_key |
| elif "wandb_api_key" not in yaml_args.__dict__: |
| yaml_args.wandb_api_key = None |
|
|
| if args.modal_token_id is not None: |
| yaml_args.modal_token_id = args.modal_token_id |
| elif "modal_token_id" not in yaml_args.__dict__: |
| yaml_args.modal_token_id = None |
|
|
| if args.modal_token_secret is not None: |
| yaml_args.modal_token_secret = args.modal_token_secret |
| elif "modal_token_secret" not in yaml_args.__dict__: |
| yaml_args.modal_token_secret = None |
|
|
| if args.modal_api_key is not None: |
| yaml_args.modal_api_key = args.modal_api_key |
| elif "modal_api_key" not in yaml_args.__dict__: |
| yaml_args.modal_api_key = None |
| yaml_args.rebuild_modal = _merge_store_true(args.rebuild_modal, "rebuild_modal") |
| yaml_args.delete_modal_embeddings = _merge_store_true(args.delete_modal_embeddings, "delete_modal_embeddings") |
|
|
| yaml_args.use_wandb_hyperopt = _merge_store_true(args.use_wandb_hyperopt, "use_wandb_hyperopt") |
|
|
| if (args.wandb_project != "Protify") or ("wandb_project" not in yaml_args.__dict__): |
| yaml_args.wandb_project = args.wandb_project |
| if (args.wandb_entity is not None) or ("wandb_entity" not in yaml_args.__dict__): |
| yaml_args.wandb_entity = args.wandb_entity |
| if (args.sweep_config_path != "yamls/sweep.yaml") or ("sweep_config_path" not in yaml_args.__dict__): |
| yaml_args.sweep_config_path = args.sweep_config_path |
| if (args.sweep_count != 10) or ("sweep_count" not in yaml_args.__dict__): |
| yaml_args.sweep_count = args.sweep_count |
| if (args.sweep_method != "bayes") or ("sweep_method" not in yaml_args.__dict__): |
| yaml_args.sweep_method = args.sweep_method |
| if (args.sweep_metric_cls != "eval_loss") or ("sweep_metric_cls" not in yaml_args.__dict__): |
| yaml_args.sweep_metric_cls = args.sweep_metric_cls |
| if (args.sweep_metric_reg != "eval_loss") or ("sweep_metric_reg" not in yaml_args.__dict__): |
| yaml_args.sweep_metric_reg = args.sweep_metric_reg |
| if (args.sweep_goal != "minimize") or ("sweep_goal" not in yaml_args.__dict__): |
| yaml_args.sweep_goal = args.sweep_goal |
| yaml_args.yaml_path = args.yaml_path |
| yaml_args.aa_to_dna = _merge_store_true(args.aa_to_dna, "aa_to_dna") |
| yaml_args.aa_to_rna = _merge_store_true(args.aa_to_rna, "aa_to_rna") |
| yaml_args.dna_to_aa = _merge_store_true(args.dna_to_aa, "dna_to_aa") |
| yaml_args.rna_to_aa = _merge_store_true(args.rna_to_aa, "rna_to_aa") |
| yaml_args.codon_to_aa = _merge_store_true(args.codon_to_aa, "codon_to_aa") |
| yaml_args.aa_to_codon = _merge_store_true(args.aa_to_codon, "aa_to_codon") |
| yaml_args.random_pair_flipping = _merge_store_true(args.random_pair_flipping, "random_pair_flipping") |
| |
| if not hasattr(yaml_args, 'proteingym'): |
| yaml_args.proteingym = False |
| if not hasattr(yaml_args, 'dms_ids'): |
| yaml_args.dms_ids = ["all"] |
| if not hasattr(yaml_args, 'mode'): |
| yaml_args.mode = None |
| if not hasattr(yaml_args, 'scoring_method'): |
| yaml_args.scoring_method = "masked_marginal" |
| |
| if not hasattr(yaml_args, 'num_runs'): |
| yaml_args.num_runs = 1 |
| if "model_dtype" not in yaml_args.__dict__ or yaml_args.model_dtype is None: |
| yaml_args.model_dtype = args.model_dtype |
| if "embed_dtype" not in yaml_args.__dict__: |
| yaml_args.embed_dtype = args.embed_dtype |
| if "model_paths" not in yaml_args.__dict__: |
| yaml_args.model_paths = args.model_paths |
| if "model_types" not in yaml_args.__dict__: |
| yaml_args.model_types = args.model_types |
| if "model_names" not in yaml_args.__dict__: |
| yaml_args.model_names = args.model_names |
| return yaml_args |
| else: |
| return args |
|
|
|
|
| if __name__ == "__main__": |
| |
| args = parse_arguments() |
|
|
| |
| has_datasets = bool(args.data_names or args.data_dirs) |
| has_proteingym = bool(args.proteingym) |
| has_modal_maintenance = bool(args.modal_cli_credentials_provided and (args.rebuild_modal or args.delete_modal_embeddings)) |
| if not has_datasets and not has_proteingym and not has_modal_maintenance: |
| raise AssertionError("No datasets specified. Provide --data_names or --data_dirs, or run a ProteinGym experiment.") |
|
|
| if args.use_xformers: |
| os.environ["_USE_XFORMERS"] = "1" |
| print("xformers memory efficient attention enabled for AMPLIFY models") |
|
|
| if args.hf_home is not None: |
| |
| import pathlib |
| base_path = args.hf_home |
| cache_root = f"{base_path}/hf_cache" |
| tmp_root = f"{base_path}/tmp" |
| pathlib.Path(cache_root).mkdir(parents=True, exist_ok=True) |
| pathlib.Path(tmp_root).mkdir(parents=True, exist_ok=True) |
|
|
| os.environ["HF_HOME"] = cache_root |
| os.environ["HF_DATASETS_CACHE"] = f"{cache_root}/datasets" |
| os.environ["TRANSFORMERS_CACHE"] = f"{cache_root}/transformers" |
| os.environ["HF_HUB_CACHE"] = f"{cache_root}/hub" |
| print(f"HF_HOME: {os.environ['HF_HOME']}") |
| print(f"HF_DATASETS_CACHE: {os.environ['HF_DATASETS_CACHE']}") |
| print(f"TRANSFORMERS_CACHE: {os.environ['TRANSFORMERS_CACHE']}") |
| print(f"HF_HUB_CACHE: {os.environ['HF_HUB_CACHE']}") |
|
|
| |
| |
| if args.deterministic: |
| from seed_utils import set_determinism |
| set_determinism() |
| |
| import entrypoint_setup |
|
|
|
|
| import torch |
| from torchinfo import summary |
|
|
| from probes.get_probe import ProbeArguments, get_probe |
| from base_models.get_base_models import BaseModelArguments, get_tokenizer, get_base_model_for_training |
| from base_models.utils import wrap_lora |
| from data.data_mixin import DataMixin, DataArguments |
| from probes.trainers import TrainerMixin, TrainerArguments |
| from probes.scikit_classes import ScikitArguments, ScikitProbe |
| from embedder import EmbeddingArguments, Embedder, get_embedding_filename |
| from logger import MetricsLogger, log_method_calls |
| from utils import torch_load, print_message, expand_dms_ids_all |
| from visualization.plot_result import create_plots |
| from hyperopt_utils import HyperoptModule |
| from benchmarks.proteingym.scorer import ProteinGymRunner |
| from benchmarks.proteingym.compare_scoring_methods import compare_scoring_methods |
| from seed_utils import set_global_seed |
|
|
|
|
| class MainProcess(MetricsLogger, DataMixin, TrainerMixin): |
| def __init__(self, full_args, GUI=False): |
| super(MainProcess, self).__init__(full_args) |
| super(DataMixin, self).__init__() |
| super(TrainerMixin, self).__init__() |
| self.full_args = full_args |
| if not GUI: |
| self.start_log_main() |
|
|
| self.dtype_map = { |
| "fp32": torch.float32, |
| "fp16": torch.float16, |
| "bf16": torch.bfloat16, |
| "float32": torch.float32, |
| "float16": torch.float16, |
| "bfloat16": torch.bfloat16, |
| "float8_e4m3fn": torch.float8_e4m3fn, |
| "float8_e5m2": torch.float8_e5m2, |
| |
| } |
|
|
| def _build_scikit_args(self): |
| if "scikit_n_iter" in self.full_args.__dict__: |
| n_iter = self.full_args.scikit_n_iter |
| else: |
| n_iter = 10 |
|
|
| if "scikit_cv" in self.full_args.__dict__: |
| cv = self.full_args.scikit_cv |
| else: |
| cv = 3 |
|
|
| if "scikit_random_state" in self.full_args.__dict__: |
| random_state = self.full_args.scikit_random_state |
| else: |
| random_state = None |
|
|
| if "scikit_model_name" in self.full_args.__dict__: |
| model_name = self.full_args.scikit_model_name |
| else: |
| model_name = None |
|
|
| if "production_model" in self.full_args.__dict__: |
| production_model = self.full_args.production_model |
| else: |
| production_model = False |
|
|
| return ScikitArguments( |
| n_iter=n_iter, |
| cv=cv, |
| random_state=random_state, |
| model_name=model_name, |
| production_model=production_model, |
| ) |
|
|
| @log_method_calls |
| def apply_current_settings(self): |
| if "model_dtype" not in self.full_args.__dict__: |
| self.full_args.model_dtype = "bf16" |
| if "embed_dtype" not in self.full_args.__dict__: |
| self.full_args.embed_dtype = None |
| if isinstance(self.full_args.model_dtype, str): |
| self.full_args.model_dtype = self.dtype_map[self.full_args.model_dtype] |
| if self.full_args.embed_dtype is None: |
| self.full_args.embed_dtype = self.full_args.model_dtype |
| elif isinstance(self.full_args.embed_dtype, str): |
| self.full_args.embed_dtype = self.dtype_map[self.full_args.embed_dtype] |
| else: |
| self.full_args.embed_dtype = self.full_args.embed_dtype |
| self.data_args = DataArguments(**self.full_args.__dict__) |
| self.embedding_args = EmbeddingArguments(**self.full_args.__dict__) |
| self.model_args = BaseModelArguments(**self.full_args.__dict__) |
| self.probe_args = ProbeArguments(**self.full_args.__dict__) |
| self.trainer_args = TrainerArguments(**self.full_args.__dict__) |
| self.logger_args = SimpleNamespace(**self.full_args.__dict__) |
| self.scikit_args = self._build_scikit_args() |
| self._sql = self.full_args.sql |
| self._full = self.full_args.matrix_embed |
| self._max_length = self.full_args.max_length |
| self._trim = self.full_args.trim |
| self._delimiter = self.full_args.delimiter |
| self._col_names = self.full_args.col_names |
| self._aa_to_dna = self.full_args.aa_to_dna |
| self._aa_to_rna = self.full_args.aa_to_rna |
| self._dna_to_aa = self.full_args.dna_to_aa |
| self._rna_to_aa = self.full_args.rna_to_aa |
| self._codon_to_aa = self.full_args.codon_to_aa |
| self._aa_to_codon = self.full_args.aa_to_codon |
| self._multi_column = getattr(self.full_args, 'multi_column', None) |
|
|
| @log_method_calls |
| def get_datasets(self): |
| self.datasets, self.all_seqs = self.get_data() |
|
|
| @log_method_calls |
| def save_embeddings_to_disk(self): |
| self.embedding_args.save_embeddings = True |
| embedder = Embedder(self.embedding_args, self.all_seqs) |
| for display_name, dispatch_type, model_path in self.model_args.model_entries(): |
| _ = embedder(display_name, model_type=dispatch_type, model_path=model_path) |
|
|
| def _create_model_factory(self, model_name, tokenwise, num_labels, hybrid, model_path=None): |
| """Function for creating fresh models in multi-run mode.""" |
| def factory(): |
| model, _ = get_base_model_for_training( |
| model_name, |
| tokenwise=tokenwise, |
| num_labels=num_labels, |
| hybrid=hybrid, |
| dtype=self.model_args.model_dtype, |
| model_path=model_path, |
| ) |
| if self.probe_args.lora: |
| model = wrap_lora(model, self.probe_args.lora_r, self.probe_args.lora_alpha, self.probe_args.lora_dropout) |
| return model |
| return factory |
| |
| def _create_probe_factory(self): |
| """Function for creating fresh probes in multi-run mode.""" |
| def factory(): |
| return get_probe(self.probe_args) |
| return factory |
|
|
| def _run_nn_probe( |
| self, |
| model_name, |
| data_name, |
| train_set, |
| valid_set, |
| test_set, |
| tokenizer, |
| emb_dict=None, |
| ppi=False, |
| source_model_name=None, |
| sweep_mode: bool = False, |
| ): |
| if source_model_name is None: |
| source_model_name = model_name |
| |
| probe = get_probe(self.probe_args) |
| summary(probe) |
| |
| |
| probe, valid_metrics, test_metrics, _, _ = self.trainer_probe( |
| model=probe, |
| tokenizer=tokenizer, |
| model_name=model_name, |
| data_name=data_name, |
| train_dataset=train_set, |
| valid_dataset=valid_set, |
| test_dataset=test_set, |
| emb_dict=emb_dict, |
| ppi=ppi, |
| log_id=self.random_id, |
| source_model_name=source_model_name, |
| ) |
| if not sweep_mode: |
| self.log_metrics(data_name, model_name, valid_metrics, split_name='valid') |
| self.log_metrics(data_name, model_name, test_metrics, split_name='test') |
| return probe, valid_metrics, test_metrics |
|
|
| def _train_nn_probe_fold(self, model_name, dms_id, subtrain_seqs, subtrain_labels, |
| valid_seqs, valid_labels, test_seqs, test_labels, |
| emb_dict, fold_info): |
| """Trains a neural network probe on a ProteinGym DMS assay CV fold.""" |
|
|
| train_set = {'seqs': subtrain_seqs, 'labels': subtrain_labels} |
| valid_set = None if (valid_seqs is None or valid_labels is None) else {'seqs': valid_seqs, 'labels': valid_labels} |
| test_set = {'seqs': test_seqs, 'labels': test_labels} |
| |
| |
| tokenizer = get_tokenizer(model_name) |
| |
| if self._sql: |
| save_path = os.path.join(self.embedding_args.embedding_save_dir, |
| f'{model_name}_{self._full}.db') |
| input_dim = self.get_embedding_dim_sql(save_path, subtrain_seqs[0], tokenizer) |
| emb_for_training = None |
| else: |
| save_path = os.path.join(self.embedding_args.embedding_save_dir, |
| f'{model_name}_{self._full}.pth') |
| emb_for_training = torch_load(save_path) if os.path.exists(save_path) else emb_dict |
| input_dim = self.get_embedding_dim_pth(emb_for_training, subtrain_seqs[0], tokenizer) |
| |
| |
| self.probe_args.input_size = input_dim |
| self.probe_args.task_type = 'regression' |
| self.probe_args.num_labels = 1 |
| self.trainer_args.task_type = 'regression' |
| |
| probe = get_probe(self.probe_args) |
| _, _, test_metrics = self.trainer_probe( |
| model=probe, |
| tokenizer=tokenizer, |
| model_name=model_name, |
| data_name=f"{dms_id}_{fold_info}", |
| train_dataset=train_set, |
| valid_dataset=valid_set, |
| test_dataset=test_set, |
| emb_dict=emb_for_training, |
| ppi=False, |
| log_id=f"{self.random_id}_{fold_info}", |
| source_model_name=model_name, |
| ) |
| |
| |
| rho = test_metrics.get('spearman_rho', test_metrics.get('test_spearman_rho', None)) |
| mse = test_metrics.get('mse', test_metrics.get('test_mse', None)) |
| return rho, mse |
| |
| def _run_full_finetuning( |
| self, |
| model_name, |
| data_name, |
| train_set, |
| valid_set, |
| test_set, |
| ppi=False, |
| source_model_name=None, |
| sweep_mode: bool = False, |
| model_path: str = None, |
| ): |
| if source_model_name is None: |
| source_model_name = model_name |
| tokenwise = self.probe_args.tokenwise |
| num_labels = self.probe_args.num_labels |
| num_runs = getattr(self.trainer_args, 'num_runs', 1) |
| |
| model_factory = self._create_model_factory(model_name, tokenwise, num_labels, hybrid=False, model_path=model_path) if num_runs > 1 else None |
| model, tokenizer = get_base_model_for_training( |
| model_name, |
| tokenwise=tokenwise, |
| num_labels=num_labels, |
| hybrid=False, |
| dtype=self.model_args.model_dtype, |
| model_path=model_path, |
| ) |
| if self.probe_args.lora: |
| model = wrap_lora(model, self.probe_args.lora_r, self.probe_args.lora_alpha, self.probe_args.lora_dropout) |
| summary(model) |
| model, valid_metrics, test_metrics, _, _ = self.trainer_base_model( |
| model=model, |
| tokenizer=tokenizer, |
| model_name=model_name, |
| data_name=data_name, |
| train_dataset=train_set, |
| valid_dataset=valid_set, |
| test_dataset=test_set, |
| ppi=ppi, |
| log_id=self.random_id, |
| source_model_name=source_model_name, |
| model_factory=model_factory, |
| ) |
| if not sweep_mode: |
| self.log_metrics(data_name, model_name, valid_metrics, split_name='valid') |
| self.log_metrics(data_name, model_name, test_metrics, split_name='test') |
| return model, valid_metrics, test_metrics |
|
|
| def _run_hybrid_probe( |
| self, |
| model_name, |
| data_name, |
| train_set, |
| valid_set, |
| test_set, |
| tokenizer, |
| emb_dict=None, |
| ppi=False, |
| source_model_name=None, |
| sweep_mode: bool = False, |
| model_path: str = None, |
| ): |
| if source_model_name is None: |
| source_model_name = model_name |
| |
| if "random" in model_name.lower(): |
| print_message(f"Model {model_name} does not support hybrid training. Training a linear probe instead.") |
| probe = get_probe(self.probe_args) |
| summary(probe) |
| probe, valid_metrics, test_metrics = self.trainer_probe( |
| model=probe, |
| tokenizer=tokenizer, |
| model_name=model_name, |
| data_name=data_name, |
| train_dataset=train_set, |
| valid_dataset=valid_set, |
| test_dataset=test_set, |
| emb_dict=emb_dict, |
| ppi=ppi, |
| log_id=self.random_id, |
| source_model_name=source_model_name, |
| ) |
| if not sweep_mode: |
| self.log_metrics(data_name, model_name, valid_metrics, split_name='valid') |
| self.log_metrics(data_name, model_name, test_metrics, split_name='test') |
| return probe, valid_metrics, test_metrics |
| |
| tokenwise = self.probe_args.tokenwise |
| num_labels = self.probe_args.num_labels |
| num_runs = getattr(self.trainer_args, 'num_runs', 1) |
| |
| model_factory = self._create_model_factory(model_name, tokenwise, num_labels, hybrid=True, model_path=model_path) if num_runs > 1 else None |
| probe_factory = self._create_probe_factory() if num_runs > 1 else None |
| model, tokenizer = get_base_model_for_training( |
| model_name, |
| tokenwise=tokenwise, |
| num_labels=num_labels, |
| hybrid=True, |
| dtype=self.model_args.model_dtype, |
| model_path=model_path, |
| ) |
| if self.probe_args.lora: |
| model = wrap_lora(model, self.probe_args.lora_r, self.probe_args.lora_alpha, self.probe_args.lora_dropout) |
| probe = get_probe(self.probe_args) |
| summary(model) |
| summary(probe) |
| model, valid_metrics, test_metrics, _, _ = self.trainer_hybrid_model( |
| model=model, |
| tokenizer=tokenizer, |
| probe=probe, |
| model_name=model_name, |
| data_name=data_name, |
| train_dataset=train_set, |
| valid_dataset=valid_set, |
| test_dataset=test_set, |
| emb_dict=emb_dict, |
| ppi=ppi, |
| log_id=self.random_id, |
| source_model_name=source_model_name, |
| model_factory=model_factory, |
| probe_factory=probe_factory, |
| ) |
| if not sweep_mode: |
| self.log_metrics(data_name, model_name, valid_metrics, split_name='valid') |
| self.log_metrics(data_name, model_name, test_metrics, split_name='test') |
| return model, valid_metrics, test_metrics |
|
|
| @log_method_calls |
| def run_full_finetuning(self): |
| total_combinations = len(self.model_args.model_names) * len(self.datasets) |
| self.logger.info(f"Processing {total_combinations} model/dataset combinations") |
| for display_name, dispatch_type, model_path in self.model_args.model_entries(): |
| for data_name, dataset in self.datasets.items(): |
| self.logger.info(f"Processing dataset: {data_name}") |
| train_set, valid_set, test_set, num_labels, label_type, ppi = dataset |
| self.probe_args.num_labels = num_labels |
| self.probe_args.task_type = label_type |
| self.trainer_args.task_type = label_type |
| self.logger.info(f'Training probe for {data_name} with {display_name}') |
| _ = self._run_full_finetuning(dispatch_type, data_name, train_set, valid_set, test_set, ppi, model_path=model_path) |
| torch.cuda.empty_cache() |
|
|
| @log_method_calls |
| def run_hybrid_probes(self): |
| probe_args = self.probe_args |
| test_seq = self.all_seqs[0] |
|
|
| |
| total_combinations = len(self.model_args.model_names) * len(self.datasets) |
| self.logger.info(f"Processing {total_combinations} model/dataset combinations") |
| |
| |
| |
| for display_name, dispatch_type, model_path in self.model_args.model_entries(): |
| self.logger.info(f"Processing model: {display_name}") |
| |
| |
| tokenizer = get_tokenizer(dispatch_type, model_path=model_path) |
|
|
| |
| pooling_types = self.embedding_args.pooling_types |
| if self._sql: |
| |
| filename = get_embedding_filename(display_name, self._full, pooling_types, 'db') |
| save_path = os.path.join(self.embedding_args.embedding_save_dir, filename) |
| input_size = self.get_embedding_dim_sql(save_path, test_seq, tokenizer) |
| emb_dict = None |
| else: |
| |
| filename = get_embedding_filename(display_name, self._full, pooling_types, 'pth') |
| save_path = os.path.join(self.embedding_args.embedding_save_dir, filename) |
| emb_dict = torch_load(save_path) |
| input_size = self.get_embedding_dim_pth(emb_dict, test_seq, tokenizer) |
|
|
| |
| if (not self._full) and getattr(self.full_args, 'multi_column', None): |
| input_size = input_size * len(self.full_args.multi_column) |
|
|
| |
| for data_name, dataset in self.datasets.items(): |
| self.logger.info(f"Processing dataset: {data_name}") |
| train_set, valid_set, test_set, num_labels, label_type, ppi = dataset |
| if ppi and not self._full: |
| probe_args.input_size = input_size * 2 |
| else: |
| probe_args.input_size = input_size |
| |
| self.probe_args.num_labels = num_labels |
| self.probe_args.task_type = label_type |
| |
| self.trainer_args.task_type = label_type |
| self.logger.info(f'Training probe for {data_name} with {display_name}') |
| |
| |
| _ = self._run_hybrid_probe( |
| model_name=dispatch_type, |
| data_name=data_name, |
| train_set=train_set, |
| valid_set=valid_set, |
| test_set=test_set, |
| tokenizer=tokenizer, |
| emb_dict=emb_dict, |
| ppi=ppi, |
| source_model_name=display_name, |
| model_path=model_path, |
| ) |
| torch.cuda.empty_cache() |
| |
|
|
| @log_method_calls |
| def run_nn_probes(self): |
| probe_args = self.probe_args |
| test_seq = self.all_seqs[0] |
|
|
| |
| total_combinations = len(self.model_args.model_names) * len(self.datasets) |
| self.logger.info(f"Processing {total_combinations} model/dataset combinations") |
| |
| |
| |
| for display_name, dispatch_type, model_path in self.model_args.model_entries(): |
| self.logger.info(f"Processing model: {display_name}") |
| |
| |
| tokenizer = get_tokenizer(dispatch_type, model_path=model_path) |
|
|
| |
| pooling_types = self.embedding_args.pooling_types |
| if self._sql: |
| |
| filename = get_embedding_filename(display_name, self._full, pooling_types, 'db') |
| save_path = os.path.join(self.embedding_args.embedding_save_dir, filename) |
| input_size = self.get_embedding_dim_sql(save_path, test_seq, tokenizer) |
| emb_dict = None |
| else: |
| |
| filename = get_embedding_filename(display_name, self._full, pooling_types, 'pth') |
| save_path = os.path.join(self.embedding_args.embedding_save_dir, filename) |
| emb_dict = torch_load(save_path) |
| input_size = self.get_embedding_dim_pth(emb_dict, test_seq, tokenizer) |
|
|
| |
| if (not self._full) and getattr(self.full_args, 'multi_column', None): |
| input_size = input_size * len(self.full_args.multi_column) |
|
|
| print(f'Input dim: {input_size}') |
|
|
| |
| for data_name, dataset in self.datasets.items(): |
| self.logger.info(f"Processing dataset: {data_name}") |
| train_set, valid_set, test_set, num_labels, label_type, ppi = dataset |
| if ppi and not self._full: |
| probe_args.input_size = input_size * 2 |
| else: |
| probe_args.input_size = input_size |
| |
| self.probe_args.num_labels = num_labels |
| self.probe_args.task_type = label_type |
| |
| self.trainer_args.task_type = label_type |
| self.logger.info(f'Training probe for {data_name} with {display_name}') |
| |
| |
| _ = self._run_nn_probe( |
| model_name=display_name, |
| data_name=data_name, |
| train_set=train_set, |
| valid_set=valid_set, |
| test_set=test_set, |
| tokenizer=tokenizer, |
| emb_dict=emb_dict, |
| ppi=ppi, |
| source_model_name=display_name, |
| ) |
| torch.cuda.empty_cache() |
| |
|
|
| @log_method_calls |
| def run_scikit_scheme(self): |
| self.scikit_args = self._build_scikit_args() |
| scikit_probe = ScikitProbe(self.scikit_args) |
| if "n_jobs" in self.full_args.__dict__: |
| scikit_probe.n_jobs = self.full_args.n_jobs |
| else: |
| scikit_probe.n_jobs = 1 |
| for display_name, dispatch_type, model_path in self.model_args.model_entries(): |
| for data_name, dataset in self.datasets.items(): |
| |
| X_train, y_train, X_valid, y_valid, X_test, y_test, label_type = self.prepare_scikit_dataset(display_name, dataset) |
| |
| |
| if self.scikit_args.model_name is not None: |
| print_message(f"Skipping LazyPredict, using specified model: {self.scikit_args.model_name}") |
| results = scikit_probe.run_specific_model(X_train, y_train, X_valid, y_valid, X_test, y_test, model_results=None) |
| else: |
| |
| if label_type == 'singlelabel': |
| results = scikit_probe.find_best_classifier(X_train, y_train, X_valid, y_valid) |
| elif label_type == 'regression': |
| results = scikit_probe.find_best_regressor(X_train, y_train, X_valid, y_valid) |
| else: |
| raise ValueError(f'Label type {label_type} not supported') |
| |
| results = scikit_probe.run_specific_model(X_train, y_train, X_valid, y_valid, X_test, y_test, results) |
| |
| |
| metrics_dict = {'test_mcc': results.final_scores} if isinstance(results.final_scores, (int, float)) else results.final_scores |
| self.log_metrics(data_name, display_name, metrics_dict, split_name='test') |
| |
| @log_method_calls |
| def generate_plots(self): |
| print_message("Generating visualization plots...") |
| |
| results_file = os.path.join(self.full_args.results_dir, f"{self.random_id}.tsv") |
| |
| |
| if not os.path.exists(results_file): |
| print_message(f"Results file not found: {results_file}") |
| return |
| |
| |
| output_dir = self.full_args.plots_dir |
|
|
| print_message(f"Generating plots in {output_dir}...") |
| create_plots(results_file, output_dir) |
| print_message("Plots generated successfully!") |
| |
| def run_proteingym_zero_shot(self): |
| """Run ProteinGym zero-shot for all specified models and DMS ids.""" |
| dms_ids = getattr(self.full_args, 'dms_ids', []) or [] |
| mode = getattr(self.full_args, 'mode', 'benchmark') |
| dms_ids = expand_dms_ids_all(dms_ids, mode=mode) |
| if len(dms_ids) == 0: |
| raise ValueError("--dms_ids is required when --proteingym is specified") |
| model_names = self.model_args.model_names |
| if len(model_names) == 0: |
| raise ValueError("--model_names must specify at least one model") |
| assert self.model_args._model_paths is None, "ProteinGym zero-shot requires --model_names (preset names), not --model_paths/--model_types." |
| |
| results_root = getattr(self.full_args, 'results_dir', 'results') |
| results_dir = os.path.join(results_root, 'proteingym') |
| scoring_method = getattr(self.full_args, 'scoring_method', 'masked_marginal') |
| scoring_window = getattr(self.full_args, 'scoring_window', 'optimal') |
| if isinstance(mode, str) and mode.lower() == 'indels': |
| print_message("Only pll is currently supported for indels scoring.") |
| scoring_method = 'pll' |
| |
| |
| self.logger.info(f"Running ProteinGym zero-shot with [{scoring_method}] scoring on {len(dms_ids)} DMS ids with models: {model_names}") |
| |
| runner = ProteinGymRunner( |
| results_dir=results_dir, |
| repo_id="GleghornLab/ProteinGym_DMS", |
| ) |
| self._proteingym_timing = runner.run( |
| dms_ids=dms_ids, |
| model_names=model_names, |
| mode=mode, |
| scoring_method=scoring_method, |
| scoring_window=scoring_window, |
| batch_size=getattr(self.full_args, 'pg_batch_size', 32), |
| ) |
| print_message(f"ProteinGym zero-shot complete. Results in {results_dir}") |
|
|
| |
| runner.run_benchmark(model_names, dms_ids, mode, scoring_method) |
|
|
| def main(args: SimpleNamespace): |
| chosen_seed = set_global_seed(args.seed) |
| args.seed = chosen_seed |
|
|
| if _should_auto_run_modal(args): |
| return _run_on_modal_cli(args) |
|
|
| if args.replay_path is not None: |
| from logger import LogReplayer |
| replayer = LogReplayer(args.replay_path) |
| replay_args = replayer.parse_log() |
| replay_args.replay_path = args.replay_path |
| |
| try: |
| |
| if not hasattr(replay_args, 'seed') or replay_args.seed is None: |
| replay_args.seed = None |
| if not hasattr(replay_args, 'deterministic') or replay_args.deterministic is None: |
| replay_args.deterministic = getattr(args, 'deterministic', False) |
| chosen_seed = set_global_seed(replay_args.seed, deterministic=replay_args.deterministic) |
| replay_args.seed = chosen_seed |
| except Exception: |
| pass |
| main = MainProcess(replay_args, GUI=False) |
| for k, v in main.full_args.__dict__.items(): |
| print(f"{k}:\t{v}") |
| replayer.run_replay(main) |
| |
| else: |
| main = MainProcess(args, GUI=False) |
| for k, v in main.full_args.__dict__.items(): |
| print(f"{k}:\t{v}") |
|
|
| if getattr(args, 'compare_scoring_methods', False) and getattr(args, 'proteingym', False): |
| |
| print_message("Running scoring method comparison...") |
| dms_ids = getattr(args, 'dms_ids', []) or [] |
| mode = getattr(args, 'mode', 'benchmark') |
| dms_ids = expand_dms_ids_all(dms_ids, mode=mode) |
| model_names = getattr(args, 'model_names', []) or [] |
| if len(model_names) == 0: |
| raise ValueError("--model_names must specify at least one model") |
| |
| |
| results_root = getattr(args, 'results_dir', 'results') |
| output_csv = os.path.join(results_root, 'scoring_methods_comparison.csv') |
| |
| summary_df = compare_scoring_methods( |
| model_names=model_names, |
| device=None, |
| methods=None, |
| dms_ids=dms_ids, |
| progress=True, |
| output_csv=output_csv |
| ) |
| print_message(f"Scoring method comparison complete. Results saved to {output_csv}") |
| return |
|
|
| |
| has_datasets = bool(getattr(args, 'data_names', []) or getattr(args, 'data_dirs', [])) |
|
|
| |
| if has_datasets: |
| main.apply_current_settings() |
| main.get_datasets() |
| print_message(f"Number of sequences: {len(main.all_seqs)}") |
| if main.full_args.use_wandb_hyperopt: |
| if not main.full_args.full_finetuning: |
| main.save_embeddings_to_disk() |
| HyperoptModule.run_wandb_hyperopt(main) |
|
|
| elif main.full_args.full_finetuning: |
| main.run_full_finetuning() |
|
|
| elif main.full_args.hybrid_probe: |
| main.save_embeddings_to_disk() |
| main.run_hybrid_probes() |
|
|
| elif main.full_args.use_scikit: |
| main.save_embeddings_to_disk() |
| main.run_scikit_scheme() |
| else: |
| main.save_embeddings_to_disk() |
| main.run_nn_probes() |
| else: |
| |
| has_datasets = bool(getattr(args, 'data_names', []) or getattr(args, 'data_dirs', [])) |
|
|
| |
| if has_datasets: |
| main.apply_current_settings() |
| main.get_datasets() |
| num_seqs = len(main.all_seqs) if hasattr(main, 'all_seqs') else 0 |
| print_message(f"Number of sequences: {num_seqs}") |
|
|
| if main.full_args.full_finetuning: |
| main.run_full_finetuning() |
|
|
| elif main.full_args.hybrid_probe: |
| main.save_embeddings_to_disk() |
| main.run_hybrid_probes() |
|
|
| elif main.full_args.use_scikit: |
| main.save_embeddings_to_disk() |
| main.run_scikit_scheme() |
| |
| else: |
| main.save_embeddings_to_disk() |
| main.run_nn_probes() |
| else: |
| print_message("No datasets specified; proceeding with ProteinGym.") |
|
|
| if getattr(args, 'proteingym', False): |
| main.run_proteingym_zero_shot() |
| try: |
| results_root = getattr(args, 'results_dir', 'results') |
| results_dir = os.path.join(results_root, 'proteingym') |
| pg_scores = ProteinGymRunner.collect_spearman(results_dir, getattr(args, 'model_names', [])) |
| for model_name, score in pg_scores.items(): |
| if isinstance(score, (int, float)): |
| training_time = getattr(main, '_proteingym_timing', {}).get(model_name, None) |
| metrics_dict = {'spearman': float(score)} |
| metrics_dict['training_time_seconds'] = float(training_time) |
| main.log_metrics('proteingym', model_name, metrics_dict) |
| except Exception as e: |
| print_message(f"Failed to log ProteinGym metrics: {e}") |
|
|
| |
| main.write_results() |
| main.generate_plots() |
| main.end_log() |
| return 0 |
|
|
| if __name__ == "__main__": |
| sys.exit(main(args)) |
|
|