# Copyright 2021 DeepMind Technologies Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Datasets consisting of proteins.""" from typing import Dict, Mapping, Optional, Sequence from alphafold.model.tf import protein_features import numpy as np import tensorflow.compat.v1 as tf TensorDict = Dict[str, tf.Tensor] def parse_tfexample( raw_data: bytes, features: protein_features.FeaturesMetadata, key: Optional[str] = None) -> Dict[str, tf.train.Feature]: """Read a single TF Example proto and return a subset of its features. Args: raw_data: A serialized tf.Example proto. features: A dictionary of features, mapping string feature names to a tuple (dtype, shape). This dictionary should be a subset of protein_features.FEATURES (or the dictionary itself for all features). key: Optional string with the SSTable key of that tf.Example. This will be added into features as a 'key' but only if requested in features. Returns: A dictionary of features mapping feature names to features. Only the given features are returned, all other ones are filtered out. """ feature_map = { k: tf.io.FixedLenSequenceFeature(shape=(), dtype=v[0], allow_missing=True) for k, v in features.items() } parsed_features = tf.io.parse_single_example(raw_data, feature_map) reshaped_features = parse_reshape_logic(parsed_features, features, key=key) return reshaped_features def _first(tensor: tf.Tensor) -> tf.Tensor: """Returns the 1st element - the input can be a tensor or a scalar.""" return tf.reshape(tensor, shape=(-1,))[0] def parse_reshape_logic( parsed_features: TensorDict, features: protein_features.FeaturesMetadata, key: Optional[str] = None) -> TensorDict: """Transforms parsed serial features to the correct shape.""" # Find out what is the number of sequences and the number of alignments. num_residues = tf.cast(_first(parsed_features["seq_length"]), dtype=tf.int32) if "num_alignments" in parsed_features: num_msa = tf.cast(_first(parsed_features["num_alignments"]), dtype=tf.int32) else: num_msa = 0 if "template_domain_names" in parsed_features: num_templates = tf.cast( tf.shape(parsed_features["template_domain_names"])[0], dtype=tf.int32) else: num_templates = 0 if key is not None and "key" in features: parsed_features["key"] = [key] # Expand dims from () to (1,). # Reshape the tensors according to the sequence length and num alignments. for k, v in parsed_features.items(): new_shape = protein_features.shape( feature_name=k, num_residues=num_residues, msa_length=num_msa, num_templates=num_templates, features=features) new_shape_size = tf.constant(1, dtype=tf.int32) for dim in new_shape: new_shape_size *= tf.cast(dim, tf.int32) assert_equal = tf.assert_equal( tf.size(v), new_shape_size, name="assert_%s_shape_correct" % k, message="The size of feature %s (%s) could not be reshaped " "into %s" % (k, tf.size(v), new_shape)) if "template" not in k: # Make sure the feature we are reshaping is not empty. assert_non_empty = tf.assert_greater( tf.size(v), 0, name="assert_%s_non_empty" % k, message="The feature %s is not set in the tf.Example. Either do not " "request the feature or use a tf.Example that has the " "feature set." % k) with tf.control_dependencies([assert_non_empty, assert_equal]): parsed_features[k] = tf.reshape(v, new_shape, name="reshape_%s" % k) else: with tf.control_dependencies([assert_equal]): parsed_features[k] = tf.reshape(v, new_shape, name="reshape_%s" % k) return parsed_features def _make_features_metadata( feature_names: Sequence[str]) -> protein_features.FeaturesMetadata: """Makes a feature name to type and shape mapping from a list of names.""" # Make sure these features are always read. required_features = ["aatype", "sequence", "seq_length"] feature_names = list(set(feature_names) | set(required_features)) features_metadata = {name: protein_features.FEATURES[name] for name in feature_names} return features_metadata def create_tensor_dict( raw_data: bytes, features: Sequence[str], key: Optional[str] = None, ) -> TensorDict: """Creates a dictionary of tensor features. Args: raw_data: A serialized tf.Example proto. features: A list of strings of feature names to be returned in the dataset. key: Optional string with the SSTable key of that tf.Example. This will be added into features as a 'key' but only if requested in features. Returns: A dictionary of features mapping feature names to features. Only the given features are returned, all other ones are filtered out. """ features_metadata = _make_features_metadata(features) return parse_tfexample(raw_data, features_metadata, key) def np_to_tensor_dict( np_example: Mapping[str, np.ndarray], features: Sequence[str], ) -> TensorDict: """Creates dict of tensors from a dict of NumPy arrays. Args: np_example: A dict of NumPy feature arrays. features: A list of strings of feature names to be returned in the dataset. Returns: A dictionary of features mapping feature names to features. Only the given features are returned, all other ones are filtered out. """ features_metadata = _make_features_metadata(features) tensor_dict = {k: tf.constant(v) for k, v in np_example.items() if k in features_metadata} # Ensures shapes are as expected. Needed for setting size of empty features # e.g. when no template hits were found. tensor_dict = parse_reshape_logic(tensor_dict, features_metadata) return tensor_dict