Spaces:
Sleeping
Sleeping
import copy | |
import ml_collections as mlc | |
from dockformerpp.utils.config_tools import set_inf, enforce_config_constraints | |
def model_config( | |
name, | |
train=False, | |
low_prec=False, | |
long_sequence_inference=False | |
): | |
c = copy.deepcopy(config) | |
# TRAINING PRESETS | |
if name == "initial_training": | |
# AF2 Suppl. Table 4, "initial training" setting | |
pass | |
elif name == "finetune_affinity": | |
c.loss.affinity2d.weight = 0.5 | |
c.loss.binding_site.weight = 0.5 | |
c.loss.positions_inter_distogram.weight = 0.5 # this is not essential given fape? | |
else: | |
raise ValueError("Invalid model name") | |
c.globals.use_lma = False | |
if long_sequence_inference: | |
assert(not train) | |
c.globals.use_lma = True | |
if train: | |
c.globals.blocks_per_ckpt = 1 | |
c.globals.use_lma = False | |
if low_prec: | |
c.globals.eps = 1e-4 | |
# If we want exact numerical parity with the original, inf can't be | |
# a global constant | |
set_inf(c, 1e4) | |
enforce_config_constraints(c) | |
return c | |
c_z = mlc.FieldReference(128, field_type=int) | |
c_m = mlc.FieldReference(256, field_type=int) | |
c_t = mlc.FieldReference(64, field_type=int) | |
c_e = mlc.FieldReference(64, field_type=int) | |
c_s = mlc.FieldReference(384, field_type=int) | |
blocks_per_ckpt = mlc.FieldReference(None, field_type=int) | |
aux_distogram_bins = mlc.FieldReference(64, field_type=int) | |
aux_affinity_bins = mlc.FieldReference(32, field_type=int) | |
eps = mlc.FieldReference(1e-8, field_type=float) | |
NUM_RES = "num residues placeholder" | |
NUM_TOKEN = "num tokens placeholder" | |
config = mlc.ConfigDict( | |
{ | |
"data": { | |
"common": { | |
"feat": { | |
"aatype": [NUM_TOKEN], | |
"all_atom_mask": [NUM_TOKEN, None], | |
"all_atom_positions": [NUM_TOKEN, None, None], | |
"atom14_alt_gt_exists": [NUM_TOKEN, None], | |
"atom14_alt_gt_positions": [NUM_TOKEN, None, None], | |
"atom14_atom_exists": [NUM_TOKEN, None], | |
"atom14_atom_is_ambiguous": [NUM_TOKEN, None], | |
"atom14_gt_exists": [NUM_TOKEN, None], | |
"atom14_gt_positions": [NUM_TOKEN, None, None], | |
"atom37_atom_exists": [NUM_TOKEN, None], | |
"backbone_rigid_mask": [NUM_TOKEN], | |
"backbone_rigid_tensor": [NUM_TOKEN, None, None], | |
"chi_angles_sin_cos": [NUM_TOKEN, None, None], | |
"chi_mask": [NUM_TOKEN, None], | |
"no_recycling_iters": [], | |
"pseudo_beta": [NUM_TOKEN, None], | |
"pseudo_beta_mask": [NUM_TOKEN], | |
"residue_index": [NUM_TOKEN], | |
"in_chain_residue_index": [NUM_TOKEN], | |
"chain_index": [NUM_TOKEN], | |
"residx_atom14_to_atom37": [NUM_TOKEN, None], | |
"residx_atom37_to_atom14": [NUM_TOKEN, None], | |
"resolution": [], | |
"rigidgroups_alt_gt_frames": [NUM_TOKEN, None, None, None], | |
"rigidgroups_group_exists": [NUM_TOKEN, None], | |
"rigidgroups_group_is_ambiguous": [NUM_TOKEN, None], | |
"rigidgroups_gt_exists": [NUM_TOKEN, None], | |
"rigidgroups_gt_frames": [NUM_TOKEN, None, None, None], | |
"seq_length": [], | |
"token_mask": [NUM_TOKEN], | |
"target_feat": [NUM_TOKEN, None], | |
"use_clamped_fape": [], | |
}, | |
"max_recycling_iters": 1, | |
"unsupervised_features": [ | |
"aatype", | |
"residue_index", | |
"in_chain_residue_index", | |
"chain_index", | |
"seq_length", | |
"no_recycling_iters", | |
"all_atom_mask", | |
"all_atom_positions", | |
], | |
}, | |
"supervised": { | |
"clamp_prob": 0.9, | |
"supervised_features": [ | |
"resolution", | |
"use_clamped_fape", | |
], | |
}, | |
"predict": { | |
"fixed_size": True, | |
"crop": False, | |
"crop_size": None, | |
"supervised": False, | |
"uniform_recycling": False, | |
}, | |
"eval": { | |
"fixed_size": True, | |
"crop": False, | |
"crop_size": None, | |
"supervised": True, | |
"uniform_recycling": False, | |
}, | |
"train": { | |
"fixed_size": True, | |
"crop": True, | |
"crop_size": 355, | |
"supervised": True, | |
"clamp_prob": 0.9, | |
"uniform_recycling": True, | |
"distogram_mask_prob": 0.1, | |
}, | |
"data_module": { | |
"data_loaders": { | |
"batch_size": 1, | |
# "batch_size": 2, | |
"num_workers": 16, | |
"pin_memory": True, | |
"should_verify": False, | |
}, | |
}, | |
}, | |
# Recurring FieldReferences that can be changed globally here | |
"globals": { | |
"blocks_per_ckpt": blocks_per_ckpt, | |
# Use Staats & Rabe's low-memory attention algorithm. | |
"use_lma": False, | |
"max_lr": 1e-3, | |
"c_z": c_z, | |
"c_m": c_m, | |
"c_t": c_t, | |
"c_e": c_e, | |
"c_s": c_s, | |
"eps": eps, | |
}, | |
"model": { | |
"_mask_trans": False, | |
"structure_input_embedder": { | |
"protein_tf_dim": 20, | |
"additional_tf_dim": 3, # number of classes (prot_r, prot_l, aff) | |
"c_z": c_z, | |
"c_m": c_m, | |
"relpos_k": 32, | |
"prot_min_bin": 3.25, | |
"prot_max_bin": 20.75, | |
"prot_no_bins": 15, | |
"inf": 1e8, | |
}, | |
"recycling_embedder": { | |
"c_z": c_z, | |
"c_m": c_m, | |
"min_bin": 3.25, | |
"max_bin": 20.75, | |
"no_bins": 15, | |
"inf": 1e8, | |
}, | |
"evoformer_stack": { | |
"c_m": c_m, | |
"c_z": c_z, | |
"c_hidden_single_att": 32, | |
"c_hidden_mul": 128, | |
"c_hidden_pair_att": 32, | |
"c_s": c_s, | |
"no_heads_single": 8, | |
"no_heads_pair": 4, | |
# "no_blocks": 48, | |
"no_blocks": 2, | |
"transition_n": 4, | |
"single_dropout": 0.15, | |
"pair_dropout": 0.25, | |
"blocks_per_ckpt": blocks_per_ckpt, | |
"clear_cache_between_blocks": False, | |
"inf": 1e9, | |
"eps": eps, # 1e-10, | |
}, | |
"structure_module": { | |
"c_s": c_s, | |
"c_z": c_z, | |
"c_ipa": 16, | |
"c_resnet": 128, | |
"no_heads_ipa": 12, | |
"no_qk_points": 4, | |
"no_v_points": 8, | |
"dropout_rate": 0.1, | |
"no_blocks": 8, | |
"no_transition_layers": 1, | |
"no_resnet_blocks": 2, | |
"no_angles": 7, | |
"trans_scale_factor": 10, | |
"epsilon": eps, # 1e-12, | |
"inf": 1e5, | |
}, | |
"heads": { | |
"lddt": { | |
"no_bins": 50, | |
"c_in": c_s, | |
"c_hidden": 128, | |
}, | |
"distogram": { | |
"c_z": c_z, | |
"no_bins": aux_distogram_bins, | |
}, | |
"affinity_2d": { | |
"c_z": c_z, | |
"num_bins": aux_affinity_bins, | |
}, | |
"affinity_1d": { | |
"c_s": c_s, | |
"num_bins": aux_affinity_bins, | |
}, | |
"affinity_cls": { | |
"c_s": c_s, | |
"num_bins": aux_affinity_bins, | |
}, | |
"binding_site": { | |
"c_s": c_s, | |
"c_out": 1, | |
}, | |
"inter_contact": { | |
"c_s": c_s, | |
"c_z": c_z, | |
"c_out": 1, | |
}, | |
}, | |
# A negative value indicates that no early stopping will occur, i.e. | |
# the model will always run `max_recycling_iters` number of recycling | |
# iterations. A positive value will enable early stopping if the | |
# difference in pairwise distances is less than the tolerance between | |
# recycling steps. | |
"recycle_early_stop_tolerance": -1. | |
}, | |
"relax": { | |
"max_iterations": 0, # no max | |
"tolerance": 2.39, | |
"stiffness": 10.0, | |
"max_outer_iterations": 20, | |
"exclude_residues": [], | |
}, | |
"loss": { | |
"distogram": { | |
"min_bin": 2.3125, | |
"max_bin": 21.6875, | |
"no_bins": 64, | |
"eps": eps, # 1e-6, | |
"weight": 0.3, | |
}, | |
"positions_inter_distogram": { | |
"max_dist": 20.0, | |
"weight": 0.0, | |
}, | |
"positions_intra_distogram": { | |
"max_dist": 10.0, | |
"weight": 0.0, | |
}, | |
"binding_site": { | |
"weight": 0.0, | |
"pos_class_weight": 20.0, | |
}, | |
"inter_contact": { | |
"weight": 0.0, | |
"pos_class_weight": 200.0, | |
}, | |
"affinity2d": { | |
"min_bin": 0, | |
"max_bin": 15, | |
"no_bins": aux_affinity_bins, | |
"weight": 0.0, | |
}, | |
"affinity_cls": { | |
"min_bin": 0, | |
"max_bin": 15, | |
"no_bins": aux_affinity_bins, | |
"weight": 0.0, | |
}, | |
"fape_backbone": { | |
"clamp_distance": 10.0, | |
"loss_unit_distance": 10.0, | |
"weight": 0.5, | |
}, | |
"fape_sidechain": { | |
"clamp_distance": 10.0, | |
"length_scale": 10.0, | |
"weight": 0.5, | |
}, | |
"fape_interface": { | |
"clamp_distance": 10.0, | |
"length_scale": 10.0, | |
"weight": 0.0, | |
}, | |
"plddt_loss": { | |
"min_resolution": 0.1, | |
"max_resolution": 3.0, | |
"cutoff": 15.0, | |
"no_bins": 50, | |
"eps": eps, # 1e-10, | |
"weight": 0.01, | |
}, | |
"supervised_chi": { | |
"chi_weight": 0.5, | |
"angle_norm_weight": 0.01, | |
"eps": eps, # 1e-6, | |
"weight": 1.0, | |
}, | |
"chain_center_of_mass": { | |
"clamp_distance": -4.0, | |
"weight": 0., | |
"eps": eps, | |
"enabled": False, | |
}, | |
"eps": eps, | |
}, | |
"ema": {"decay": 0.999}, | |
} | |
) | |