Spaces:
Running
on
T4
Running
on
T4
File size: 6,344 Bytes
85bd48b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Datasets consisting of proteins."""
from typing import Dict, Mapping, Optional, Sequence
from alphafold.model.tf import protein_features
import numpy as np
import tensorflow.compat.v1 as tf
TensorDict = Dict[str, tf.Tensor]
def parse_tfexample(
raw_data: bytes,
features: protein_features.FeaturesMetadata,
key: Optional[str] = None) -> Dict[str, tf.train.Feature]:
"""Read a single TF Example proto and return a subset of its features.
Args:
raw_data: A serialized tf.Example proto.
features: A dictionary of features, mapping string feature names to a tuple
(dtype, shape). This dictionary should be a subset of
protein_features.FEATURES (or the dictionary itself for all features).
key: Optional string with the SSTable key of that tf.Example. This will be
added into features as a 'key' but only if requested in features.
Returns:
A dictionary of features mapping feature names to features. Only the given
features are returned, all other ones are filtered out.
"""
feature_map = {
k: tf.io.FixedLenSequenceFeature(shape=(), dtype=v[0], allow_missing=True)
for k, v in features.items()
}
parsed_features = tf.io.parse_single_example(raw_data, feature_map)
reshaped_features = parse_reshape_logic(parsed_features, features, key=key)
return reshaped_features
def _first(tensor: tf.Tensor) -> tf.Tensor:
"""Returns the 1st element - the input can be a tensor or a scalar."""
return tf.reshape(tensor, shape=(-1,))[0]
def parse_reshape_logic(
parsed_features: TensorDict,
features: protein_features.FeaturesMetadata,
key: Optional[str] = None) -> TensorDict:
"""Transforms parsed serial features to the correct shape."""
# Find out what is the number of sequences and the number of alignments.
num_residues = tf.cast(_first(parsed_features["seq_length"]), dtype=tf.int32)
if "num_alignments" in parsed_features:
num_msa = tf.cast(_first(parsed_features["num_alignments"]), dtype=tf.int32)
else:
num_msa = 0
if "template_domain_names" in parsed_features:
num_templates = tf.cast(
tf.shape(parsed_features["template_domain_names"])[0], dtype=tf.int32)
else:
num_templates = 0
if key is not None and "key" in features:
parsed_features["key"] = [key] # Expand dims from () to (1,).
# Reshape the tensors according to the sequence length and num alignments.
for k, v in parsed_features.items():
new_shape = protein_features.shape(
feature_name=k,
num_residues=num_residues,
msa_length=num_msa,
num_templates=num_templates,
features=features)
new_shape_size = tf.constant(1, dtype=tf.int32)
for dim in new_shape:
new_shape_size *= tf.cast(dim, tf.int32)
assert_equal = tf.assert_equal(
tf.size(v), new_shape_size,
name="assert_%s_shape_correct" % k,
message="The size of feature %s (%s) could not be reshaped "
"into %s" % (k, tf.size(v), new_shape))
if "template" not in k:
# Make sure the feature we are reshaping is not empty.
assert_non_empty = tf.assert_greater(
tf.size(v), 0, name="assert_%s_non_empty" % k,
message="The feature %s is not set in the tf.Example. Either do not "
"request the feature or use a tf.Example that has the "
"feature set." % k)
with tf.control_dependencies([assert_non_empty, assert_equal]):
parsed_features[k] = tf.reshape(v, new_shape, name="reshape_%s" % k)
else:
with tf.control_dependencies([assert_equal]):
parsed_features[k] = tf.reshape(v, new_shape, name="reshape_%s" % k)
return parsed_features
def _make_features_metadata(
feature_names: Sequence[str]) -> protein_features.FeaturesMetadata:
"""Makes a feature name to type and shape mapping from a list of names."""
# Make sure these features are always read.
required_features = ["aatype", "sequence", "seq_length"]
feature_names = list(set(feature_names) | set(required_features))
features_metadata = {name: protein_features.FEATURES[name]
for name in feature_names}
return features_metadata
def create_tensor_dict(
raw_data: bytes,
features: Sequence[str],
key: Optional[str] = None,
) -> TensorDict:
"""Creates a dictionary of tensor features.
Args:
raw_data: A serialized tf.Example proto.
features: A list of strings of feature names to be returned in the dataset.
key: Optional string with the SSTable key of that tf.Example. This will be
added into features as a 'key' but only if requested in features.
Returns:
A dictionary of features mapping feature names to features. Only the given
features are returned, all other ones are filtered out.
"""
features_metadata = _make_features_metadata(features)
return parse_tfexample(raw_data, features_metadata, key)
def np_to_tensor_dict(
np_example: Mapping[str, np.ndarray],
features: Sequence[str],
) -> TensorDict:
"""Creates dict of tensors from a dict of NumPy arrays.
Args:
np_example: A dict of NumPy feature arrays.
features: A list of strings of feature names to be returned in the dataset.
Returns:
A dictionary of features mapping feature names to features. Only the given
features are returned, all other ones are filtered out.
"""
features_metadata = _make_features_metadata(features)
tensor_dict = {k: tf.constant(v) for k, v in np_example.items()
if k in features_metadata}
# Ensures shapes are as expected. Needed for setting size of empty features
# e.g. when no template hits were found.
tensor_dict = parse_reshape_logic(tensor_dict, features_metadata)
return tensor_dict
|