# Copyright 2021 DeepMind Technologies Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Feature pre-processing input pipeline for AlphaFold.""" from alphafold.model.tf import data_transforms from alphafold.model.tf import shape_placeholders import tensorflow.compat.v1 as tf import tree # Pylint gets confused by the curry1 decorator because it changes the number # of arguments to the function. # pylint:disable=no-value-for-parameter NUM_RES = shape_placeholders.NUM_RES NUM_MSA_SEQ = shape_placeholders.NUM_MSA_SEQ NUM_EXTRA_SEQ = shape_placeholders.NUM_EXTRA_SEQ NUM_TEMPLATES = shape_placeholders.NUM_TEMPLATES def nonensembled_map_fns(data_config): """Input pipeline functions which are not ensembled.""" common_cfg = data_config.common map_fns = [ data_transforms.correct_msa_restypes, data_transforms.add_distillation_flag(False), data_transforms.cast_64bit_ints, data_transforms.squeeze_features, # Keep to not disrupt RNG. data_transforms.randomly_replace_msa_with_unknown(0.0), data_transforms.make_seq_mask, data_transforms.make_msa_mask, # Compute the HHblits profile if it's not set. This has to be run before # sampling the MSA. data_transforms.make_hhblits_profile, data_transforms.make_random_crop_to_size_seed, ] if common_cfg.use_templates: map_fns.extend([ data_transforms.fix_templates_aatype, data_transforms.make_template_mask, data_transforms.make_pseudo_beta('template_') ]) map_fns.extend([ data_transforms.make_atom14_masks, ]) return map_fns def ensembled_map_fns(data_config): """Input pipeline functions that can be ensembled and averaged.""" common_cfg = data_config.common eval_cfg = data_config.eval map_fns = [] if common_cfg.reduce_msa_clusters_by_max_templates: pad_msa_clusters = eval_cfg.max_msa_clusters - eval_cfg.max_templates else: pad_msa_clusters = eval_cfg.max_msa_clusters max_msa_clusters = pad_msa_clusters max_extra_msa = common_cfg.max_extra_msa map_fns.append( data_transforms.sample_msa( max_msa_clusters, keep_extra=True)) if 'masked_msa' in common_cfg: # Masked MSA should come *before* MSA clustering so that # the clustering and full MSA profile do not leak information about # the masked locations and secret corrupted locations. map_fns.append( data_transforms.make_masked_msa(common_cfg.masked_msa, eval_cfg.masked_msa_replace_fraction)) if common_cfg.msa_cluster_features: map_fns.append(data_transforms.nearest_neighbor_clusters()) map_fns.append(data_transforms.summarize_clusters()) # Crop after creating the cluster profiles. if max_extra_msa: map_fns.append(data_transforms.crop_extra_msa(max_extra_msa)) else: map_fns.append(data_transforms.delete_extra_msa) map_fns.append(data_transforms.make_msa_feat()) crop_feats = dict(eval_cfg.feat) if eval_cfg.fixed_size: map_fns.append(data_transforms.select_feat(list(crop_feats))) map_fns.append(data_transforms.random_crop_to_size( eval_cfg.crop_size, eval_cfg.max_templates, crop_feats, eval_cfg.subsample_templates)) map_fns.append(data_transforms.make_fixed_size( crop_feats, pad_msa_clusters, common_cfg.max_extra_msa, eval_cfg.crop_size, eval_cfg.max_templates)) else: map_fns.append(data_transforms.crop_templates(eval_cfg.max_templates)) return map_fns def process_tensors_from_config(tensors, data_config): """Apply filters and maps to an existing dataset, based on the config.""" def wrap_ensemble_fn(data, i): """Function to be mapped over the ensemble dimension.""" d = data.copy() fns = ensembled_map_fns(data_config) fn = compose(fns) d['ensemble_index'] = i return fn(d) eval_cfg = data_config.eval tensors = compose( nonensembled_map_fns( data_config))( tensors) tensors_0 = wrap_ensemble_fn(tensors, tf.constant(0)) num_ensemble = eval_cfg.num_ensemble if data_config.common.resample_msa_in_recycling: # Separate batch per ensembling & recycling step. num_ensemble *= data_config.common.num_recycle + 1 if isinstance(num_ensemble, tf.Tensor) or num_ensemble > 1: fn_output_signature = tree.map_structure( tf.TensorSpec.from_tensor, tensors_0) tensors = tf.map_fn( lambda x: wrap_ensemble_fn(tensors, x), tf.range(num_ensemble), parallel_iterations=1, fn_output_signature=fn_output_signature) else: tensors = tree.map_structure(lambda x: x[None], tensors_0) return tensors @data_transforms.curry1 def compose(x, fs): for f in fs: x = f(x) return x