File size: 3,690 Bytes
85bd48b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Code to generate processed features."""
import copy
from typing import List, Mapping, Tuple
from alphafold.model.tf import input_pipeline
from alphafold.model.tf import proteins_dataset
import ml_collections
import numpy as np
import tensorflow.compat.v1 as tf

FeatureDict = Mapping[str, np.ndarray]


def make_data_config(
    config: ml_collections.ConfigDict,
    num_res: int,
    ) -> Tuple[ml_collections.ConfigDict, List[str]]:
  """Makes a data config for the input pipeline."""
  cfg = copy.deepcopy(config.data)

  feature_names = cfg.common.unsupervised_features
  if cfg.common.use_templates:
    feature_names += cfg.common.template_features

  with cfg.unlocked():
    cfg.eval.crop_size = num_res

  return cfg, feature_names


def tf_example_to_features(tf_example: tf.train.Example,
                           config: ml_collections.ConfigDict,
                           random_seed: int = 0) -> FeatureDict:
  """Converts tf_example to numpy feature dictionary."""
  num_res = int(tf_example.features.feature['seq_length'].int64_list.value[0])
  cfg, feature_names = make_data_config(config, num_res=num_res)

  if 'deletion_matrix_int' in set(tf_example.features.feature):
    deletion_matrix_int = (
        tf_example.features.feature['deletion_matrix_int'].int64_list.value)
    feat = tf.train.Feature(float_list=tf.train.FloatList(
        value=map(float, deletion_matrix_int)))
    tf_example.features.feature['deletion_matrix'].CopyFrom(feat)
    del tf_example.features.feature['deletion_matrix_int']

  tf_graph = tf.Graph()
  with tf_graph.as_default(), tf.device('/device:CPU:0'):
    tf.compat.v1.set_random_seed(random_seed)
    tensor_dict = proteins_dataset.create_tensor_dict(
        raw_data=tf_example.SerializeToString(),
        features=feature_names)
    processed_batch = input_pipeline.process_tensors_from_config(
        tensor_dict, cfg)

  tf_graph.finalize()

  with tf.Session(graph=tf_graph) as sess:
    features = sess.run(processed_batch)

  return {k: v for k, v in features.items() if v.dtype != 'O'}


def np_example_to_features(np_example: FeatureDict,
                           config: ml_collections.ConfigDict,
                           random_seed: int = 0) -> FeatureDict:
  """Preprocesses NumPy feature dict using TF pipeline."""
  np_example = dict(np_example)
  num_res = int(np_example['seq_length'][0])
  cfg, feature_names = make_data_config(config, num_res=num_res)

  if 'deletion_matrix_int' in np_example:
    np_example['deletion_matrix'] = (
        np_example.pop('deletion_matrix_int').astype(np.float32))

  tf_graph = tf.Graph()
  with tf_graph.as_default(), tf.device('/device:CPU:0'):
    tf.compat.v1.set_random_seed(random_seed)
    tensor_dict = proteins_dataset.np_to_tensor_dict(
        np_example=np_example, features=feature_names)

    processed_batch = input_pipeline.process_tensors_from_config(
        tensor_dict, cfg)

  tf_graph.finalize()

  with tf.Session(graph=tf_graph) as sess:
    features = sess.run(processed_batch)

  return {k: v for k, v in features.items() if v.dtype != 'O'}