Spaces:
Running
on
T4
Running
on
T4
# Copyright 2021 DeepMind Technologies Limited | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Contains descriptions of various protein features.""" | |
import enum | |
from typing import Dict, Optional, Sequence, Tuple, Union | |
from alphafold.common import residue_constants | |
import tensorflow.compat.v1 as tf | |
# Type aliases. | |
FeaturesMetadata = Dict[str, Tuple[tf.dtypes.DType, Sequence[Union[str, int]]]] | |
class FeatureType(enum.Enum): | |
ZERO_DIM = 0 # Shape [x] | |
ONE_DIM = 1 # Shape [num_res, x] | |
TWO_DIM = 2 # Shape [num_res, num_res, x] | |
MSA = 3 # Shape [msa_length, num_res, x] | |
# Placeholder values that will be replaced with their true value at runtime. | |
NUM_RES = "num residues placeholder" | |
NUM_SEQ = "length msa placeholder" | |
NUM_TEMPLATES = "num templates placeholder" | |
# Sizes of the protein features, NUM_RES and NUM_SEQ are allowed as placeholders | |
# to be replaced with the number of residues and the number of sequences in the | |
# multiple sequence alignment, respectively. | |
FEATURES = { | |
#### Static features of a protein sequence #### | |
"aatype": (tf.float32, [NUM_RES, 21]), | |
"between_segment_residues": (tf.int64, [NUM_RES, 1]), | |
"deletion_matrix": (tf.float32, [NUM_SEQ, NUM_RES, 1]), | |
"domain_name": (tf.string, [1]), | |
"msa": (tf.int64, [NUM_SEQ, NUM_RES, 1]), | |
"num_alignments": (tf.int64, [NUM_RES, 1]), | |
"residue_index": (tf.int64, [NUM_RES, 1]), | |
"seq_length": (tf.int64, [NUM_RES, 1]), | |
"sequence": (tf.string, [1]), | |
"all_atom_positions": (tf.float32, | |
[NUM_RES, residue_constants.atom_type_num, 3]), | |
"all_atom_mask": (tf.int64, [NUM_RES, residue_constants.atom_type_num]), | |
"resolution": (tf.float32, [1]), | |
"template_domain_names": (tf.string, [NUM_TEMPLATES]), | |
"template_sum_probs": (tf.float32, [NUM_TEMPLATES, 1]), | |
"template_aatype": (tf.float32, [NUM_TEMPLATES, NUM_RES, 22]), | |
"template_all_atom_positions": (tf.float32, [ | |
NUM_TEMPLATES, NUM_RES, residue_constants.atom_type_num, 3 | |
]), | |
"template_all_atom_masks": (tf.float32, [ | |
NUM_TEMPLATES, NUM_RES, residue_constants.atom_type_num, 1 | |
]), | |
} | |
FEATURE_TYPES = {k: v[0] for k, v in FEATURES.items()} | |
FEATURE_SIZES = {k: v[1] for k, v in FEATURES.items()} | |
def register_feature(name: str, | |
type_: tf.dtypes.DType, | |
shape_: Tuple[Union[str, int]]): | |
"""Register extra features used in custom datasets.""" | |
FEATURES[name] = (type_, shape_) | |
FEATURE_TYPES[name] = type_ | |
FEATURE_SIZES[name] = shape_ | |
def shape(feature_name: str, | |
num_residues: int, | |
msa_length: int, | |
num_templates: Optional[int] = None, | |
features: Optional[FeaturesMetadata] = None): | |
"""Get the shape for the given feature name. | |
This is near identical to _get_tf_shape_no_placeholders() but with 2 | |
differences: | |
* This method does not calculate a single placeholder from the total number of | |
elements (eg given <NUM_RES, 3> and size := 12, this won't deduce NUM_RES | |
must be 4) | |
* This method will work with tensors | |
Args: | |
feature_name: String identifier for the feature. If the feature name ends | |
with "_unnormalized", this suffix is stripped off. | |
num_residues: The number of residues in the current domain - some elements | |
of the shape can be dynamic and will be replaced by this value. | |
msa_length: The number of sequences in the multiple sequence alignment, some | |
elements of the shape can be dynamic and will be replaced by this value. | |
If the number of alignments is unknown / not read, please pass None for | |
msa_length. | |
num_templates (optional): The number of templates in this tfexample. | |
features: A feature_name to (tf_dtype, shape) lookup; defaults to FEATURES. | |
Returns: | |
List of ints representation the tensor size. | |
Raises: | |
ValueError: If a feature is requested but no concrete placeholder value is | |
given. | |
""" | |
features = features or FEATURES | |
if feature_name.endswith("_unnormalized"): | |
feature_name = feature_name[:-13] | |
unused_dtype, raw_sizes = features[feature_name] | |
replacements = {NUM_RES: num_residues, | |
NUM_SEQ: msa_length} | |
if num_templates is not None: | |
replacements[NUM_TEMPLATES] = num_templates | |
sizes = [replacements.get(dimension, dimension) for dimension in raw_sizes] | |
for dimension in sizes: | |
if isinstance(dimension, str): | |
raise ValueError("Could not parse %s (shape: %s) with values: %s" % ( | |
feature_name, raw_sizes, replacements)) | |
return sizes | |