Spaces:
Running
on
T4
Running
on
T4
# Copyright 2021 DeepMind Technologies Limited | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Model config.""" | |
import copy | |
from alphafold.model.tf import shape_placeholders | |
import ml_collections | |
NUM_RES = shape_placeholders.NUM_RES | |
NUM_MSA_SEQ = shape_placeholders.NUM_MSA_SEQ | |
NUM_EXTRA_SEQ = shape_placeholders.NUM_EXTRA_SEQ | |
NUM_TEMPLATES = shape_placeholders.NUM_TEMPLATES | |
def model_config(name: str) -> ml_collections.ConfigDict: | |
"""Get the ConfigDict of a CASP14 model.""" | |
if name not in CONFIG_DIFFS: | |
raise ValueError(f'Invalid model name {name}.') | |
cfg = copy.deepcopy(CONFIG) | |
cfg.update_from_flattened_dict(CONFIG_DIFFS[name]) | |
return cfg | |
CONFIG_DIFFS = { | |
'model_1': { | |
# Jumper et al. (2021) Suppl. Table 5, Model 1.1.1 | |
'data.common.max_extra_msa': 5120, | |
'data.common.reduce_msa_clusters_by_max_templates': True, | |
'data.common.use_templates': True, | |
'model.embeddings_and_evoformer.template.embed_torsion_angles': True, | |
'model.embeddings_and_evoformer.template.enabled': True | |
}, | |
'model_2': { | |
# Jumper et al. (2021) Suppl. Table 5, Model 1.1.2 | |
'data.common.reduce_msa_clusters_by_max_templates': True, | |
'data.common.use_templates': True, | |
'model.embeddings_and_evoformer.template.embed_torsion_angles': True, | |
'model.embeddings_and_evoformer.template.enabled': True | |
}, | |
'model_3': { | |
# Jumper et al. (2021) Suppl. Table 5, Model 1.2.1 | |
'data.common.max_extra_msa': 5120, | |
}, | |
'model_4': { | |
# Jumper et al. (2021) Suppl. Table 5, Model 1.2.2 | |
'data.common.max_extra_msa': 5120, | |
}, | |
'model_5': { | |
# Jumper et al. (2021) Suppl. Table 5, Model 1.2.3 | |
}, | |
# The following models are fine-tuned from the corresponding models above | |
# with an additional predicted_aligned_error head that can produce | |
# predicted TM-score (pTM) and predicted aligned errors. | |
'model_1_ptm': { | |
'data.common.max_extra_msa': 5120, | |
'data.common.reduce_msa_clusters_by_max_templates': True, | |
'data.common.use_templates': True, | |
'model.embeddings_and_evoformer.template.embed_torsion_angles': True, | |
'model.embeddings_and_evoformer.template.enabled': True, | |
'model.heads.predicted_aligned_error.weight': 0.1 | |
}, | |
'model_2_ptm': { | |
'data.common.reduce_msa_clusters_by_max_templates': True, | |
'data.common.use_templates': True, | |
'model.embeddings_and_evoformer.template.embed_torsion_angles': True, | |
'model.embeddings_and_evoformer.template.enabled': True, | |
'model.heads.predicted_aligned_error.weight': 0.1 | |
}, | |
'model_3_ptm': { | |
'data.common.max_extra_msa': 5120, | |
'model.heads.predicted_aligned_error.weight': 0.1 | |
}, | |
'model_4_ptm': { | |
'data.common.max_extra_msa': 5120, | |
'model.heads.predicted_aligned_error.weight': 0.1 | |
}, | |
'model_5_ptm': { | |
'model.heads.predicted_aligned_error.weight': 0.1 | |
} | |
} | |
CONFIG = ml_collections.ConfigDict({ | |
'data': { | |
'common': { | |
'masked_msa': { | |
'profile_prob': 0.1, | |
'same_prob': 0.1, | |
'uniform_prob': 0.1 | |
}, | |
'max_extra_msa': 1024, | |
'msa_cluster_features': True, | |
'num_recycle': 3, | |
'reduce_msa_clusters_by_max_templates': False, | |
'resample_msa_in_recycling': True, | |
'template_features': [ | |
'template_all_atom_positions', 'template_sum_probs', | |
'template_aatype', 'template_all_atom_masks', | |
'template_domain_names' | |
], | |
'unsupervised_features': [ | |
'aatype', 'residue_index', 'sequence', 'msa', 'domain_name', | |
'num_alignments', 'seq_length', 'between_segment_residues', | |
'deletion_matrix' | |
], | |
'use_templates': False, | |
}, | |
'eval': { | |
'feat': { | |
'aatype': [NUM_RES], | |
'all_atom_mask': [NUM_RES, None], | |
'all_atom_positions': [NUM_RES, None, None], | |
'alt_chi_angles': [NUM_RES, None], | |
'atom14_alt_gt_exists': [NUM_RES, None], | |
'atom14_alt_gt_positions': [NUM_RES, None, None], | |
'atom14_atom_exists': [NUM_RES, None], | |
'atom14_atom_is_ambiguous': [NUM_RES, None], | |
'atom14_gt_exists': [NUM_RES, None], | |
'atom14_gt_positions': [NUM_RES, None, None], | |
'atom37_atom_exists': [NUM_RES, None], | |
'backbone_affine_mask': [NUM_RES], | |
'backbone_affine_tensor': [NUM_RES, None], | |
'bert_mask': [NUM_MSA_SEQ, NUM_RES], | |
'chi_angles': [NUM_RES, None], | |
'chi_mask': [NUM_RES, None], | |
'extra_deletion_value': [NUM_EXTRA_SEQ, NUM_RES], | |
'extra_has_deletion': [NUM_EXTRA_SEQ, NUM_RES], | |
'extra_msa': [NUM_EXTRA_SEQ, NUM_RES], | |
'extra_msa_mask': [NUM_EXTRA_SEQ, NUM_RES], | |
'extra_msa_row_mask': [NUM_EXTRA_SEQ], | |
'is_distillation': [], | |
'msa_feat': [NUM_MSA_SEQ, NUM_RES, None], | |
'msa_mask': [NUM_MSA_SEQ, NUM_RES], | |
'msa_row_mask': [NUM_MSA_SEQ], | |
'pseudo_beta': [NUM_RES, None], | |
'pseudo_beta_mask': [NUM_RES], | |
'random_crop_to_size_seed': [None], | |
'residue_index': [NUM_RES], | |
'residx_atom14_to_atom37': [NUM_RES, None], | |
'residx_atom37_to_atom14': [NUM_RES, None], | |
'resolution': [], | |
'rigidgroups_alt_gt_frames': [NUM_RES, None, None], | |
'rigidgroups_group_exists': [NUM_RES, None], | |
'rigidgroups_group_is_ambiguous': [NUM_RES, None], | |
'rigidgroups_gt_exists': [NUM_RES, None], | |
'rigidgroups_gt_frames': [NUM_RES, None, None], | |
'seq_length': [], | |
'seq_mask': [NUM_RES], | |
'target_feat': [NUM_RES, None], | |
'template_aatype': [NUM_TEMPLATES, NUM_RES], | |
'template_all_atom_masks': [NUM_TEMPLATES, NUM_RES, None], | |
'template_all_atom_positions': [ | |
NUM_TEMPLATES, NUM_RES, None, None], | |
'template_backbone_affine_mask': [NUM_TEMPLATES, NUM_RES], | |
'template_backbone_affine_tensor': [ | |
NUM_TEMPLATES, NUM_RES, None], | |
'template_mask': [NUM_TEMPLATES], | |
'template_pseudo_beta': [NUM_TEMPLATES, NUM_RES, None], | |
'template_pseudo_beta_mask': [NUM_TEMPLATES, NUM_RES], | |
'template_sum_probs': [NUM_TEMPLATES, None], | |
'true_msa': [NUM_MSA_SEQ, NUM_RES] | |
}, | |
'fixed_size': True, | |
'subsample_templates': False, # We want top templates. | |
'masked_msa_replace_fraction': 0.15, | |
'max_msa_clusters': 512, | |
'max_templates': 4, | |
'num_ensemble': 1, | |
}, | |
}, | |
'model': { | |
'embeddings_and_evoformer': { | |
'evoformer_num_block': 48, | |
'evoformer': { | |
'msa_row_attention_with_pair_bias': { | |
'dropout_rate': 0.15, | |
'gating': True, | |
'num_head': 8, | |
'orientation': 'per_row', | |
'shared_dropout': True | |
}, | |
'msa_column_attention': { | |
'dropout_rate': 0.0, | |
'gating': True, | |
'num_head': 8, | |
'orientation': 'per_column', | |
'shared_dropout': True | |
}, | |
'msa_transition': { | |
'dropout_rate': 0.0, | |
'num_intermediate_factor': 4, | |
'orientation': 'per_row', | |
'shared_dropout': True | |
}, | |
'outer_product_mean': { | |
'chunk_size': 128, | |
'dropout_rate': 0.0, | |
'num_outer_channel': 32, | |
'orientation': 'per_row', | |
'shared_dropout': True | |
}, | |
'triangle_attention_starting_node': { | |
'dropout_rate': 0.25, | |
'gating': True, | |
'num_head': 4, | |
'orientation': 'per_row', | |
'shared_dropout': True | |
}, | |
'triangle_attention_ending_node': { | |
'dropout_rate': 0.25, | |
'gating': True, | |
'num_head': 4, | |
'orientation': 'per_column', | |
'shared_dropout': True | |
}, | |
'triangle_multiplication_outgoing': { | |
'dropout_rate': 0.25, | |
'equation': 'ikc,jkc->ijc', | |
'num_intermediate_channel': 128, | |
'orientation': 'per_row', | |
'shared_dropout': True | |
}, | |
'triangle_multiplication_incoming': { | |
'dropout_rate': 0.25, | |
'equation': 'kjc,kic->ijc', | |
'num_intermediate_channel': 128, | |
'orientation': 'per_row', | |
'shared_dropout': True | |
}, | |
'pair_transition': { | |
'dropout_rate': 0.0, | |
'num_intermediate_factor': 4, | |
'orientation': 'per_row', | |
'shared_dropout': True | |
} | |
}, | |
'extra_msa_channel': 64, | |
'extra_msa_stack_num_block': 4, | |
'max_relative_feature': 32, | |
'custom_relative_features': False, | |
'msa_channel': 256, | |
'pair_channel': 128, | |
'prev_pos': { | |
'min_bin': 3.25, | |
'max_bin': 20.75, | |
'num_bins': 15 | |
}, | |
'recycle_features': True, | |
'recycle_pos': True, | |
'recycle_dgram': False, | |
'backprop_dgram': False, | |
'backprop_dgram_temp': 1.0, | |
'seq_channel': 384, | |
'template': { | |
'attention': { | |
'gating': False, | |
'key_dim': 64, | |
'num_head': 4, | |
'value_dim': 64 | |
}, | |
'dgram_features': { | |
'min_bin': 3.25, | |
'max_bin': 50.75, | |
'num_bins': 39 | |
}, | |
'backprop_dgram': False, | |
'backprop_dgram_temp': 1.0, | |
'embed_torsion_angles': False, | |
'enabled': False, | |
'template_pair_stack': { | |
'num_block': 2, | |
'triangle_attention_starting_node': { | |
'dropout_rate': 0.25, | |
'gating': True, | |
'key_dim': 64, | |
'num_head': 4, | |
'orientation': 'per_row', | |
'shared_dropout': True, | |
'value_dim': 64 | |
}, | |
'triangle_attention_ending_node': { | |
'dropout_rate': 0.25, | |
'gating': True, | |
'key_dim': 64, | |
'num_head': 4, | |
'orientation': 'per_column', | |
'shared_dropout': True, | |
'value_dim': 64 | |
}, | |
'triangle_multiplication_outgoing': { | |
'dropout_rate': 0.25, | |
'equation': 'ikc,jkc->ijc', | |
'num_intermediate_channel': 64, | |
'orientation': 'per_row', | |
'shared_dropout': True | |
}, | |
'triangle_multiplication_incoming': { | |
'dropout_rate': 0.25, | |
'equation': 'kjc,kic->ijc', | |
'num_intermediate_channel': 64, | |
'orientation': 'per_row', | |
'shared_dropout': True | |
}, | |
'pair_transition': { | |
'dropout_rate': 0.0, | |
'num_intermediate_factor': 2, | |
'orientation': 'per_row', | |
'shared_dropout': True | |
} | |
}, | |
'max_templates': 4, | |
'subbatch_size': 128, | |
'use_template_unit_vector': False, | |
} | |
}, | |
'global_config': { | |
'mixed_precision': False, | |
'deterministic': False, | |
'subbatch_size': 4, | |
'use_remat': False, | |
'zero_init': True | |
}, | |
'heads': { | |
'distogram': { | |
'first_break': 2.3125, | |
'last_break': 21.6875, | |
'num_bins': 64, | |
'weight': 0.3 | |
}, | |
'predicted_aligned_error': { | |
# `num_bins - 1` bins uniformly space the | |
# [0, max_error_bin A] range. | |
# The final bin covers [max_error_bin A, +infty] | |
# 31A gives bins with 0.5A width. | |
'max_error_bin': 31., | |
'num_bins': 64, | |
'num_channels': 128, | |
'filter_by_resolution': True, | |
'min_resolution': 0.1, | |
'max_resolution': 3.0, | |
'weight': 0.0, | |
}, | |
'experimentally_resolved': { | |
'filter_by_resolution': True, | |
'max_resolution': 3.0, | |
'min_resolution': 0.1, | |
'weight': 0.01 | |
}, | |
'structure_module': { | |
'num_layer': 8, | |
'fape': { | |
'clamp_distance': 10.0, | |
'clamp_type': 'relu', | |
'loss_unit_distance': 10.0 | |
}, | |
'angle_norm_weight': 0.01, | |
'chi_weight': 0.5, | |
'clash_overlap_tolerance': 1.5, | |
'compute_in_graph_metrics': True, | |
'dropout': 0.1, | |
'num_channel': 384, | |
'num_head': 12, | |
'num_layer_in_transition': 3, | |
'num_point_qk': 4, | |
'num_point_v': 8, | |
'num_scalar_qk': 16, | |
'num_scalar_v': 16, | |
'position_scale': 10.0, | |
'sidechain': { | |
'atom_clamp_distance': 10.0, | |
'num_channel': 128, | |
'num_residual_block': 2, | |
'weight_frac': 0.5, | |
'length_scale': 10., | |
}, | |
'structural_violation_loss_weight': 1.0, | |
'violation_tolerance_factor': 12.0, | |
'weight': 1.0 | |
}, | |
'predicted_lddt': { | |
'filter_by_resolution': True, | |
'max_resolution': 3.0, | |
'min_resolution': 0.1, | |
'num_bins': 50, | |
'num_channels': 128, | |
'weight': 0.01 | |
}, | |
'masked_msa': { | |
'num_output': 23, | |
'weight': 2.0 | |
}, | |
}, | |
'num_recycle': 3, | |
'backprop_recycle': False, | |
'resample_msa_in_recycling': True, | |
'add_prev': False, | |
'use_struct': True, | |
}, | |
}) | |