diffdock / utils /parsing.py
gcorso's picture
first commit
4a3f787
from argparse import ArgumentParser,FileType
def parse_train_args():
# General arguments
parser = ArgumentParser()
parser.add_argument('--config', type=FileType(mode='r'), default=None)
parser.add_argument('--log_dir', type=str, default='workdir', help='Folder in which to save model and logs')
parser.add_argument('--restart_dir', type=str, help='Folder of previous training model from which to restart')
parser.add_argument('--cache_path', type=str, default='data/cache', help='Folder from where to load/restore cached dataset')
parser.add_argument('--data_dir', type=str, default='data/PDBBind_processed/', help='Folder containing original structures')
parser.add_argument('--split_train', type=str, default='data/splits/timesplit_no_lig_overlap_train', help='Path of file defining the split')
parser.add_argument('--split_val', type=str, default='data/splits/timesplit_no_lig_overlap_val', help='Path of file defining the split')
parser.add_argument('--split_test', type=str, default='data/splits/timesplit_test', help='Path of file defining the split')
parser.add_argument('--test_sigma_intervals', action='store_true', default=False, help='Whether to log loss per noise interval')
parser.add_argument('--val_inference_freq', type=int, default=5, help='Frequency of epochs for which to run expensive inference on val data')
parser.add_argument('--train_inference_freq', type=int, default=None, help='Frequency of epochs for which to run expensive inference on train data')
parser.add_argument('--inference_steps', type=int, default=20, help='Number of denoising steps for inference on val')
parser.add_argument('--num_inference_complexes', type=int, default=100, help='Number of complexes for which inference is run every val/train_inference_freq epochs (None will run it on all)')
parser.add_argument('--inference_earlystop_metric', type=str, default='valinf_rmsds_lt2', help='This is the metric that is addionally used when val_inference_freq is not None')
parser.add_argument('--inference_earlystop_goal', type=str, default='max', help='Whether to maximize or minimize metric')
parser.add_argument('--wandb', action='store_true', default=False, help='')
parser.add_argument('--project', type=str, default='difdock_train', help='')
parser.add_argument('--run_name', type=str, default='', help='')
parser.add_argument('--cudnn_benchmark', action='store_true', default=False, help='CUDA optimization parameter for faster training')
parser.add_argument('--num_dataloader_workers', type=int, default=0, help='Number of workers for dataloader')
parser.add_argument('--pin_memory', action='store_true', default=False, help='pin_memory arg of dataloader')
# Training arguments
parser.add_argument('--n_epochs', type=int, default=400, help='Number of epochs for training')
parser.add_argument('--batch_size', type=int, default=32, help='Batch size')
parser.add_argument('--scheduler', type=str, default=None, help='LR scheduler')
parser.add_argument('--scheduler_patience', type=int, default=20, help='Patience of the LR scheduler')
parser.add_argument('--lr', type=float, default=1e-3, help='Initial learning rate')
parser.add_argument('--restart_lr', type=float, default=None, help='If this is not none, the lr of the optimizer will be overwritten with this value when restarting from a checkpoint.')
parser.add_argument('--w_decay', type=float, default=0.0, help='Weight decay added to loss')
parser.add_argument('--num_workers', type=int, default=1, help='Number of workers for preprocessing')
parser.add_argument('--use_ema', action='store_true', default=False, help='Whether or not to use ema for the model weights')
parser.add_argument('--ema_rate', type=float, default=0.999, help='decay rate for the exponential moving average model parameters ')
# Dataset
parser.add_argument('--limit_complexes', type=int, default=0, help='If positive, the number of training and validation complexes is capped')
parser.add_argument('--all_atoms', action='store_true', default=False, help='Whether to use the all atoms model')
parser.add_argument('--receptor_radius', type=float, default=30, help='Cutoff on distances for receptor edges')
parser.add_argument('--c_alpha_max_neighbors', type=int, default=10, help='Maximum number of neighbors for each residue')
parser.add_argument('--atom_radius', type=float, default=5, help='Cutoff on distances for atom connections')
parser.add_argument('--atom_max_neighbors', type=int, default=8, help='Maximum number of atom neighbours for receptor')
parser.add_argument('--matching_popsize', type=int, default=20, help='Differential evolution popsize parameter in matching')
parser.add_argument('--matching_maxiter', type=int, default=20, help='Differential evolution maxiter parameter in matching')
parser.add_argument('--max_lig_size', type=int, default=None, help='Maximum number of heavy atoms in ligand')
parser.add_argument('--remove_hs', action='store_true', default=False, help='remove Hs')
parser.add_argument('--num_conformers', type=int, default=1, help='Number of conformers to match to each ligand')
parser.add_argument('--esm_embeddings_path', type=str, default=None, help='If this is set then the LM embeddings at that path will be used for the receptor features')
# Diffusion
parser.add_argument('--tr_weight', type=float, default=0.33, help='Weight of translation loss')
parser.add_argument('--rot_weight', type=float, default=0.33, help='Weight of rotation loss')
parser.add_argument('--tor_weight', type=float, default=0.33, help='Weight of torsional loss')
parser.add_argument('--rot_sigma_min', type=float, default=0.1, help='Minimum sigma for rotational component')
parser.add_argument('--rot_sigma_max', type=float, default=1.65, help='Maximum sigma for rotational component')
parser.add_argument('--tr_sigma_min', type=float, default=0.1, help='Minimum sigma for translational component')
parser.add_argument('--tr_sigma_max', type=float, default=30, help='Maximum sigma for translational component')
parser.add_argument('--tor_sigma_min', type=float, default=0.0314, help='Minimum sigma for torsional component')
parser.add_argument('--tor_sigma_max', type=float, default=3.14, help='Maximum sigma for torsional component')
parser.add_argument('--no_torsion', action='store_true', default=False, help='If set only rigid matching')
# Model
parser.add_argument('--num_conv_layers', type=int, default=2, help='Number of interaction layers')
parser.add_argument('--max_radius', type=float, default=5.0, help='Radius cutoff for geometric graph')
parser.add_argument('--scale_by_sigma', action='store_true', default=True, help='Whether to normalise the score')
parser.add_argument('--ns', type=int, default=16, help='Number of hidden features per node of order 0')
parser.add_argument('--nv', type=int, default=4, help='Number of hidden features per node of order >0')
parser.add_argument('--distance_embed_dim', type=int, default=32, help='Embedding size for the distance')
parser.add_argument('--cross_distance_embed_dim', type=int, default=32, help='Embeddings size for the cross distance')
parser.add_argument('--no_batch_norm', action='store_true', default=False, help='If set, it removes the batch norm')
parser.add_argument('--use_second_order_repr', action='store_true', default=False, help='Whether to use only up to first order representations or also second')
parser.add_argument('--cross_max_distance', type=float, default=80, help='Maximum cross distance in case not dynamic')
parser.add_argument('--dynamic_max_cross', action='store_true', default=False, help='Whether to use the dynamic distance cutoff')
parser.add_argument('--dropout', type=float, default=0.0, help='MLP dropout')
parser.add_argument('--embedding_type', type=str, default="sinusoidal", help='Type of diffusion time embedding')
parser.add_argument('--sigma_embed_dim', type=int, default=32, help='Size of the embedding of the diffusion time')
parser.add_argument('--embedding_scale', type=int, default=1000, help='Parameter of the diffusion time embedding')
args = parser.parse_args()
return args