bshor commited on
Commit
0fdcb79
1 Parent(s): 0e06bb8
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. LICENSE +202 -0
  2. dockformerpp/.DS_Store +0 -0
  3. dockformerpp/__init__.py +6 -0
  4. dockformerpp/config.py +344 -0
  5. dockformerpp/data/.DS_Store +0 -0
  6. dockformerpp/data/data_modules.py +643 -0
  7. dockformerpp/data/data_pipeline.py +360 -0
  8. dockformerpp/data/data_transforms.py +731 -0
  9. dockformerpp/data/errors.py +22 -0
  10. dockformerpp/data/parsers.py +53 -0
  11. dockformerpp/data/protein_features.py +71 -0
  12. dockformerpp/data/utils.py +54 -0
  13. dockformerpp/model/.DS_Store +0 -0
  14. dockformerpp/model/__init__.py +0 -0
  15. dockformerpp/model/dropout.py +69 -0
  16. dockformerpp/model/embedders.py +320 -0
  17. dockformerpp/model/evoformer.py +468 -0
  18. dockformerpp/model/heads.py +233 -0
  19. dockformerpp/model/model.py +317 -0
  20. dockformerpp/model/pair_transition.py +81 -0
  21. dockformerpp/model/primitives.py +598 -0
  22. dockformerpp/model/single_attention.py +184 -0
  23. dockformerpp/model/structure_module.py +837 -0
  24. dockformerpp/model/torchscript.py +171 -0
  25. dockformerpp/model/triangular_attention.py +104 -0
  26. dockformerpp/model/triangular_multiplicative_update.py +173 -0
  27. dockformerpp/resources/.DS_Store +0 -0
  28. dockformerpp/resources/__init__.py +0 -0
  29. dockformerpp/resources/stereo_chemical_props.txt +345 -0
  30. dockformerpp/utils/.DS_Store +0 -0
  31. dockformerpp/utils/__init__.py +0 -0
  32. dockformerpp/utils/callbacks.py +15 -0
  33. dockformerpp/utils/checkpointing.py +78 -0
  34. dockformerpp/utils/config_tools.py +32 -0
  35. dockformerpp/utils/consts.py +25 -0
  36. dockformerpp/utils/exponential_moving_average.py +71 -0
  37. dockformerpp/utils/feats.py +174 -0
  38. dockformerpp/utils/geometry/__init__.py +28 -0
  39. dockformerpp/utils/geometry/__pycache__/__init__.cpython-39.pyc +0 -0
  40. dockformerpp/utils/geometry/__pycache__/quat_rigid.cpython-39.pyc +0 -0
  41. dockformerpp/utils/geometry/__pycache__/rigid_matrix_vector.cpython-39.pyc +0 -0
  42. dockformerpp/utils/geometry/__pycache__/rotation_matrix.cpython-39.pyc +0 -0
  43. dockformerpp/utils/geometry/__pycache__/utils.cpython-39.pyc +0 -0
  44. dockformerpp/utils/geometry/__pycache__/vector.cpython-39.pyc +0 -0
  45. dockformerpp/utils/geometry/quat_rigid.py +38 -0
  46. dockformerpp/utils/geometry/rigid_matrix_vector.py +181 -0
  47. dockformerpp/utils/geometry/rotation_matrix.py +208 -0
  48. dockformerpp/utils/geometry/test_utils.py +97 -0
  49. dockformerpp/utils/geometry/utils.py +22 -0
  50. dockformerpp/utils/geometry/vector.py +261 -0
LICENSE ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ Apache License
3
+ Version 2.0, January 2004
4
+ http://www.apache.org/licenses/
5
+
6
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7
+
8
+ 1. Definitions.
9
+
10
+ "License" shall mean the terms and conditions for use, reproduction,
11
+ and distribution as defined by Sections 1 through 9 of this document.
12
+
13
+ "Licensor" shall mean the copyright owner or entity authorized by
14
+ the copyright owner that is granting the License.
15
+
16
+ "Legal Entity" shall mean the union of the acting entity and all
17
+ other entities that control, are controlled by, or are under common
18
+ control with that entity. For the purposes of this definition,
19
+ "control" means (i) the power, direct or indirect, to cause the
20
+ direction or management of such entity, whether by contract or
21
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
22
+ outstanding shares, or (iii) beneficial ownership of such entity.
23
+
24
+ "You" (or "Your") shall mean an individual or Legal Entity
25
+ exercising permissions granted by this License.
26
+
27
+ "Source" form shall mean the preferred form for making modifications,
28
+ including but not limited to software source code, documentation
29
+ source, and configuration files.
30
+
31
+ "Object" form shall mean any form resulting from mechanical
32
+ transformation or translation of a Source form, including but
33
+ not limited to compiled object code, generated documentation,
34
+ and conversions to other media types.
35
+
36
+ "Work" shall mean the work of authorship, whether in Source or
37
+ Object form, made available under the License, as indicated by a
38
+ copyright notice that is included in or attached to the work
39
+ (an example is provided in the Appendix below).
40
+
41
+ "Derivative Works" shall mean any work, whether in Source or Object
42
+ form, that is based on (or derived from) the Work and for which the
43
+ editorial revisions, annotations, elaborations, or other modifications
44
+ represent, as a whole, an original work of authorship. For the purposes
45
+ of this License, Derivative Works shall not include works that remain
46
+ separable from, or merely link (or bind by name) to the interfaces of,
47
+ the Work and Derivative Works thereof.
48
+
49
+ "Contribution" shall mean any work of authorship, including
50
+ the original version of the Work and any modifications or additions
51
+ to that Work or Derivative Works thereof, that is intentionally
52
+ submitted to Licensor for inclusion in the Work by the copyright owner
53
+ or by an individual or Legal Entity authorized to submit on behalf of
54
+ the copyright owner. For the purposes of this definition, "submitted"
55
+ means any form of electronic, verbal, or written communication sent
56
+ to the Licensor or its representatives, including but not limited to
57
+ communication on electronic mailing lists, source code control systems,
58
+ and issue tracking systems that are managed by, or on behalf of, the
59
+ Licensor for the purpose of discussing and improving the Work, but
60
+ excluding communication that is conspicuously marked or otherwise
61
+ designated in writing by the copyright owner as "Not a Contribution."
62
+
63
+ "Contributor" shall mean Licensor and any individual or Legal Entity
64
+ on behalf of whom a Contribution has been received by Licensor and
65
+ subsequently incorporated within the Work.
66
+
67
+ 2. Grant of Copyright License. Subject to the terms and conditions of
68
+ this License, each Contributor hereby grants to You a perpetual,
69
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70
+ copyright license to reproduce, prepare Derivative Works of,
71
+ publicly display, publicly perform, sublicense, and distribute the
72
+ Work and such Derivative Works in Source or Object form.
73
+
74
+ 3. Grant of Patent License. Subject to the terms and conditions of
75
+ this License, each Contributor hereby grants to You a perpetual,
76
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77
+ (except as stated in this section) patent license to make, have made,
78
+ use, offer to sell, sell, import, and otherwise transfer the Work,
79
+ where such license applies only to those patent claims licensable
80
+ by such Contributor that are necessarily infringed by their
81
+ Contribution(s) alone or by combination of their Contribution(s)
82
+ with the Work to which such Contribution(s) was submitted. If You
83
+ institute patent litigation against any entity (including a
84
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
85
+ or a Contribution incorporated within the Work constitutes direct
86
+ or contributory patent infringement, then any patent licenses
87
+ granted to You under this License for that Work shall terminate
88
+ as of the date such litigation is filed.
89
+
90
+ 4. Redistribution. You may reproduce and distribute copies of the
91
+ Work or Derivative Works thereof in any medium, with or without
92
+ modifications, and in Source or Object form, provided that You
93
+ meet the following conditions:
94
+
95
+ (a) You must give any other recipients of the Work or
96
+ Derivative Works a copy of this License; and
97
+
98
+ (b) You must cause any modified files to carry prominent notices
99
+ stating that You changed the files; and
100
+
101
+ (c) You must retain, in the Source form of any Derivative Works
102
+ that You distribute, all copyright, patent, trademark, and
103
+ attribution notices from the Source form of the Work,
104
+ excluding those notices that do not pertain to any part of
105
+ the Derivative Works; and
106
+
107
+ (d) If the Work includes a "NOTICE" text file as part of its
108
+ distribution, then any Derivative Works that You distribute must
109
+ include a readable copy of the attribution notices contained
110
+ within such NOTICE file, excluding those notices that do not
111
+ pertain to any part of the Derivative Works, in at least one
112
+ of the following places: within a NOTICE text file distributed
113
+ as part of the Derivative Works; within the Source form or
114
+ documentation, if provided along with the Derivative Works; or,
115
+ within a display generated by the Derivative Works, if and
116
+ wherever such third-party notices normally appear. The contents
117
+ of the NOTICE file are for informational purposes only and
118
+ do not modify the License. You may add Your own attribution
119
+ notices within Derivative Works that You distribute, alongside
120
+ or as an addendum to the NOTICE text from the Work, provided
121
+ that such additional attribution notices cannot be construed
122
+ as modifying the License.
123
+
124
+ You may add Your own copyright statement to Your modifications and
125
+ may provide additional or different license terms and conditions
126
+ for use, reproduction, or distribution of Your modifications, or
127
+ for any such Derivative Works as a whole, provided Your use,
128
+ reproduction, and distribution of the Work otherwise complies with
129
+ the conditions stated in this License.
130
+
131
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
132
+ any Contribution intentionally submitted for inclusion in the Work
133
+ by You to the Licensor shall be under the terms and conditions of
134
+ this License, without any additional terms or conditions.
135
+ Notwithstanding the above, nothing herein shall supersede or modify
136
+ the terms of any separate license agreement you may have executed
137
+ with Licensor regarding such Contributions.
138
+
139
+ 6. Trademarks. This License does not grant permission to use the trade
140
+ names, trademarks, service marks, or product names of the Licensor,
141
+ except as required for reasonable and customary use in describing the
142
+ origin of the Work and reproducing the content of the NOTICE file.
143
+
144
+ 7. Disclaimer of Warranty. Unless required by applicable law or
145
+ agreed to in writing, Licensor provides the Work (and each
146
+ Contributor provides its Contributions) on an "AS IS" BASIS,
147
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148
+ implied, including, without limitation, any warranties or conditions
149
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150
+ PARTICULAR PURPOSE. You are solely responsible for determining the
151
+ appropriateness of using or redistributing the Work and assume any
152
+ risks associated with Your exercise of permissions under this License.
153
+
154
+ 8. Limitation of Liability. In no event and under no legal theory,
155
+ whether in tort (including negligence), contract, or otherwise,
156
+ unless required by applicable law (such as deliberate and grossly
157
+ negligent acts) or agreed to in writing, shall any Contributor be
158
+ liable to You for damages, including any direct, indirect, special,
159
+ incidental, or consequential damages of any character arising as a
160
+ result of this License or out of the use or inability to use the
161
+ Work (including but not limited to damages for loss of goodwill,
162
+ work stoppage, computer failure or malfunction, or any and all
163
+ other commercial damages or losses), even if such Contributor
164
+ has been advised of the possibility of such damages.
165
+
166
+ 9. Accepting Warranty or Additional Liability. While redistributing
167
+ the Work or Derivative Works thereof, You may choose to offer,
168
+ and charge a fee for, acceptance of support, warranty, indemnity,
169
+ or other liability obligations and/or rights consistent with this
170
+ License. However, in accepting such obligations, You may act only
171
+ on Your own behalf and on Your sole responsibility, not on behalf
172
+ of any other Contributor, and only if You agree to indemnify,
173
+ defend, and hold each Contributor harmless for any liability
174
+ incurred by, or claims asserted against, such Contributor by reason
175
+ of your accepting any such warranty or additional liability.
176
+
177
+ END OF TERMS AND CONDITIONS
178
+
179
+ APPENDIX: How to apply the Apache License to your work.
180
+
181
+ To apply the Apache License to your work, attach the following
182
+ boilerplate notice, with the fields enclosed by brackets "[]"
183
+ replaced with your own identifying information. (Don't include
184
+ the brackets!) The text should be enclosed in the appropriate
185
+ comment syntax for the file format. We also recommend that a
186
+ file or class name and description of purpose be included on the
187
+ same "printed page" as the copyright notice for easier
188
+ identification within third-party archives.
189
+
190
+ Copyright [yyyy] [name of copyright owner]
191
+
192
+ Licensed under the Apache License, Version 2.0 (the "License");
193
+ you may not use this file except in compliance with the License.
194
+ You may obtain a copy of the License at
195
+
196
+ http://www.apache.org/licenses/LICENSE-2.0
197
+
198
+ Unless required by applicable law or agreed to in writing, software
199
+ distributed under the License is distributed on an "AS IS" BASIS,
200
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201
+ See the License for the specific language governing permissions and
202
+ limitations under the License.
dockformerpp/.DS_Store ADDED
Binary file (6.15 kB). View file
 
dockformerpp/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from . import model
2
+ from . import utils
3
+ from . import data
4
+ from . import resources
5
+
6
+ __all__ = ["model", "utils", "data", "resources"]
dockformerpp/config.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import ml_collections as mlc
3
+
4
+ from dockformerpp.utils.config_tools import set_inf, enforce_config_constraints
5
+
6
+
7
+ def model_config(
8
+ name,
9
+ train=False,
10
+ low_prec=False,
11
+ long_sequence_inference=False
12
+ ):
13
+ c = copy.deepcopy(config)
14
+ # TRAINING PRESETS
15
+ if name == "initial_training":
16
+ # AF2 Suppl. Table 4, "initial training" setting
17
+
18
+ pass
19
+ elif name == "finetune_affinity":
20
+ c.loss.affinity2d.weight = 0.5
21
+ c.loss.binding_site.weight = 0.5
22
+ c.loss.positions_inter_distogram.weight = 0.5 # this is not essential given fape?
23
+ else:
24
+ raise ValueError("Invalid model name")
25
+
26
+ c.globals.use_lma = False
27
+
28
+ if long_sequence_inference:
29
+ assert(not train)
30
+ c.globals.use_lma = True
31
+
32
+ if train:
33
+ c.globals.blocks_per_ckpt = 1
34
+ c.globals.use_lma = False
35
+
36
+ if low_prec:
37
+ c.globals.eps = 1e-4
38
+ # If we want exact numerical parity with the original, inf can't be
39
+ # a global constant
40
+ set_inf(c, 1e4)
41
+
42
+ enforce_config_constraints(c)
43
+
44
+ return c
45
+
46
+
47
+ c_z = mlc.FieldReference(128, field_type=int)
48
+ c_m = mlc.FieldReference(256, field_type=int)
49
+ c_t = mlc.FieldReference(64, field_type=int)
50
+ c_e = mlc.FieldReference(64, field_type=int)
51
+ c_s = mlc.FieldReference(384, field_type=int)
52
+
53
+
54
+ blocks_per_ckpt = mlc.FieldReference(None, field_type=int)
55
+ aux_distogram_bins = mlc.FieldReference(64, field_type=int)
56
+ aux_affinity_bins = mlc.FieldReference(32, field_type=int)
57
+ eps = mlc.FieldReference(1e-8, field_type=float)
58
+
59
+ NUM_RES = "num residues placeholder"
60
+ NUM_TOKEN = "num tokens placeholder"
61
+
62
+
63
+ config = mlc.ConfigDict(
64
+ {
65
+ "data": {
66
+ "common": {
67
+ "feat": {
68
+ "aatype": [NUM_TOKEN],
69
+ "all_atom_mask": [NUM_TOKEN, None],
70
+ "all_atom_positions": [NUM_TOKEN, None, None],
71
+ "atom14_alt_gt_exists": [NUM_TOKEN, None],
72
+ "atom14_alt_gt_positions": [NUM_TOKEN, None, None],
73
+ "atom14_atom_exists": [NUM_TOKEN, None],
74
+ "atom14_atom_is_ambiguous": [NUM_TOKEN, None],
75
+ "atom14_gt_exists": [NUM_TOKEN, None],
76
+ "atom14_gt_positions": [NUM_TOKEN, None, None],
77
+ "atom37_atom_exists": [NUM_TOKEN, None],
78
+ "backbone_rigid_mask": [NUM_TOKEN],
79
+ "backbone_rigid_tensor": [NUM_TOKEN, None, None],
80
+ "chi_angles_sin_cos": [NUM_TOKEN, None, None],
81
+ "chi_mask": [NUM_TOKEN, None],
82
+ "no_recycling_iters": [],
83
+ "pseudo_beta": [NUM_TOKEN, None],
84
+ "pseudo_beta_mask": [NUM_TOKEN],
85
+ "residue_index": [NUM_TOKEN],
86
+ "in_chain_residue_index": [NUM_TOKEN],
87
+ "chain_index": [NUM_TOKEN],
88
+ "residx_atom14_to_atom37": [NUM_TOKEN, None],
89
+ "residx_atom37_to_atom14": [NUM_TOKEN, None],
90
+ "resolution": [],
91
+ "rigidgroups_alt_gt_frames": [NUM_TOKEN, None, None, None],
92
+ "rigidgroups_group_exists": [NUM_TOKEN, None],
93
+ "rigidgroups_group_is_ambiguous": [NUM_TOKEN, None],
94
+ "rigidgroups_gt_exists": [NUM_TOKEN, None],
95
+ "rigidgroups_gt_frames": [NUM_TOKEN, None, None, None],
96
+ "seq_length": [],
97
+ "token_mask": [NUM_TOKEN],
98
+ "target_feat": [NUM_TOKEN, None],
99
+ "use_clamped_fape": [],
100
+ },
101
+ "max_recycling_iters": 1,
102
+ "unsupervised_features": [
103
+ "aatype",
104
+ "residue_index",
105
+ "in_chain_residue_index",
106
+ "chain_index",
107
+ "seq_length",
108
+ "no_recycling_iters",
109
+ "all_atom_mask",
110
+ "all_atom_positions",
111
+ ],
112
+ },
113
+ "supervised": {
114
+ "clamp_prob": 0.9,
115
+ "supervised_features": [
116
+ "resolution",
117
+ "use_clamped_fape",
118
+ ],
119
+ },
120
+ "predict": {
121
+ "fixed_size": True,
122
+ "crop": False,
123
+ "crop_size": None,
124
+ "supervised": False,
125
+ "uniform_recycling": False,
126
+ },
127
+ "eval": {
128
+ "fixed_size": True,
129
+ "crop": False,
130
+ "crop_size": None,
131
+ "supervised": True,
132
+ "uniform_recycling": False,
133
+ },
134
+ "train": {
135
+ "fixed_size": True,
136
+ "crop": True,
137
+ "crop_size": 355,
138
+ "supervised": True,
139
+ "clamp_prob": 0.9,
140
+ "uniform_recycling": True,
141
+ "distogram_mask_prob": 0.1,
142
+ },
143
+ "data_module": {
144
+ "data_loaders": {
145
+ "batch_size": 1,
146
+ # "batch_size": 2,
147
+ "num_workers": 16,
148
+ "pin_memory": True,
149
+ "should_verify": False,
150
+ },
151
+ },
152
+ },
153
+ # Recurring FieldReferences that can be changed globally here
154
+ "globals": {
155
+ "blocks_per_ckpt": blocks_per_ckpt,
156
+ # Use Staats & Rabe's low-memory attention algorithm.
157
+ "use_lma": False,
158
+ "max_lr": 1e-3,
159
+ "c_z": c_z,
160
+ "c_m": c_m,
161
+ "c_t": c_t,
162
+ "c_e": c_e,
163
+ "c_s": c_s,
164
+ "eps": eps,
165
+ },
166
+ "model": {
167
+ "_mask_trans": False,
168
+ "structure_input_embedder": {
169
+ "protein_tf_dim": 20,
170
+ "additional_tf_dim": 3, # number of classes (prot_r, prot_l, aff)
171
+ "c_z": c_z,
172
+ "c_m": c_m,
173
+ "relpos_k": 32,
174
+ "prot_min_bin": 3.25,
175
+ "prot_max_bin": 20.75,
176
+ "prot_no_bins": 15,
177
+ "inf": 1e8,
178
+ },
179
+ "recycling_embedder": {
180
+ "c_z": c_z,
181
+ "c_m": c_m,
182
+ "min_bin": 3.25,
183
+ "max_bin": 20.75,
184
+ "no_bins": 15,
185
+ "inf": 1e8,
186
+ },
187
+ "evoformer_stack": {
188
+ "c_m": c_m,
189
+ "c_z": c_z,
190
+ "c_hidden_single_att": 32,
191
+ "c_hidden_mul": 128,
192
+ "c_hidden_pair_att": 32,
193
+ "c_s": c_s,
194
+ "no_heads_single": 8,
195
+ "no_heads_pair": 4,
196
+ # "no_blocks": 48,
197
+ "no_blocks": 2,
198
+ "transition_n": 4,
199
+ "single_dropout": 0.15,
200
+ "pair_dropout": 0.25,
201
+ "blocks_per_ckpt": blocks_per_ckpt,
202
+ "clear_cache_between_blocks": False,
203
+ "inf": 1e9,
204
+ "eps": eps, # 1e-10,
205
+ },
206
+ "structure_module": {
207
+ "c_s": c_s,
208
+ "c_z": c_z,
209
+ "c_ipa": 16,
210
+ "c_resnet": 128,
211
+ "no_heads_ipa": 12,
212
+ "no_qk_points": 4,
213
+ "no_v_points": 8,
214
+ "dropout_rate": 0.1,
215
+ "no_blocks": 8,
216
+ "no_transition_layers": 1,
217
+ "no_resnet_blocks": 2,
218
+ "no_angles": 7,
219
+ "trans_scale_factor": 10,
220
+ "epsilon": eps, # 1e-12,
221
+ "inf": 1e5,
222
+ },
223
+ "heads": {
224
+ "lddt": {
225
+ "no_bins": 50,
226
+ "c_in": c_s,
227
+ "c_hidden": 128,
228
+ },
229
+ "distogram": {
230
+ "c_z": c_z,
231
+ "no_bins": aux_distogram_bins,
232
+ },
233
+ "affinity_2d": {
234
+ "c_z": c_z,
235
+ "num_bins": aux_affinity_bins,
236
+ },
237
+ "affinity_1d": {
238
+ "c_s": c_s,
239
+ "num_bins": aux_affinity_bins,
240
+ },
241
+ "affinity_cls": {
242
+ "c_s": c_s,
243
+ "num_bins": aux_affinity_bins,
244
+ },
245
+ "binding_site": {
246
+ "c_s": c_s,
247
+ "c_out": 1,
248
+ },
249
+ "inter_contact": {
250
+ "c_s": c_s,
251
+ "c_z": c_z,
252
+ "c_out": 1,
253
+ },
254
+ },
255
+ # A negative value indicates that no early stopping will occur, i.e.
256
+ # the model will always run `max_recycling_iters` number of recycling
257
+ # iterations. A positive value will enable early stopping if the
258
+ # difference in pairwise distances is less than the tolerance between
259
+ # recycling steps.
260
+ "recycle_early_stop_tolerance": -1.
261
+ },
262
+ "relax": {
263
+ "max_iterations": 0, # no max
264
+ "tolerance": 2.39,
265
+ "stiffness": 10.0,
266
+ "max_outer_iterations": 20,
267
+ "exclude_residues": [],
268
+ },
269
+ "loss": {
270
+ "distogram": {
271
+ "min_bin": 2.3125,
272
+ "max_bin": 21.6875,
273
+ "no_bins": 64,
274
+ "eps": eps, # 1e-6,
275
+ "weight": 0.3,
276
+ },
277
+ "positions_inter_distogram": {
278
+ "max_dist": 20.0,
279
+ "weight": 0.0,
280
+ },
281
+ "positions_intra_distogram": {
282
+ "max_dist": 10.0,
283
+ "weight": 0.0,
284
+ },
285
+ "binding_site": {
286
+ "weight": 0.0,
287
+ "pos_class_weight": 20.0,
288
+ },
289
+ "inter_contact": {
290
+ "weight": 0.0,
291
+ "pos_class_weight": 200.0,
292
+ },
293
+ "affinity2d": {
294
+ "min_bin": 0,
295
+ "max_bin": 15,
296
+ "no_bins": aux_affinity_bins,
297
+ "weight": 0.0,
298
+ },
299
+ "affinity_cls": {
300
+ "min_bin": 0,
301
+ "max_bin": 15,
302
+ "no_bins": aux_affinity_bins,
303
+ "weight": 0.0,
304
+ },
305
+ "fape_backbone": {
306
+ "clamp_distance": 10.0,
307
+ "loss_unit_distance": 10.0,
308
+ "weight": 0.5,
309
+ },
310
+ "fape_sidechain": {
311
+ "clamp_distance": 10.0,
312
+ "length_scale": 10.0,
313
+ "weight": 0.5,
314
+ },
315
+ "fape_interface": {
316
+ "clamp_distance": 10.0,
317
+ "length_scale": 10.0,
318
+ "weight": 0.0,
319
+ },
320
+ "plddt_loss": {
321
+ "min_resolution": 0.1,
322
+ "max_resolution": 3.0,
323
+ "cutoff": 15.0,
324
+ "no_bins": 50,
325
+ "eps": eps, # 1e-10,
326
+ "weight": 0.01,
327
+ },
328
+ "supervised_chi": {
329
+ "chi_weight": 0.5,
330
+ "angle_norm_weight": 0.01,
331
+ "eps": eps, # 1e-6,
332
+ "weight": 1.0,
333
+ },
334
+ "chain_center_of_mass": {
335
+ "clamp_distance": -4.0,
336
+ "weight": 0.,
337
+ "eps": eps,
338
+ "enabled": False,
339
+ },
340
+ "eps": eps,
341
+ },
342
+ "ema": {"decay": 0.999},
343
+ }
344
+ )
dockformerpp/data/.DS_Store ADDED
Binary file (6.15 kB). View file
 
dockformerpp/data/data_modules.py ADDED
@@ -0,0 +1,643 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import itertools
3
+ import time
4
+ import traceback
5
+ from collections import Counter
6
+ from functools import partial
7
+ import json
8
+ import os
9
+ import pickle
10
+ from typing import Optional, Sequence, Any
11
+
12
+ import ml_collections as mlc
13
+ import lightning as L
14
+ import torch
15
+ from torch.utils.data import RandomSampler
16
+
17
+ from dockformerpp.data.data_pipeline import parse_input_json
18
+ from dockformerpp.data import data_pipeline
19
+ from dockformerpp.utils.tensor_utils import dict_multimap
20
+ from dockformerpp.utils.tensor_utils import (
21
+ tensor_tree_map,
22
+ )
23
+
24
+
25
+ class OpenFoldSingleDataset(torch.utils.data.Dataset):
26
+ def __init__(self,
27
+ data_dir: str,
28
+ config: mlc.ConfigDict,
29
+ mode: str = "train",
30
+ ):
31
+ """
32
+ Args:
33
+ data_dir:
34
+ A path to a directory containing mmCIF files (in train
35
+ mode) or FASTA files (in inference mode).
36
+ config:
37
+ A dataset config object. See openfold.config
38
+ mode:
39
+ "train", "val", or "predict"
40
+ """
41
+ super(OpenFoldSingleDataset, self).__init__()
42
+ self.data_dir = data_dir
43
+
44
+ self.config = config
45
+ self.mode = mode
46
+
47
+ valid_modes = ["train", "eval", "predict"]
48
+ if mode not in valid_modes:
49
+ raise ValueError(f'mode must be one of {valid_modes}')
50
+
51
+ self._all_input_files = [i for i in os.listdir(data_dir) if i.endswith(".json")]
52
+ if self.config.data_module.data_loaders.should_verify:
53
+ self._all_input_files = [i for i in self._all_input_files if self._verify_json_input_file(i)]
54
+
55
+ self.data_pipeline = data_pipeline.DataPipeline(config, mode)
56
+
57
+ def _verify_json_input_file(self, file_name: str) -> bool:
58
+ with open(os.path.join(self.data_dir, file_name), "r") as f:
59
+ try:
60
+ loaded = json.load(f)
61
+ for i in ["input_structure"]:
62
+ if i not in loaded:
63
+ return False
64
+ if self.mode != "predict":
65
+ for i in ["gt_structure", "resolution"]:
66
+ if i not in loaded:
67
+ return False
68
+ except json.JSONDecodeError:
69
+ return False
70
+ return True
71
+
72
+ def get_metadata_for_idx(self, idx: int) -> dict:
73
+ input_path = os.path.join(self.data_dir, self._all_input_files[idx])
74
+ input_data = json.load(open(input_path, "r"))
75
+ metadata = {
76
+ "resolution": input_data.get("resolution", 99.0),
77
+ "input_path": input_path,
78
+ "input_name": os.path.basename(input_path).split(".json")[0],
79
+ }
80
+ return metadata
81
+
82
+ def __getitem__(self, idx):
83
+ return parse_input_json(
84
+ input_path=os.path.join(self.data_dir, self._all_input_files[idx]),
85
+ mode=self.mode,
86
+ config=self.config,
87
+ data_pipeline=self.data_pipeline,
88
+ data_dir=os.path.dirname(self.data_dir),
89
+ idx=idx,
90
+ )
91
+
92
+ def __len__(self):
93
+ return len(self._all_input_files)
94
+
95
+
96
+ def resolution_filter(resolution: int, max_resolution: float) -> bool:
97
+ """Check that the resolution is <= max_resolution permitted"""
98
+ return resolution is not None and resolution <= max_resolution
99
+
100
+
101
+ def all_seq_len_filter(seqs: list, minimum_number_of_residues: int) -> bool:
102
+ """Check if the total combined sequence lengths are >= minimum_numer_of_residues"""
103
+ total_len = sum([len(i) for i in seqs])
104
+ return total_len >= minimum_number_of_residues
105
+
106
+
107
+ class OpenFoldDataset(torch.utils.data.Dataset):
108
+ """
109
+ Implements the stochastic filters applied during AlphaFold's training.
110
+ Because samples are selected from constituent datasets randomly, the
111
+ length of an OpenFoldFilteredDataset is arbitrary. Samples are selected
112
+ and filtered once at initialization.
113
+ """
114
+
115
+ def __init__(self,
116
+ datasets: Sequence[OpenFoldSingleDataset],
117
+ probabilities: Sequence[float],
118
+ epoch_len: int,
119
+ generator: torch.Generator = None,
120
+ _roll_at_init: bool = True,
121
+ ):
122
+ self.datasets = datasets
123
+ self.probabilities = probabilities
124
+ self.epoch_len = epoch_len
125
+ self.generator = generator
126
+
127
+ self._samples = [self.looped_samples(i) for i in range(len(self.datasets))]
128
+ if _roll_at_init:
129
+ self.reroll()
130
+
131
+ @staticmethod
132
+ def deterministic_train_filter(
133
+ cache_entry: Any,
134
+ max_resolution: float = 9.,
135
+ max_single_aa_prop: float = 0.8,
136
+ *args, **kwargs
137
+ ) -> bool:
138
+ # Hard filters
139
+ resolution = cache_entry["resolution"]
140
+
141
+ return all([
142
+ resolution_filter(resolution=resolution,
143
+ max_resolution=max_resolution)
144
+ ])
145
+
146
+ @staticmethod
147
+ def get_stochastic_train_filter_prob(
148
+ cache_entry: Any,
149
+ *args, **kwargs
150
+ ) -> float:
151
+ # Stochastic filters
152
+ probabilities = []
153
+
154
+ cluster_size = cache_entry.get("cluster_size", None)
155
+ if cluster_size is not None and cluster_size > 0:
156
+ probabilities.append(1 / cluster_size)
157
+
158
+ # Risk of underflow here?
159
+ out = 1
160
+ for p in probabilities:
161
+ out *= p
162
+
163
+ return out
164
+
165
+ def looped_shuffled_dataset_idx(self, dataset_len):
166
+ while True:
167
+ # Uniformly shuffle each dataset's indices
168
+ weights = [1. for _ in range(dataset_len)]
169
+ shuf = torch.multinomial(
170
+ torch.tensor(weights),
171
+ num_samples=dataset_len,
172
+ replacement=False,
173
+ generator=self.generator,
174
+ )
175
+ for idx in shuf:
176
+ yield idx
177
+
178
+ def looped_samples(self, dataset_idx):
179
+ max_cache_len = int(self.epoch_len * self.probabilities[dataset_idx])
180
+ dataset = self.datasets[dataset_idx]
181
+ idx_iter = self.looped_shuffled_dataset_idx(len(dataset))
182
+ while True:
183
+ weights = []
184
+ idx = []
185
+ for _ in range(max_cache_len):
186
+ candidate_idx = next(idx_iter)
187
+ # chain_id = dataset.idx_to_chain_id(candidate_idx)
188
+ # chain_data_cache_entry = chain_data_cache[chain_id]
189
+ # data_entry = dataset[candidate_idx.item()]
190
+ entry_metadata_for_filter = dataset.get_metadata_for_idx(candidate_idx.item())
191
+ if not self.deterministic_train_filter(entry_metadata_for_filter):
192
+ continue
193
+
194
+ p = self.get_stochastic_train_filter_prob(
195
+ entry_metadata_for_filter,
196
+ )
197
+ weights.append([1. - p, p])
198
+ idx.append(candidate_idx)
199
+
200
+ samples = torch.multinomial(
201
+ torch.tensor(weights),
202
+ num_samples=1,
203
+ generator=self.generator,
204
+ )
205
+ samples = samples.squeeze()
206
+
207
+ cache = [i for i, s in zip(idx, samples) if s]
208
+
209
+ for datapoint_idx in cache:
210
+ yield datapoint_idx
211
+
212
+ def __getitem__(self, idx):
213
+ dataset_idx, datapoint_idx = self.datapoints[idx]
214
+ return self.datasets[dataset_idx][datapoint_idx]
215
+
216
+ def __len__(self):
217
+ return self.epoch_len
218
+
219
+ def reroll(self):
220
+ # TODO bshor: I have removed support for filters (currently done in preprocess) and to weighting clusters
221
+ # now it is much faster, because it doesn't call looped_samples
222
+ dataset_choices = torch.multinomial(
223
+ torch.tensor(self.probabilities),
224
+ num_samples=self.epoch_len,
225
+ replacement=True,
226
+ generator=self.generator,
227
+ )
228
+ self.datapoints = []
229
+ counter_datasets = Counter(dataset_choices.tolist())
230
+ for dataset_idx, num_samples in counter_datasets.items():
231
+ dataset = self.datasets[dataset_idx]
232
+ sample_choices = torch.randint(0, len(dataset), (num_samples,), generator=self.generator)
233
+ for datapoint_idx in sample_choices:
234
+ self.datapoints.append((dataset_idx, datapoint_idx))
235
+
236
+
237
+ class OpenFoldBatchCollator:
238
+ def __call__(self, prots):
239
+ stack_fn = partial(torch.stack, dim=0)
240
+ return dict_multimap(stack_fn, prots)
241
+
242
+
243
+ class OpenFoldDataLoader(torch.utils.data.DataLoader):
244
+ def __init__(self, *args, config, stage="train", generator=None, **kwargs):
245
+ super().__init__(*args, **kwargs)
246
+ self.config = config
247
+ self.stage = stage
248
+ self.generator = generator
249
+ self._prep_batch_properties_probs()
250
+
251
+ def _prep_batch_properties_probs(self):
252
+ keyed_probs = []
253
+ stage_cfg = self.config[self.stage]
254
+
255
+ max_iters = self.config.common.max_recycling_iters
256
+
257
+ if stage_cfg.uniform_recycling:
258
+ recycling_probs = [
259
+ 1. / (max_iters + 1) for _ in range(max_iters + 1)
260
+ ]
261
+ else:
262
+ recycling_probs = [
263
+ 0. for _ in range(max_iters + 1)
264
+ ]
265
+ recycling_probs[-1] = 1.
266
+
267
+ keyed_probs.append(
268
+ ("no_recycling_iters", recycling_probs)
269
+ )
270
+
271
+ keys, probs = zip(*keyed_probs)
272
+ max_len = max([len(p) for p in probs])
273
+ padding = [[0.] * (max_len - len(p)) for p in probs]
274
+
275
+ self.prop_keys = keys
276
+ self.prop_probs_tensor = torch.tensor(
277
+ [p + pad for p, pad in zip(probs, padding)],
278
+ dtype=torch.float32,
279
+ )
280
+
281
+ def _add_batch_properties(self, batch):
282
+ # gt_features = batch.pop('gt_features', None)
283
+ samples = torch.multinomial(
284
+ self.prop_probs_tensor,
285
+ num_samples=1, # 1 per row
286
+ replacement=True,
287
+ generator=self.generator
288
+ )
289
+
290
+ aatype = batch["aatype"]
291
+ batch_dims = aatype.shape[:-2]
292
+ recycling_dim = aatype.shape[-1]
293
+ no_recycling = recycling_dim
294
+ for i, key in enumerate(self.prop_keys):
295
+ sample = int(samples[i][0])
296
+ sample_tensor = torch.tensor(
297
+ sample,
298
+ device=aatype.device,
299
+ requires_grad=False
300
+ )
301
+ orig_shape = sample_tensor.shape
302
+ sample_tensor = sample_tensor.view(
303
+ (1,) * len(batch_dims) + sample_tensor.shape + (1,)
304
+ )
305
+ sample_tensor = sample_tensor.expand(
306
+ batch_dims + orig_shape + (recycling_dim,)
307
+ )
308
+ batch[key] = sample_tensor
309
+
310
+ if key == "no_recycling_iters":
311
+ no_recycling = sample
312
+
313
+ resample_recycling = lambda t: t[..., :no_recycling + 1]
314
+ batch = tensor_tree_map(resample_recycling, batch)
315
+ # batch['gt_features'] = gt_features
316
+
317
+ return batch
318
+
319
+ def __iter__(self):
320
+ it = super().__iter__()
321
+
322
+ def _batch_prop_gen(iterator):
323
+ for batch in iterator:
324
+ yield self._add_batch_properties(batch)
325
+
326
+ return _batch_prop_gen(it)
327
+
328
+
329
+ class OpenFoldDataModule(L.LightningDataModule):
330
+ def __init__(self,
331
+ config: mlc.ConfigDict,
332
+ train_data_dir: Optional[str] = None,
333
+ val_data_dir: Optional[str] = None,
334
+ predict_data_dir: Optional[str] = None,
335
+ batch_seed: Optional[int] = None,
336
+ train_epoch_len: int = 50000,
337
+ **kwargs
338
+ ):
339
+ super(OpenFoldDataModule, self).__init__()
340
+
341
+ self.config = config
342
+ self.train_data_dir = train_data_dir
343
+ self.val_data_dir = val_data_dir
344
+ self.predict_data_dir = predict_data_dir
345
+ self.batch_seed = batch_seed
346
+ self.train_epoch_len = train_epoch_len
347
+
348
+ if self.train_data_dir is None and self.predict_data_dir is None:
349
+ raise ValueError(
350
+ 'At least one of train_data_dir or predict_data_dir must be '
351
+ 'specified'
352
+ )
353
+
354
+ self.training_mode = self.train_data_dir is not None
355
+
356
+ # if not self.training_mode and predict_alignment_dir is None:
357
+ # raise ValueError(
358
+ # 'In inference mode, predict_alignment_dir must be specified'
359
+ # )
360
+ # elif val_data_dir is not None and val_alignment_dir is None:
361
+ # raise ValueError(
362
+ # 'If val_data_dir is specified, val_alignment_dir must '
363
+ # 'be specified as well'
364
+ # )
365
+
366
+ def setup(self, stage):
367
+ # Most of the arguments are the same for the three datasets
368
+ dataset_gen = partial(OpenFoldSingleDataset,
369
+ config=self.config)
370
+
371
+ if self.training_mode:
372
+ train_dataset = dataset_gen(
373
+ data_dir=self.train_data_dir,
374
+ mode="train",
375
+ )
376
+
377
+ datasets = [train_dataset]
378
+ probabilities = [1.]
379
+
380
+ generator = None
381
+ if self.batch_seed is not None:
382
+ generator = torch.Generator()
383
+ generator = generator.manual_seed(self.batch_seed + 1)
384
+
385
+ self.train_dataset = OpenFoldDataset(
386
+ datasets=datasets,
387
+ probabilities=probabilities,
388
+ epoch_len=self.train_epoch_len,
389
+ generator=generator,
390
+ _roll_at_init=False,
391
+ )
392
+
393
+ if self.val_data_dir is not None:
394
+ self.eval_dataset = dataset_gen(
395
+ data_dir=self.val_data_dir,
396
+ mode="eval",
397
+ )
398
+ else:
399
+ self.eval_dataset = None
400
+ else:
401
+ self.predict_dataset = dataset_gen(
402
+ data_dir=self.predict_data_dir,
403
+ mode="predict",
404
+ )
405
+
406
+ def _gen_dataloader(self, stage):
407
+ generator = None
408
+ if self.batch_seed is not None:
409
+ generator = torch.Generator()
410
+ generator = generator.manual_seed(self.batch_seed)
411
+
412
+ if stage == "train":
413
+ dataset = self.train_dataset
414
+ # Filter the dataset, if necessary
415
+ dataset.reroll()
416
+ elif stage == "eval":
417
+ dataset = self.eval_dataset
418
+ elif stage == "predict":
419
+ dataset = self.predict_dataset
420
+ else:
421
+ raise ValueError("Invalid stage")
422
+
423
+ batch_collator = OpenFoldBatchCollator()
424
+
425
+ dl = OpenFoldDataLoader(
426
+ dataset,
427
+ config=self.config,
428
+ stage=stage,
429
+ generator=generator,
430
+ batch_size=self.config.data_module.data_loaders.batch_size,
431
+ # num_workers=self.config.data_module.data_loaders.num_workers,
432
+ num_workers=0, # TODO bshor: solve generator pickling issue and then bring back num_workers, or just remove generator
433
+ collate_fn=batch_collator,
434
+ )
435
+
436
+ return dl
437
+
438
+ def train_dataloader(self):
439
+ return self._gen_dataloader("train")
440
+
441
+ def val_dataloader(self):
442
+ if self.eval_dataset is not None:
443
+ return self._gen_dataloader("eval")
444
+ return None
445
+
446
+ def predict_dataloader(self):
447
+ return self._gen_dataloader("predict")
448
+
449
+
450
+ class DummyDataset(torch.utils.data.Dataset):
451
+ def __init__(self, batch_path):
452
+ with open(batch_path, "rb") as f:
453
+ self.batch = pickle.load(f)
454
+
455
+ def __getitem__(self, idx):
456
+ return copy.deepcopy(self.batch)
457
+
458
+ def __len__(self):
459
+ return 1000
460
+
461
+
462
+ class DummyDataLoader(L.LightningDataModule):
463
+ def __init__(self, batch_path):
464
+ super().__init__()
465
+ self.dataset = DummyDataset(batch_path)
466
+
467
+ def train_dataloader(self):
468
+ return torch.utils.data.DataLoader(self.dataset)
469
+
470
+
471
+ class DockFormerSimpleDataset(torch.utils.data.Dataset):
472
+ def __init__(self, clusters_json: str, config: mlc.ConfigDict, mode: str = "train"):
473
+ clusters = json.load(open(clusters_json, "r"))
474
+ self.config = config
475
+ self.mode = mode
476
+ self._data_dir = os.path.dirname(clusters_json)
477
+ print("Data dir", self._data_dir)
478
+ self._clusters = clusters
479
+ self._all_input_files = sum(clusters.values(), [])
480
+ self.data_pipeline = data_pipeline.DataPipeline(config, mode)
481
+
482
+ def __getitem__(self, idx):
483
+ return parse_input_json(
484
+ input_path=os.path.join(self._data_dir, self._all_input_files[idx]),
485
+ mode=self.mode,
486
+ config=self.config,
487
+ data_pipeline=self.data_pipeline,
488
+ data_dir=self._data_dir,
489
+ idx=idx,
490
+ )
491
+
492
+ def __len__(self):
493
+ return len(self._all_input_files)
494
+
495
+
496
+ class DockFormerClusteredDataset(torch.utils.data.Dataset):
497
+ def __init__(self, clusters_json: str, config: mlc.ConfigDict, mode: str = "train", generator=None):
498
+ clusters = json.load(open(clusters_json, "r"))
499
+ self.config = config
500
+ self.mode = mode
501
+ self._data_dir = os.path.dirname(clusters_json)
502
+ self._clusters = list(clusters.values())
503
+ self.data_pipeline = data_pipeline.DataPipeline(config, mode)
504
+ self._generator = generator
505
+
506
+ def __getitem__(self, idx):
507
+ try:
508
+ cluster = self._clusters[idx]
509
+ # choose random from cluster
510
+ input_file = cluster[torch.randint(0, len(cluster), (1,), generator=self._generator).item()]
511
+
512
+ return parse_input_json(
513
+ input_path=os.path.join(self._data_dir, input_file),
514
+ mode=self.mode,
515
+ config=self.config,
516
+ data_pipeline=self.data_pipeline,
517
+ data_dir=self._data_dir,
518
+ idx=idx,
519
+ )
520
+ except Exception as e:
521
+ print("ERROR in loading", e)
522
+ traceback.print_exc()
523
+ return parse_input_json(
524
+ input_path=os.path.join(self._data_dir, self._clusters[0][0]),
525
+ mode=self.mode,
526
+ config=self.config,
527
+ data_pipeline=self.data_pipeline,
528
+ data_dir=self._data_dir,
529
+ idx=idx,
530
+ )
531
+
532
+
533
+ def __len__(self):
534
+ return len(self._clusters)
535
+
536
+
537
+ class DockFormerDataLoader(torch.utils.data.DataLoader):
538
+ def __init__(self, *args, config, stage="train", generator=None, **kwargs):
539
+ super().__init__(*args, **kwargs)
540
+ self.config = config
541
+ self.stage = stage
542
+ # self.generator = generator
543
+
544
+ def _add_batch_properties(self, batch):
545
+ if self.config[self.stage].uniform_recycling:
546
+ aatype = batch["aatype"]
547
+ max_recycling_dim = aatype.shape[-1]
548
+
549
+ # num_recycles = torch.randint(0, max_recycling_dim, (1,), generator=self.generator)
550
+ num_recycles = torch.randint(0, max_recycling_dim, (1,)).item()
551
+
552
+ resample_recycling = lambda t: t[..., :num_recycles + 1]
553
+ batch = tensor_tree_map(resample_recycling, batch)
554
+
555
+ return batch
556
+
557
+ def __iter__(self):
558
+ it = super().__iter__()
559
+
560
+ def _batch_prop_gen(iterator):
561
+ for batch in iterator:
562
+ yield self._add_batch_properties(batch)
563
+
564
+ return _batch_prop_gen(it)
565
+
566
+
567
+ class DockFormerDataModule(L.LightningDataModule):
568
+ def __init__(self,
569
+ config: mlc.ConfigDict,
570
+ train_data_file: Optional[str] = None,
571
+ val_data_file: Optional[str] = None,
572
+ batch_seed: Optional[int] = None,
573
+ **kwargs
574
+ ):
575
+ super(DockFormerDataModule, self).__init__()
576
+
577
+ self.config = config
578
+ self.train_data_file = train_data_file
579
+ self.val_data_file = val_data_file
580
+ self.batch_seed = batch_seed
581
+
582
+ assert self.train_data_file is not None, "train_data_file must be specified"
583
+ assert self.val_data_file is not None, "val_data_file must be specified"
584
+
585
+ self.train_dataset = None
586
+ self.val_dataset = None
587
+
588
+ def setup(self, stage):
589
+ generator = None
590
+ if self.batch_seed is not None:
591
+ generator = torch.Generator()
592
+ generator = generator.manual_seed(self.batch_seed + 1)
593
+
594
+ self.train_dataset = DockFormerClusteredDataset(
595
+ clusters_json=self.train_data_file,
596
+ config=self.config,
597
+ mode="train",
598
+ generator=generator,
599
+ )
600
+
601
+ self.val_dataset = DockFormerSimpleDataset(
602
+ clusters_json=self.val_data_file,
603
+ config=self.config,
604
+ mode="eval",
605
+ )
606
+
607
+ def _gen_dataloader(self, stage):
608
+ generator = None
609
+ if self.batch_seed is not None:
610
+ generator = torch.Generator()
611
+ generator = generator.manual_seed(self.batch_seed)
612
+
613
+ should_shuffle = stage == "train"
614
+ if stage == "train":
615
+ dataset = self.train_dataset
616
+ elif stage == "eval":
617
+ dataset = self.val_dataset
618
+ else:
619
+ raise ValueError("Invalid stage")
620
+
621
+ batch_collator = OpenFoldBatchCollator()
622
+
623
+ dl = DockFormerDataLoader(
624
+ dataset,
625
+ config=self.config,
626
+ stage=stage,
627
+ # generator=generator,
628
+ batch_size=self.config.data_module.data_loaders.batch_size,
629
+ # num_workers=self.config.data_module.data_loaders.num_workers,
630
+ num_workers=0, # TODO bshor: solve generator pickling issue and then bring back num_workers, or just remove generator
631
+ collate_fn=batch_collator,
632
+ shuffle=should_shuffle,
633
+ )
634
+
635
+ return dl
636
+
637
+ def train_dataloader(self):
638
+ return self._gen_dataloader("train")
639
+
640
+ def val_dataloader(self):
641
+ if self.val_dataset is not None:
642
+ return self._gen_dataloader("eval")
643
+ return None
dockformerpp/data/data_pipeline.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import json
16
+ import os
17
+ import time
18
+ from typing import List
19
+
20
+ import numpy as np
21
+ import torch
22
+ import ml_collections as mlc
23
+ from rdkit import Chem
24
+
25
+ from dockformerpp.data import data_transforms
26
+ from dockformerpp.data.data_transforms import get_restype_atom37_mask, get_restypes
27
+ from dockformerpp.data.protein_features import make_protein_features
28
+ from dockformerpp.data.utils import FeatureTensorDict, FeatureDict
29
+ from dockformerpp.utils import protein
30
+
31
+
32
+ def _np_filter_and_to_tensor_dict(np_example: FeatureDict, features_to_keep: List[str]) -> FeatureTensorDict:
33
+ """Creates dict of tensors from a dict of NumPy arrays.
34
+
35
+ Args:
36
+ np_example: A dict of NumPy feature arrays.
37
+ features: A list of strings of feature names to be returned in the dataset.
38
+
39
+ Returns:
40
+ A dictionary of features mapping feature names to features. Only the given
41
+ features are returned, all other ones are filtered out.
42
+ """
43
+ # torch generates warnings if feature is already a torch Tensor
44
+ to_tensor = lambda t: torch.tensor(t) if type(t) != torch.Tensor else t.clone().detach()
45
+ tensor_dict = {
46
+ k: to_tensor(v) for k, v in np_example.items() if k in features_to_keep
47
+ }
48
+
49
+ return tensor_dict
50
+
51
+
52
+ def _add_protein_probablistic_features(features: FeatureDict, cfg: mlc.ConfigDict, mode: str) -> FeatureDict:
53
+ if mode == "train":
54
+ p = torch.rand(1).item()
55
+ use_clamped_fape_value = float(p < cfg.supervised.clamp_prob)
56
+ features["use_clamped_fape"] = np.float32(use_clamped_fape_value)
57
+ else:
58
+ features["use_clamped_fape"] = np.float32(0.0)
59
+ return features
60
+
61
+
62
+ @data_transforms.curry1
63
+ def compose(x, fs):
64
+ for f in fs:
65
+ x = f(x)
66
+ return x
67
+
68
+
69
+ def _apply_protein_transforms(tensors: FeatureTensorDict) -> FeatureTensorDict:
70
+ transforms = [
71
+ data_transforms.cast_to_64bit_ints,
72
+ data_transforms.squeeze_features,
73
+ data_transforms.make_atom14_masks,
74
+ data_transforms.make_atom14_positions,
75
+ data_transforms.atom37_to_frames,
76
+ data_transforms.atom37_to_torsion_angles(""),
77
+ data_transforms.make_pseudo_beta(),
78
+ data_transforms.get_backbone_frames,
79
+ data_transforms.get_chi_angles,
80
+ ]
81
+
82
+ tensors = compose(transforms)(tensors)
83
+
84
+ return tensors
85
+
86
+
87
+ def _apply_protein_probablistic_transforms(tensors: FeatureTensorDict, cfg: mlc.ConfigDict, mode: str) \
88
+ -> FeatureTensorDict:
89
+ transforms = [data_transforms.make_target_feat()]
90
+
91
+ crop_feats = dict(cfg.common.feat)
92
+
93
+ if cfg[mode].fixed_size:
94
+ transforms.append(data_transforms.select_feat(list(crop_feats)))
95
+ # TODO bshor: restore transforms for training on cropped proteins, need to handle pocket somehow
96
+ # if so, look for random_crop_to_size and make_fixed_size in data_transforms.py
97
+
98
+ compose(transforms)(tensors)
99
+
100
+ return tensors
101
+
102
+
103
+ class DataPipeline:
104
+ """Assembles input features."""
105
+ def __init__(self, config: mlc.ConfigDict, mode: str):
106
+ self.config = config
107
+ self.mode = mode
108
+
109
+ self.feature_names = config.common.unsupervised_features
110
+ if config[mode].supervised:
111
+ self.feature_names += config.supervised.supervised_features
112
+
113
+ def process_pdb(self, pdb_path: str) -> FeatureTensorDict:
114
+ """
115
+ Assembles features for a protein in a PDB file.
116
+ """
117
+ with open(pdb_path, 'r') as f:
118
+ pdb_str = f.read()
119
+
120
+ protein_object = protein.from_pdb_string(pdb_str)
121
+ description = os.path.splitext(os.path.basename(pdb_path))[0].upper()
122
+ pdb_feats = make_protein_features(protein_object, description)
123
+ pdb_feats = _add_protein_probablistic_features(pdb_feats, self.config, self.mode)
124
+
125
+ tensor_feats = _np_filter_and_to_tensor_dict(pdb_feats, self.feature_names)
126
+
127
+ tensor_feats = _apply_protein_transforms(tensor_feats)
128
+ tensor_feats = _apply_protein_probablistic_transforms(tensor_feats, self.config, self.mode)
129
+
130
+ return tensor_feats
131
+
132
+ def _prepare_recycles(feat: torch.Tensor, num_recycles: int) -> torch.Tensor:
133
+ return feat.unsqueeze(-1).repeat(*([1] * len(feat.shape)), num_recycles)
134
+
135
+
136
+ def _fit_to_crop(target_tensor: torch.Tensor, crop_size: int, start_ind: int) -> torch.Tensor:
137
+ if len(target_tensor.shape) == 1:
138
+ ret = torch.zeros((crop_size, ), dtype=target_tensor.dtype)
139
+ ret[start_ind:start_ind + target_tensor.shape[0]] = target_tensor
140
+ return ret
141
+ elif len(target_tensor.shape) == 2:
142
+ ret = torch.zeros((crop_size, target_tensor.shape[-1]), dtype=target_tensor.dtype)
143
+ ret[start_ind:start_ind + target_tensor.shape[0], :] = target_tensor
144
+ return ret
145
+ else:
146
+ ret = torch.zeros((crop_size, *target_tensor.shape[1:]), dtype=target_tensor.dtype)
147
+ ret[start_ind:start_ind + target_tensor.shape[0], ...] = target_tensor
148
+ return ret
149
+
150
+
151
+ def parse_input_json(input_path: str, mode: str, config: mlc.ConfigDict, data_pipeline: DataPipeline,
152
+ data_dir: str, idx: int) -> FeatureTensorDict:
153
+ start_load_time = time.time()
154
+ input_data = json.load(open(input_path, "r"))
155
+ if mode == "train" or mode == "eval":
156
+ print("loading", input_data["pdb_id"], end=" ")
157
+
158
+ num_recycles = config.common.max_recycling_iters + 1
159
+
160
+ input_protein_r_feats = data_pipeline.process_pdb(pdb_path=os.path.join(data_dir, input_data["input_r_structure"]))
161
+ input_protein_l_feats = data_pipeline.process_pdb(pdb_path=os.path.join(data_dir, input_data["input_l_structure"]))
162
+
163
+ n_res_r = input_protein_r_feats["protein_target_feat"].shape[0]
164
+ n_res_l = input_protein_l_feats["protein_target_feat"].shape[0]
165
+ n_res_total = n_res_r + n_res_l
166
+ n_affinity = 1
167
+
168
+ # add 1 for affinity token
169
+ crop_size = n_res_total + n_affinity
170
+ if (mode == "train" or mode == "eval") and config.train.fixed_size:
171
+ crop_size = config.train.crop_size
172
+
173
+ assert crop_size >= n_res_total + n_affinity, f"crop_size: {crop_size}, n_res_r: {n_res_r}, n_res_l: {n_res_l}"
174
+
175
+ token_mask = torch.zeros((crop_size,), dtype=torch.float32)
176
+ token_mask[:n_res_total + n_affinity] = 1
177
+
178
+ protein_r_mask = torch.zeros((crop_size,), dtype=torch.float32)
179
+ protein_r_mask[:n_res_r] = 1
180
+
181
+ protein_l_mask = torch.zeros((crop_size,), dtype=torch.float32)
182
+ protein_l_mask[n_res_r:n_res_total] = 1
183
+
184
+ affinity_mask = torch.zeros((crop_size,), dtype=torch.float32)
185
+ affinity_mask[n_res_total] = 1
186
+
187
+ structural_mask = torch.zeros((crop_size,), dtype=torch.float32)
188
+ structural_mask[:n_res_total] = 1
189
+
190
+ inter_pair_mask = torch.zeros((crop_size, crop_size), dtype=torch.float32)
191
+ inter_pair_mask[:n_res_r, n_res_r:n_res_total] = 1
192
+ inter_pair_mask[n_res_r:n_res_total, :n_res_r] = 1
193
+
194
+ tf_dim = input_protein_r_feats["protein_target_feat"].shape[-1]
195
+
196
+ target_feat = torch.zeros((crop_size, tf_dim + 3), dtype=torch.float32)
197
+ target_feat[:n_res_r, :tf_dim] = input_protein_r_feats["protein_target_feat"]
198
+ target_feat[n_res_r:n_res_total, :tf_dim] = input_protein_l_feats["protein_target_feat"]
199
+
200
+ target_feat[:n_res_r, tf_dim] = 1 # Set "is_protein_r" flag for protein rows
201
+ target_feat[n_res_r:n_res_total, tf_dim + 1] = 1 # Set "is_protein_l" flag for ligand rows
202
+ target_feat[n_res_total, tf_dim + 2] = 1 # Set "is_affinity" flag for affinity row
203
+
204
+ input_positions = torch.zeros((crop_size, 3), dtype=torch.float32)
205
+ input_positions[:n_res_r] = input_protein_r_feats["pseudo_beta"]
206
+ input_positions[n_res_r:n_res_total] = input_protein_l_feats["pseudo_beta"]
207
+
208
+ distogram_mask = torch.zeros(crop_size)
209
+ if mode == "train":
210
+ ones_indices = torch.randperm(n_res_total)[:int(n_res_total * config.train.distogram_mask_prob)]
211
+ # print(ones_indices)
212
+ distogram_mask[ones_indices] = 1
213
+ input_positions = input_positions * (1 - distogram_mask).unsqueeze(-1)
214
+ elif mode == "predict":
215
+ # ignore all positions where pseudo_beta is 0, 0, 0
216
+ distogram_mask = (input_positions == 0).all(dim=-1).float()
217
+ # print("Ignoring residues", torch.nonzero(distogram_mask).flatten())
218
+
219
+ # Implement ligand as amino acid type 20
220
+ aatype = torch.cat([input_protein_r_feats["aatype"], input_protein_l_feats["aatype"]], dim=0)
221
+ residue_index = torch.cat([input_protein_r_feats["residue_index"], input_protein_l_feats["residue_index"]], dim=0)
222
+ residx_atom37_to_atom14 = torch.cat([input_protein_r_feats["residx_atom37_to_atom14"],
223
+ input_protein_l_feats["residx_atom37_to_atom14"]],
224
+ dim=0)
225
+ atom37_atom_exists = torch.cat([input_protein_r_feats["atom37_atom_exists"],
226
+ input_protein_l_feats["atom37_atom_exists"]], dim=0)
227
+
228
+ feats = {
229
+ "token_mask": token_mask,
230
+ "protein_r_mask": protein_r_mask,
231
+ "protein_l_mask": protein_l_mask,
232
+ "affinity_mask": affinity_mask,
233
+ "structural_mask": structural_mask,
234
+ "inter_pair_mask": inter_pair_mask,
235
+
236
+ "target_feat": target_feat,
237
+ "input_positions": input_positions,
238
+ "distogram_mask": distogram_mask,
239
+ "residue_index": _fit_to_crop(residue_index, crop_size, 0),
240
+ "aatype": _fit_to_crop(aatype, crop_size, 0),
241
+ "residx_atom37_to_atom14": _fit_to_crop(residx_atom37_to_atom14, crop_size, 0),
242
+ "atom37_atom_exists": _fit_to_crop(atom37_atom_exists, crop_size, 0),
243
+ }
244
+
245
+ if mode == "predict":
246
+ feats.update({
247
+ "in_chain_residue_index_r": input_protein_r_feats["in_chain_residue_index"],
248
+ "chain_index_r": input_protein_r_feats["chain_index"],
249
+ "in_chain_residue_index_l": input_protein_l_feats["in_chain_residue_index"],
250
+ "chain_index_l": input_protein_l_feats["chain_index"],
251
+ })
252
+
253
+ if mode == 'train' or mode == 'eval':
254
+ gt_protein_r_feats = data_pipeline.process_pdb(pdb_path=os.path.join(data_dir, input_data["gt_r_structure"]))
255
+ gt_protein_l_feats = data_pipeline.process_pdb(pdb_path=os.path.join(data_dir, input_data["gt_l_structure"]))
256
+
257
+ affinity_loss_factor = torch.tensor([1.0], dtype=torch.float32)
258
+ if input_data.get("affinity") is None:
259
+ eps = 1e-6
260
+ affinity_loss_factor = torch.tensor([eps], dtype=torch.float32)
261
+ affinity = torch.tensor([0.0], dtype=torch.float32)
262
+ else:
263
+ affinity = torch.tensor([input_data["affinity"]], dtype=torch.float32)
264
+
265
+ resolution = torch.tensor(input_data["resolution"], dtype=torch.float32)
266
+
267
+ # prepare inter_contacts
268
+ expanded_prot_r_pos = gt_protein_r_feats["pseudo_beta"].unsqueeze(1) # Shape: (n_res_r, 1, 3)
269
+ expanded_prot_l_pos = gt_protein_l_feats["pseudo_beta"].unsqueeze(0) # Shape: (1, n_res_l, 3)
270
+ distances = torch.sqrt(torch.sum((expanded_prot_r_pos - expanded_prot_l_pos) ** 2, dim=-1))
271
+ inter_contact = (distances < 8.0).float()
272
+ binding_site_mask_r = inter_contact.any(dim=1).float()
273
+ binding_site_mask_l = inter_contact.any(dim=0).float()
274
+ print("attaching binding masks", binding_site_mask_r.shape, binding_site_mask_l.shape)
275
+ binding_site_mask = torch.cat([binding_site_mask_r, binding_site_mask_l], dim=0)
276
+
277
+ inter_contact_reshaped_to_crop = torch.zeros((crop_size, crop_size), dtype=torch.float32)
278
+ inter_contact_reshaped_to_crop[:n_res_r, n_res_r:n_res_total] = inter_contact
279
+ inter_contact_reshaped_to_crop[n_res_r:n_res_total, :n_res_r] = inter_contact.T
280
+
281
+ # Use CA positions only
282
+ atom37_gt_positions = torch.cat([gt_protein_r_feats["all_atom_positions"],
283
+ gt_protein_l_feats["all_atom_positions"]], dim=0)
284
+ atom37_atom_exists_in_res = torch.cat([gt_protein_r_feats["atom37_atom_exists"],
285
+ gt_protein_l_feats["atom37_atom_exists"]], dim=0)
286
+ atom37_atom_exists_in_gt = torch.cat([gt_protein_r_feats["all_atom_mask"],
287
+ gt_protein_l_feats["all_atom_mask"]], dim=0)
288
+
289
+ atom14_gt_positions = torch.cat([gt_protein_r_feats["atom14_gt_positions"],
290
+ gt_protein_l_feats["atom14_gt_positions"]], dim=0)
291
+ atom14_atom_exists_in_res = torch.cat([gt_protein_r_feats["atom14_atom_exists"],
292
+ gt_protein_l_feats["atom14_atom_exists"]], dim=0)
293
+ atom14_atom_exists_in_gt = torch.cat([gt_protein_r_feats["atom14_gt_exists"],
294
+ gt_protein_l_feats["atom14_gt_exists"]], dim=0)
295
+
296
+ gt_pseudo_beta_joined = torch.cat([gt_protein_r_feats["pseudo_beta"], gt_protein_l_feats["pseudo_beta"]], dim=0)
297
+ gt_pseudo_beta_joined_mask = torch.cat([gt_protein_r_feats["pseudo_beta_mask"],
298
+ gt_protein_l_feats["pseudo_beta_mask"]], dim=0)
299
+
300
+ # IGNORES: residx_atom14_to_atom37, rigidgroups_group_exists,
301
+ # rigidgroups_group_is_ambiguous, pseudo_beta_mask, backbone_rigid_mask, protein_target_feat
302
+ gt_protein_feats = {
303
+ "atom37_gt_positions": atom37_gt_positions, # torch.Size([n_struct, 37, 3])
304
+ "atom37_atom_exists_in_res": atom37_atom_exists_in_res, # torch.Size([n_struct, 37])
305
+ "atom37_atom_exists_in_gt": atom37_atom_exists_in_gt, # torch.Size([n_struct, 37])
306
+
307
+ "atom14_gt_positions": atom14_gt_positions, # torch.Size([n_struct, 14, 3])
308
+ "atom14_atom_exists_in_res": atom14_atom_exists_in_res, # torch.Size([n_struct, 14])
309
+ "atom14_atom_exists_in_gt": atom14_atom_exists_in_gt, # torch.Size([n_struct, 14])
310
+
311
+ "gt_pseudo_beta_joined": gt_pseudo_beta_joined, # torch.Size([n_struct, 3])
312
+ "gt_pseudo_beta_joined_mask": gt_pseudo_beta_joined_mask, # torch.Size([n_struct])
313
+
314
+ # These we don't need to add the ligand to, because padding is sufficient (everything should be 0)
315
+ "atom14_alt_gt_positions": torch.cat([gt_protein_r_feats["atom14_alt_gt_positions"],
316
+ gt_protein_l_feats["atom14_alt_gt_positions"]], dim=0), # torch.Size([n_res, 14, 3])
317
+ "atom14_alt_gt_exists": torch.cat([gt_protein_r_feats["atom14_alt_gt_exists"],
318
+ gt_protein_l_feats["atom14_alt_gt_exists"]], dim=0), # torch.Size([n_res, 14])
319
+ "atom14_atom_is_ambiguous": torch.cat([gt_protein_r_feats["atom14_atom_is_ambiguous"],
320
+ gt_protein_l_feats["atom14_atom_is_ambiguous"]], dim=0), # torch.Size([n_res, 14])
321
+ "rigidgroups_gt_frames": torch.cat([gt_protein_r_feats["rigidgroups_gt_frames"],
322
+ gt_protein_l_feats["rigidgroups_gt_frames"]], dim=0), # torch.Size([n_res, 8, 4, 4])
323
+ "rigidgroups_gt_exists": torch.cat([gt_protein_r_feats["rigidgroups_gt_exists"],
324
+ gt_protein_l_feats["rigidgroups_gt_exists"]], dim=0), # torch.Size([n_res, 8])
325
+ "rigidgroups_alt_gt_frames": torch.cat([gt_protein_r_feats["rigidgroups_alt_gt_frames"],
326
+ gt_protein_l_feats["rigidgroups_alt_gt_frames"]], dim=0), # torch.Size([n_res, 8, 4, 4])
327
+ "backbone_rigid_tensor": torch.cat([gt_protein_r_feats["backbone_rigid_tensor"],
328
+ gt_protein_l_feats["backbone_rigid_tensor"]], dim=0), # torch.Size([n_res, 4, 4])
329
+ "backbone_rigid_mask": torch.cat([gt_protein_r_feats["backbone_rigid_mask"],
330
+ gt_protein_l_feats["backbone_rigid_mask"]], dim=0), # torch.Size([n_res])
331
+ "chi_angles_sin_cos": torch.cat([gt_protein_r_feats["chi_angles_sin_cos"],
332
+ gt_protein_l_feats["chi_angles_sin_cos"]], dim=0),
333
+ "chi_mask": torch.cat([gt_protein_r_feats["chi_mask"], gt_protein_l_feats["chi_mask"]], dim=0),
334
+ }
335
+
336
+ for k, v in gt_protein_feats.items():
337
+ gt_protein_feats[k] = _fit_to_crop(v, crop_size, 0)
338
+
339
+ feats = {
340
+ **feats,
341
+ **gt_protein_feats,
342
+ "resolution": resolution,
343
+ "affinity": affinity,
344
+ "affinity_loss_factor": affinity_loss_factor,
345
+ "seq_length": torch.tensor(n_res_total),
346
+ "binding_site_mask": _fit_to_crop(binding_site_mask, crop_size, 0),
347
+ "gt_inter_contacts": inter_contact_reshaped_to_crop,
348
+ }
349
+
350
+ for k, v in feats.items():
351
+ # print(k, v.shape)
352
+ feats[k] = _prepare_recycles(v, num_recycles)
353
+
354
+ feats["batch_idx"] = torch.tensor(
355
+ [idx for _ in range(crop_size)], dtype=torch.int64, device=feats["aatype"].device
356
+ )
357
+
358
+ print("load time", round(time.time() - start_load_time, 4))
359
+
360
+ return feats
dockformerpp/data/data_transforms.py ADDED
@@ -0,0 +1,731 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import itertools
17
+ from functools import reduce, wraps
18
+ from operator import add
19
+
20
+ import numpy as np
21
+ import torch
22
+
23
+ from dockformerpp.config import NUM_RES
24
+ from dockformerpp.utils import residue_constants as rc
25
+ from dockformerpp.utils.residue_constants import restypes
26
+ from dockformerpp.utils.rigid_utils import Rotation, Rigid
27
+ from dockformerpp.utils.geometry.rigid_matrix_vector import Rigid3Array
28
+ from dockformerpp.utils.geometry.rotation_matrix import Rot3Array
29
+ from dockformerpp.utils.geometry.vector import Vec3Array
30
+ from dockformerpp.utils.tensor_utils import (
31
+ tree_map,
32
+ tensor_tree_map,
33
+ batched_gather,
34
+ )
35
+
36
+
37
+ def cast_to_64bit_ints(protein):
38
+ # We keep all ints as int64
39
+ for k, v in protein.items():
40
+ if v.dtype == torch.int32:
41
+ protein[k] = v.type(torch.int64)
42
+
43
+ return protein
44
+
45
+
46
+ def make_one_hot(x, num_classes):
47
+ x_one_hot = torch.zeros(*x.shape, num_classes, device=x.device)
48
+ x_one_hot.scatter_(-1, x.unsqueeze(-1), 1)
49
+ return x_one_hot
50
+
51
+
52
+ def curry1(f):
53
+ """Supply all arguments but the first."""
54
+ @wraps(f)
55
+ def fc(*args, **kwargs):
56
+ return lambda x: f(x, *args, **kwargs)
57
+
58
+ return fc
59
+
60
+
61
+ def squeeze_features(protein):
62
+ """Remove singleton and repeated dimensions in protein features."""
63
+ protein["aatype"] = torch.argmax(protein["aatype"], dim=-1)
64
+ for k in [
65
+ "domain_name",
66
+ "seq_length",
67
+ "sequence",
68
+ "resolution",
69
+ "residue_index",
70
+ ]:
71
+ if k in protein:
72
+ final_dim = protein[k].shape[-1]
73
+ if isinstance(final_dim, int) and final_dim == 1:
74
+ if torch.is_tensor(protein[k]):
75
+ protein[k] = torch.squeeze(protein[k], dim=-1)
76
+ else:
77
+ protein[k] = np.squeeze(protein[k], axis=-1)
78
+
79
+ for k in ["seq_length"]:
80
+ if k in protein:
81
+ protein[k] = protein[k][0]
82
+
83
+ return protein
84
+
85
+
86
+ def pseudo_beta_fn(aatype, all_atom_positions, all_atom_mask):
87
+ """Create pseudo beta features."""
88
+ is_gly = torch.eq(aatype, rc.restype_order["G"])
89
+ ca_idx = rc.atom_order["CA"]
90
+ cb_idx = rc.atom_order["CB"]
91
+ pseudo_beta = torch.where(
92
+ torch.tile(is_gly[..., None], [1] * len(is_gly.shape) + [3]),
93
+ all_atom_positions[..., ca_idx, :],
94
+ all_atom_positions[..., cb_idx, :],
95
+ )
96
+
97
+ if all_atom_mask is not None:
98
+ pseudo_beta_mask = torch.where(
99
+ is_gly, all_atom_mask[..., ca_idx], all_atom_mask[..., cb_idx]
100
+ )
101
+ return pseudo_beta, pseudo_beta_mask
102
+ else:
103
+ return pseudo_beta
104
+
105
+
106
+ @curry1
107
+ def make_pseudo_beta(protein):
108
+ """Create pseudo-beta (alpha for glycine) position and mask."""
109
+ (protein["pseudo_beta"], protein["pseudo_beta_mask"]) = pseudo_beta_fn(
110
+ protein["aatype"],
111
+ protein["all_atom_positions"],
112
+ protein["all_atom_mask"],
113
+ )
114
+ return protein
115
+
116
+
117
+ @curry1
118
+ def make_target_feat(protein):
119
+ """Create and concatenate protein features."""
120
+ # Whether there is a domain break. Always zero for chains, but keeping for
121
+ # compatibility with domain datasets.
122
+ aatype_1hot = make_one_hot(protein["aatype"], 20)
123
+
124
+ protein["protein_target_feat"] = aatype_1hot
125
+
126
+ return protein
127
+
128
+
129
+
130
+ @curry1
131
+ def select_feat(protein, feature_list):
132
+ return {k: v for k, v in protein.items() if k in feature_list}
133
+
134
+
135
+ def get_restypes(device):
136
+ restype_atom14_to_atom37 = []
137
+ restype_atom37_to_atom14 = []
138
+ restype_atom14_mask = []
139
+
140
+ for rt in rc.restypes:
141
+ atom_names = rc.restype_name_to_atom14_names[rc.restype_1to3[rt]]
142
+ restype_atom14_to_atom37.append(
143
+ [(rc.atom_order[name] if name else 0) for name in atom_names]
144
+ )
145
+ atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)}
146
+ restype_atom37_to_atom14.append(
147
+ [
148
+ (atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0)
149
+ for name in rc.atom_types
150
+ ]
151
+ )
152
+
153
+ restype_atom14_mask.append(
154
+ [(1.0 if name else 0.0) for name in atom_names]
155
+ )
156
+
157
+ # Add dummy mapping for restype 'UNK'
158
+ restype_atom14_to_atom37.append([0] * 14)
159
+ restype_atom37_to_atom14.append([0] * 37)
160
+ restype_atom14_mask.append([0.0] * 14)
161
+
162
+ restype_atom14_to_atom37 = torch.tensor(
163
+ restype_atom14_to_atom37,
164
+ dtype=torch.int32,
165
+ device=device,
166
+ )
167
+ restype_atom37_to_atom14 = torch.tensor(
168
+ restype_atom37_to_atom14,
169
+ dtype=torch.int32,
170
+ device=device,
171
+ )
172
+ restype_atom14_mask = torch.tensor(
173
+ restype_atom14_mask,
174
+ dtype=torch.float32,
175
+ device=device,
176
+ )
177
+
178
+ return restype_atom14_to_atom37, restype_atom37_to_atom14, restype_atom14_mask
179
+
180
+
181
+ def get_restype_atom37_mask(device):
182
+ # create the corresponding mask
183
+ restype_atom37_mask = torch.zeros(
184
+ [len(restypes) + 1, 37], dtype=torch.float32, device=device
185
+ )
186
+ for restype, restype_letter in enumerate(rc.restypes):
187
+ restype_name = rc.restype_1to3[restype_letter]
188
+ atom_names = rc.residue_atoms[restype_name]
189
+ for atom_name in atom_names:
190
+ atom_type = rc.atom_order[atom_name]
191
+ restype_atom37_mask[restype, atom_type] = 1
192
+ return restype_atom37_mask
193
+
194
+
195
+ def make_atom14_masks(protein):
196
+ """Construct denser atom positions (14 dimensions instead of 37)."""
197
+ restype_atom14_to_atom37, restype_atom37_to_atom14, restype_atom14_mask = get_restypes(protein["aatype"].device)
198
+
199
+ protein_aatype = protein['aatype'].to(torch.long)
200
+
201
+ # create the mapping for (residx, atom14) --> atom37, i.e. an array
202
+ # with shape (num_res, 14) containing the atom37 indices for this protein
203
+ residx_atom14_to_atom37 = restype_atom14_to_atom37[protein_aatype]
204
+ residx_atom14_mask = restype_atom14_mask[protein_aatype]
205
+
206
+ protein["atom14_atom_exists"] = residx_atom14_mask
207
+ protein["residx_atom14_to_atom37"] = residx_atom14_to_atom37.long()
208
+
209
+ # create the gather indices for mapping back
210
+ residx_atom37_to_atom14 = restype_atom37_to_atom14[protein_aatype]
211
+ protein["residx_atom37_to_atom14"] = residx_atom37_to_atom14.long()
212
+
213
+ restype_atom37_mask = get_restype_atom37_mask(protein["aatype"].device)
214
+
215
+ residx_atom37_mask = restype_atom37_mask[protein_aatype]
216
+ protein["atom37_atom_exists"] = residx_atom37_mask
217
+
218
+ return protein
219
+
220
+
221
+ def make_atom14_positions(protein):
222
+ """Constructs denser atom positions (14 dimensions instead of 37)."""
223
+ residx_atom14_mask = protein["atom14_atom_exists"]
224
+ residx_atom14_to_atom37 = protein["residx_atom14_to_atom37"]
225
+
226
+ # Create a mask for known ground truth positions.
227
+ residx_atom14_gt_mask = residx_atom14_mask * batched_gather(
228
+ protein["all_atom_mask"],
229
+ residx_atom14_to_atom37,
230
+ dim=-1,
231
+ no_batch_dims=len(protein["all_atom_mask"].shape[:-1]),
232
+ )
233
+
234
+ # Gather the ground truth positions.
235
+ residx_atom14_gt_positions = residx_atom14_gt_mask[..., None] * (
236
+ batched_gather(
237
+ protein["all_atom_positions"],
238
+ residx_atom14_to_atom37,
239
+ dim=-2,
240
+ no_batch_dims=len(protein["all_atom_positions"].shape[:-2]),
241
+ )
242
+ )
243
+
244
+ protein["atom14_atom_exists"] = residx_atom14_mask
245
+ protein["atom14_gt_exists"] = residx_atom14_gt_mask
246
+ protein["atom14_gt_positions"] = residx_atom14_gt_positions
247
+
248
+ # As the atom naming is ambiguous for 7 of the 20 amino acids, provide
249
+ # alternative ground truth coordinates where the naming is swapped
250
+ restype_3 = [rc.restype_1to3[res] for res in rc.restypes]
251
+ restype_3 += ["UNK"]
252
+
253
+ # Matrices for renaming ambiguous atoms.
254
+ all_matrices = {
255
+ res: torch.eye(
256
+ 14,
257
+ dtype=protein["all_atom_mask"].dtype,
258
+ device=protein["all_atom_mask"].device,
259
+ )
260
+ for res in restype_3
261
+ }
262
+ for resname, swap in rc.residue_atom_renaming_swaps.items():
263
+ correspondences = torch.arange(
264
+ 14, device=protein["all_atom_mask"].device
265
+ )
266
+ for source_atom_swap, target_atom_swap in swap.items():
267
+ source_index = rc.restype_name_to_atom14_names[resname].index(
268
+ source_atom_swap
269
+ )
270
+ target_index = rc.restype_name_to_atom14_names[resname].index(
271
+ target_atom_swap
272
+ )
273
+ correspondences[source_index] = target_index
274
+ correspondences[target_index] = source_index
275
+ renaming_matrix = protein["all_atom_mask"].new_zeros((14, 14))
276
+ for index, correspondence in enumerate(correspondences):
277
+ renaming_matrix[index, correspondence] = 1.0
278
+ all_matrices[resname] = renaming_matrix
279
+
280
+ renaming_matrices = torch.stack(
281
+ [all_matrices[restype] for restype in restype_3]
282
+ )
283
+
284
+ # Pick the transformation matrices for the given residue sequence
285
+ # shape (num_res, 14, 14).
286
+ renaming_transform = renaming_matrices[protein["aatype"]]
287
+
288
+ # Apply it to the ground truth positions. shape (num_res, 14, 3).
289
+ alternative_gt_positions = torch.einsum(
290
+ "...rac,...rab->...rbc", residx_atom14_gt_positions, renaming_transform
291
+ )
292
+ protein["atom14_alt_gt_positions"] = alternative_gt_positions
293
+
294
+ # Create the mask for the alternative ground truth (differs from the
295
+ # ground truth mask, if only one of the atoms in an ambiguous pair has a
296
+ # ground truth position).
297
+ alternative_gt_mask = torch.einsum(
298
+ "...ra,...rab->...rb", residx_atom14_gt_mask, renaming_transform
299
+ )
300
+ protein["atom14_alt_gt_exists"] = alternative_gt_mask
301
+
302
+ # Create an ambiguous atoms mask. shape: (21, 14).
303
+ restype_atom14_is_ambiguous = protein["all_atom_mask"].new_zeros((21, 14))
304
+ for resname, swap in rc.residue_atom_renaming_swaps.items():
305
+ for atom_name1, atom_name2 in swap.items():
306
+ restype = rc.restype_order[rc.restype_3to1[resname]]
307
+ atom_idx1 = rc.restype_name_to_atom14_names[resname].index(
308
+ atom_name1
309
+ )
310
+ atom_idx2 = rc.restype_name_to_atom14_names[resname].index(
311
+ atom_name2
312
+ )
313
+ restype_atom14_is_ambiguous[restype, atom_idx1] = 1
314
+ restype_atom14_is_ambiguous[restype, atom_idx2] = 1
315
+
316
+ # From this create an ambiguous_mask for the given sequence.
317
+ protein["atom14_atom_is_ambiguous"] = restype_atom14_is_ambiguous[
318
+ protein["aatype"]
319
+ ]
320
+
321
+ return protein
322
+
323
+
324
+ def atom37_to_frames(protein, eps=1e-8):
325
+ aatype = protein["aatype"]
326
+ all_atom_positions = protein["all_atom_positions"]
327
+ all_atom_mask = protein["all_atom_mask"]
328
+
329
+ batch_dims = len(aatype.shape[:-1])
330
+
331
+ restype_rigidgroup_base_atom_names = np.full([21, 8, 3], "", dtype=object)
332
+ restype_rigidgroup_base_atom_names[:, 0, :] = ["C", "CA", "N"]
333
+ restype_rigidgroup_base_atom_names[:, 3, :] = ["CA", "C", "O"]
334
+
335
+ for restype, restype_letter in enumerate(rc.restypes):
336
+ resname = rc.restype_1to3[restype_letter]
337
+ for chi_idx in range(4):
338
+ if rc.chi_angles_mask[restype][chi_idx]:
339
+ names = rc.chi_angles_atoms[resname][chi_idx]
340
+ restype_rigidgroup_base_atom_names[
341
+ restype, chi_idx + 4, :
342
+ ] = names[1:]
343
+
344
+ restype_rigidgroup_mask = all_atom_mask.new_zeros(
345
+ (*aatype.shape[:-1], 21, 8),
346
+ )
347
+ restype_rigidgroup_mask[..., 0] = 1
348
+ restype_rigidgroup_mask[..., 3] = 1
349
+ restype_rigidgroup_mask[..., :len(restypes), 4:] = all_atom_mask.new_tensor(
350
+ rc.chi_angles_mask
351
+ )
352
+
353
+ lookuptable = rc.atom_order.copy()
354
+ lookuptable[""] = 0
355
+ lookup = np.vectorize(lambda x: lookuptable[x])
356
+ restype_rigidgroup_base_atom37_idx = lookup(
357
+ restype_rigidgroup_base_atom_names,
358
+ )
359
+ restype_rigidgroup_base_atom37_idx = aatype.new_tensor(
360
+ restype_rigidgroup_base_atom37_idx,
361
+ )
362
+ restype_rigidgroup_base_atom37_idx = (
363
+ restype_rigidgroup_base_atom37_idx.view(
364
+ *((1,) * batch_dims), *restype_rigidgroup_base_atom37_idx.shape
365
+ )
366
+ )
367
+
368
+ residx_rigidgroup_base_atom37_idx = batched_gather(
369
+ restype_rigidgroup_base_atom37_idx,
370
+ aatype,
371
+ dim=-3,
372
+ no_batch_dims=batch_dims,
373
+ )
374
+
375
+ base_atom_pos = batched_gather(
376
+ all_atom_positions,
377
+ residx_rigidgroup_base_atom37_idx,
378
+ dim=-2,
379
+ no_batch_dims=len(all_atom_positions.shape[:-2]),
380
+ )
381
+
382
+ gt_frames = Rigid.from_3_points(
383
+ p_neg_x_axis=base_atom_pos[..., 0, :],
384
+ origin=base_atom_pos[..., 1, :],
385
+ p_xy_plane=base_atom_pos[..., 2, :],
386
+ eps=eps,
387
+ )
388
+
389
+ group_exists = batched_gather(
390
+ restype_rigidgroup_mask,
391
+ aatype,
392
+ dim=-2,
393
+ no_batch_dims=batch_dims,
394
+ )
395
+
396
+ gt_atoms_exist = batched_gather(
397
+ all_atom_mask,
398
+ residx_rigidgroup_base_atom37_idx,
399
+ dim=-1,
400
+ no_batch_dims=len(all_atom_mask.shape[:-1]),
401
+ )
402
+ gt_exists = torch.min(gt_atoms_exist, dim=-1)[0] * group_exists
403
+
404
+ rots = torch.eye(3, dtype=all_atom_mask.dtype, device=aatype.device)
405
+ rots = torch.tile(rots, (*((1,) * batch_dims), 8, 1, 1))
406
+ rots[..., 0, 0, 0] = -1
407
+ rots[..., 0, 2, 2] = -1
408
+
409
+ rots = Rotation(rot_mats=rots)
410
+ gt_frames = gt_frames.compose(Rigid(rots, None))
411
+
412
+ restype_rigidgroup_is_ambiguous = all_atom_mask.new_zeros(
413
+ *((1,) * batch_dims), 21, 8
414
+ )
415
+ restype_rigidgroup_rots = torch.eye(
416
+ 3, dtype=all_atom_mask.dtype, device=aatype.device
417
+ )
418
+ restype_rigidgroup_rots = torch.tile(
419
+ restype_rigidgroup_rots,
420
+ (*((1,) * batch_dims), 21, 8, 1, 1),
421
+ )
422
+
423
+ for resname, _ in rc.residue_atom_renaming_swaps.items():
424
+ restype = rc.restype_order[rc.restype_3to1[resname]]
425
+ chi_idx = int(sum(rc.chi_angles_mask[restype]) - 1)
426
+ restype_rigidgroup_is_ambiguous[..., restype, chi_idx + 4] = 1
427
+ restype_rigidgroup_rots[..., restype, chi_idx + 4, 1, 1] = -1
428
+ restype_rigidgroup_rots[..., restype, chi_idx + 4, 2, 2] = -1
429
+
430
+ residx_rigidgroup_is_ambiguous = batched_gather(
431
+ restype_rigidgroup_is_ambiguous,
432
+ aatype,
433
+ dim=-2,
434
+ no_batch_dims=batch_dims,
435
+ )
436
+
437
+ residx_rigidgroup_ambiguity_rot = batched_gather(
438
+ restype_rigidgroup_rots,
439
+ aatype,
440
+ dim=-4,
441
+ no_batch_dims=batch_dims,
442
+ )
443
+
444
+ residx_rigidgroup_ambiguity_rot = Rotation(
445
+ rot_mats=residx_rigidgroup_ambiguity_rot
446
+ )
447
+ alt_gt_frames = gt_frames.compose(
448
+ Rigid(residx_rigidgroup_ambiguity_rot, None)
449
+ )
450
+
451
+ gt_frames_tensor = gt_frames.to_tensor_4x4()
452
+ alt_gt_frames_tensor = alt_gt_frames.to_tensor_4x4()
453
+
454
+ protein["rigidgroups_gt_frames"] = gt_frames_tensor
455
+ protein["rigidgroups_gt_exists"] = gt_exists
456
+ protein["rigidgroups_group_exists"] = group_exists
457
+ protein["rigidgroups_group_is_ambiguous"] = residx_rigidgroup_is_ambiguous
458
+ protein["rigidgroups_alt_gt_frames"] = alt_gt_frames_tensor
459
+
460
+ return protein
461
+
462
+
463
+ def get_chi_atom_indices():
464
+ """Returns atom indices needed to compute chi angles for all residue types.
465
+
466
+ Returns:
467
+ A tensor of shape [residue_types=21, chis=4, atoms=4]. The residue types are
468
+ in the order specified in rc.restypes + unknown residue type
469
+ at the end. For chi angles which are not defined on the residue, the
470
+ positions indices are by default set to 0.
471
+ """
472
+ chi_atom_indices = []
473
+ for residue_name in rc.restypes:
474
+ residue_name = rc.restype_1to3[residue_name]
475
+ residue_chi_angles = rc.chi_angles_atoms[residue_name]
476
+ atom_indices = []
477
+ for chi_angle in residue_chi_angles:
478
+ atom_indices.append([rc.atom_order[atom] for atom in chi_angle])
479
+ for _ in range(4 - len(atom_indices)):
480
+ atom_indices.append(
481
+ [0, 0, 0, 0]
482
+ ) # For chi angles not defined on the AA.
483
+ chi_atom_indices.append(atom_indices)
484
+
485
+ chi_atom_indices.append([[0, 0, 0, 0]] * 4) # For UNKNOWN residue.
486
+
487
+ return chi_atom_indices
488
+
489
+
490
+ @curry1
491
+ def atom37_to_torsion_angles(
492
+ protein,
493
+ prefix="",
494
+ ):
495
+ """
496
+ Convert coordinates to torsion angles.
497
+
498
+ This function is extremely sensitive to floating point imprecisions
499
+ and should be run with double precision whenever possible.
500
+
501
+ Args:
502
+ Dict containing:
503
+ * (prefix)aatype:
504
+ [*, N_res] residue indices
505
+ * (prefix)all_atom_positions:
506
+ [*, N_res, 37, 3] atom positions (in atom37
507
+ format)
508
+ * (prefix)all_atom_mask:
509
+ [*, N_res, 37] atom position mask
510
+ Returns:
511
+ The same dictionary updated with the following features:
512
+
513
+ "(prefix)torsion_angles_sin_cos" ([*, N_res, 7, 2])
514
+ Torsion angles
515
+ "(prefix)alt_torsion_angles_sin_cos" ([*, N_res, 7, 2])
516
+ Alternate torsion angles (accounting for 180-degree symmetry)
517
+ "(prefix)torsion_angles_mask" ([*, N_res, 7])
518
+ Torsion angles mask
519
+ """
520
+ aatype = protein[prefix + "aatype"]
521
+ all_atom_positions = protein[prefix + "all_atom_positions"]
522
+ all_atom_mask = protein[prefix + "all_atom_mask"]
523
+
524
+ aatype = torch.clamp(aatype, max=20)
525
+
526
+ pad = all_atom_positions.new_zeros(
527
+ [*all_atom_positions.shape[:-3], 1, 37, 3]
528
+ )
529
+ prev_all_atom_positions = torch.cat(
530
+ [pad, all_atom_positions[..., :-1, :, :]], dim=-3
531
+ )
532
+
533
+ pad = all_atom_mask.new_zeros([*all_atom_mask.shape[:-2], 1, 37])
534
+ prev_all_atom_mask = torch.cat([pad, all_atom_mask[..., :-1, :]], dim=-2)
535
+
536
+ pre_omega_atom_pos = torch.cat(
537
+ [prev_all_atom_positions[..., 1:3, :], all_atom_positions[..., :2, :]],
538
+ dim=-2,
539
+ )
540
+ phi_atom_pos = torch.cat(
541
+ [prev_all_atom_positions[..., 2:3, :], all_atom_positions[..., :3, :]],
542
+ dim=-2,
543
+ )
544
+ psi_atom_pos = torch.cat(
545
+ [all_atom_positions[..., :3, :], all_atom_positions[..., 4:5, :]],
546
+ dim=-2,
547
+ )
548
+
549
+ pre_omega_mask = torch.prod(
550
+ prev_all_atom_mask[..., 1:3], dim=-1
551
+ ) * torch.prod(all_atom_mask[..., :2], dim=-1)
552
+ phi_mask = prev_all_atom_mask[..., 2] * torch.prod(
553
+ all_atom_mask[..., :3], dim=-1, dtype=all_atom_mask.dtype
554
+ )
555
+ psi_mask = (
556
+ torch.prod(all_atom_mask[..., :3], dim=-1, dtype=all_atom_mask.dtype)
557
+ * all_atom_mask[..., 4]
558
+ )
559
+
560
+ chi_atom_indices = torch.as_tensor(
561
+ get_chi_atom_indices(), device=aatype.device
562
+ )
563
+
564
+ atom_indices = chi_atom_indices[..., aatype, :, :]
565
+ chis_atom_pos = batched_gather(
566
+ all_atom_positions, atom_indices, -2, len(atom_indices.shape[:-2])
567
+ )
568
+
569
+ chi_angles_mask = list(rc.chi_angles_mask)
570
+ chi_angles_mask.append([0.0, 0.0, 0.0, 0.0])
571
+ chi_angles_mask = all_atom_mask.new_tensor(chi_angles_mask)
572
+
573
+ chis_mask = chi_angles_mask[aatype, :]
574
+
575
+ chi_angle_atoms_mask = batched_gather(
576
+ all_atom_mask,
577
+ atom_indices,
578
+ dim=-1,
579
+ no_batch_dims=len(atom_indices.shape[:-2]),
580
+ )
581
+ chi_angle_atoms_mask = torch.prod(
582
+ chi_angle_atoms_mask, dim=-1, dtype=chi_angle_atoms_mask.dtype
583
+ )
584
+ chis_mask = chis_mask * chi_angle_atoms_mask
585
+
586
+ torsions_atom_pos = torch.cat(
587
+ [
588
+ pre_omega_atom_pos[..., None, :, :],
589
+ phi_atom_pos[..., None, :, :],
590
+ psi_atom_pos[..., None, :, :],
591
+ chis_atom_pos,
592
+ ],
593
+ dim=-3,
594
+ )
595
+
596
+ torsion_angles_mask = torch.cat(
597
+ [
598
+ pre_omega_mask[..., None],
599
+ phi_mask[..., None],
600
+ psi_mask[..., None],
601
+ chis_mask,
602
+ ],
603
+ dim=-1,
604
+ )
605
+
606
+ torsion_frames = Rigid.from_3_points(
607
+ torsions_atom_pos[..., 1, :],
608
+ torsions_atom_pos[..., 2, :],
609
+ torsions_atom_pos[..., 0, :],
610
+ eps=1e-8,
611
+ )
612
+
613
+ fourth_atom_rel_pos = torsion_frames.invert().apply(
614
+ torsions_atom_pos[..., 3, :]
615
+ )
616
+
617
+ torsion_angles_sin_cos = torch.stack(
618
+ [fourth_atom_rel_pos[..., 2], fourth_atom_rel_pos[..., 1]], dim=-1
619
+ )
620
+
621
+ denom = torch.sqrt(
622
+ torch.sum(
623
+ torch.square(torsion_angles_sin_cos),
624
+ dim=-1,
625
+ dtype=torsion_angles_sin_cos.dtype,
626
+ keepdims=True,
627
+ )
628
+ + 1e-8
629
+ )
630
+ torsion_angles_sin_cos = torsion_angles_sin_cos / denom
631
+
632
+ torsion_angles_sin_cos = torsion_angles_sin_cos * all_atom_mask.new_tensor(
633
+ [1.0, 1.0, -1.0, 1.0, 1.0, 1.0, 1.0],
634
+ )[((None,) * len(torsion_angles_sin_cos.shape[:-2])) + (slice(None), None)]
635
+
636
+ chi_is_ambiguous = torsion_angles_sin_cos.new_tensor(
637
+ rc.chi_pi_periodic,
638
+ )[aatype, ...]
639
+
640
+ mirror_torsion_angles = torch.cat(
641
+ [
642
+ all_atom_mask.new_ones(*aatype.shape, 3),
643
+ 1.0 - 2.0 * chi_is_ambiguous,
644
+ ],
645
+ dim=-1,
646
+ )
647
+
648
+ alt_torsion_angles_sin_cos = (
649
+ torsion_angles_sin_cos * mirror_torsion_angles[..., None]
650
+ )
651
+
652
+ protein[prefix + "torsion_angles_sin_cos"] = torsion_angles_sin_cos
653
+ protein[prefix + "alt_torsion_angles_sin_cos"] = alt_torsion_angles_sin_cos
654
+ protein[prefix + "torsion_angles_mask"] = torsion_angles_mask
655
+
656
+ return protein
657
+
658
+
659
+ def get_backbone_frames(protein):
660
+ # DISCREPANCY: AlphaFold uses tensor_7s here. I don't know why.
661
+ protein["backbone_rigid_tensor"] = protein["rigidgroups_gt_frames"][
662
+ ..., 0, :, :
663
+ ]
664
+ protein["backbone_rigid_mask"] = protein["rigidgroups_gt_exists"][..., 0]
665
+
666
+ return protein
667
+
668
+
669
+ def get_chi_angles(protein):
670
+ dtype = protein["all_atom_mask"].dtype
671
+ protein["chi_angles_sin_cos"] = (
672
+ protein["torsion_angles_sin_cos"][..., 3:, :]
673
+ ).to(dtype)
674
+ protein["chi_mask"] = protein["torsion_angles_mask"][..., 3:].to(dtype)
675
+
676
+ return protein
677
+
678
+
679
+ @curry1
680
+ def random_crop_to_size(
681
+ protein,
682
+ crop_size,
683
+ shape_schema,
684
+ seed=None,
685
+ ):
686
+ """Crop randomly to `crop_size`, or keep as is if shorter than that."""
687
+ # We want each ensemble to be cropped the same way
688
+
689
+ g = None
690
+ if seed is not None:
691
+ g = torch.Generator(device=protein["seq_length"].device)
692
+ g.manual_seed(seed)
693
+
694
+ seq_length = protein["seq_length"]
695
+
696
+ num_res_crop_size = min(int(seq_length), crop_size)
697
+
698
+ def _randint(lower, upper):
699
+ return int(torch.randint(
700
+ lower,
701
+ upper + 1,
702
+ (1,),
703
+ device=protein["seq_length"].device,
704
+ generator=g,
705
+ )[0])
706
+
707
+ n = seq_length - num_res_crop_size
708
+ if "use_clamped_fape" in protein and protein["use_clamped_fape"] == 1.:
709
+ right_anchor = n
710
+ else:
711
+ x = _randint(0, n)
712
+ right_anchor = n - x
713
+
714
+ num_res_crop_start = _randint(0, right_anchor)
715
+
716
+ for k, v in protein.items():
717
+ if k not in shape_schema or (NUM_RES not in shape_schema[k]):
718
+ continue
719
+
720
+ slices = []
721
+ for i, (dim_size, dim) in enumerate(zip(shape_schema[k], v.shape)):
722
+ is_num_res = dim_size == NUM_RES
723
+ crop_start = num_res_crop_start if is_num_res else 0
724
+ crop_size = num_res_crop_size if is_num_res else dim
725
+ slices.append(slice(crop_start, crop_start + crop_size))
726
+ protein[k] = v[slices]
727
+
728
+ protein["seq_length"] = protein["seq_length"].new_tensor(num_res_crop_size)
729
+
730
+ return protein
731
+
dockformerpp/data/errors.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """General-purpose errors used throughout the data pipeline"""
17
+ class Error(Exception):
18
+ """Base class for exceptions."""
19
+
20
+
21
+ class MultipleChainsError(Error):
22
+ """An error indicating that multiple chains were found for a given ID."""
dockformerpp/data/parsers.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Functions for parsing various file formats."""
17
+ import collections
18
+ import dataclasses
19
+ import itertools
20
+ import re
21
+ import string
22
+ from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Set
23
+
24
+
25
+ def parse_fasta(fasta_string: str) -> Tuple[Sequence[str], Sequence[str]]:
26
+ """Parses FASTA string and returns list of strings with amino-acid sequences.
27
+
28
+ Arguments:
29
+ fasta_string: The string contents of a FASTA file.
30
+
31
+ Returns:
32
+ A tuple of two lists:
33
+ * A list of sequences.
34
+ * A list of sequence descriptions taken from the comment lines. In the
35
+ same order as the sequences.
36
+ """
37
+ sequences = []
38
+ descriptions = []
39
+ index = -1
40
+ for line in fasta_string.splitlines():
41
+ line = line.strip()
42
+ if line.startswith(">"):
43
+ index += 1
44
+ descriptions.append(line[1:]) # Remove the '>' at the beginning.
45
+ sequences.append("")
46
+ continue
47
+ elif line.startswith("#"):
48
+ continue
49
+ elif not line:
50
+ continue # Skip blank lines.
51
+ sequences[index] += line
52
+
53
+ return sequences, descriptions
dockformerpp/data/protein_features.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ from dockformerpp.data.utils import FeatureDict
4
+ from dockformerpp.utils import residue_constants, protein
5
+
6
+
7
+ def _make_sequence_features(sequence: str, description: str, num_res: int) -> FeatureDict:
8
+ """Construct a feature dict of sequence features."""
9
+ features = {}
10
+ features["aatype"] = residue_constants.sequence_to_onehot(
11
+ sequence=sequence,
12
+ mapping=residue_constants.restype_order_with_x,
13
+ map_unknown_to_x=True,
14
+ )
15
+ features["domain_name"] = np.array(
16
+ [description.encode("utf-8")], dtype=object
17
+ )
18
+ # features["residue_index"] = np.array(range(num_res), dtype=np.int32)
19
+ features["seq_length"] = np.array([num_res] * num_res, dtype=np.int32)
20
+ features["sequence"] = np.array(
21
+ [sequence.encode("utf-8")], dtype=object
22
+ )
23
+ return features
24
+
25
+
26
+ def _aatype_to_str_sequence(aatype):
27
+ return ''.join([
28
+ residue_constants.restypes_with_x[aatype[i]]
29
+ for i in range(len(aatype))
30
+ ])
31
+
32
+
33
+ def _make_protein_structure_features(protein_object: protein.Protein) -> FeatureDict:
34
+ pdb_feats = {}
35
+
36
+ all_atom_positions = protein_object.atom_positions
37
+ all_atom_mask = protein_object.atom_mask
38
+
39
+ pdb_feats["all_atom_positions"] = all_atom_positions.astype(np.float32)
40
+ pdb_feats["all_atom_mask"] = all_atom_mask.astype(np.float32)
41
+ pdb_feats["in_chain_residue_index"] = protein_object.residue_index.astype(np.int32)
42
+
43
+ gapped_res_indexes = []
44
+ prev_chain_index = protein_object.chain_index[0]
45
+ chain_start_res_ind = 0
46
+ for relative_res_ind, chain_index in zip(protein_object.residue_index, protein_object.chain_index):
47
+ if chain_index != prev_chain_index:
48
+ chain_start_res_ind = gapped_res_indexes[-1] + 50
49
+ prev_chain_index = chain_index
50
+ gapped_res_indexes.append(relative_res_ind + chain_start_res_ind)
51
+
52
+ pdb_feats["residue_index"] = np.array(gapped_res_indexes).astype(np.int32)
53
+ pdb_feats["chain_index"] = np.array(protein_object.chain_index).astype(np.int32)
54
+ pdb_feats["resolution"] = np.array([0.]).astype(np.float32)
55
+
56
+ return pdb_feats
57
+
58
+
59
+ def make_protein_features(protein_object: protein.Protein, description: str) -> FeatureDict:
60
+ feats = {}
61
+ aatype = protein_object.aatype
62
+ sequence = _aatype_to_str_sequence(aatype)
63
+ feats.update(
64
+ _make_sequence_features(sequence=sequence, description=description, num_res=len(protein_object.aatype))
65
+ )
66
+
67
+ feats.update(
68
+ _make_protein_structure_features(protein_object=protein_object)
69
+ )
70
+
71
+ return feats
dockformerpp/data/utils.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Common utilities for data pipeline tools."""
17
+ import contextlib
18
+ import datetime
19
+ import logging
20
+ import shutil
21
+ import tempfile
22
+ import time
23
+ from typing import Optional, Mapping, Dict
24
+
25
+ import numpy as np
26
+ import torch
27
+
28
+ FeatureDict = Dict[str, np.ndarray]
29
+ FeatureTensorDict = Dict[str, torch.Tensor]
30
+
31
+
32
+ @contextlib.contextmanager
33
+ def tmpdir_manager(base_dir: Optional[str] = None):
34
+ """Context manager that deletes a temporary directory on exit."""
35
+ tmpdir = tempfile.mkdtemp(dir=base_dir)
36
+ try:
37
+ yield tmpdir
38
+ finally:
39
+ shutil.rmtree(tmpdir, ignore_errors=True)
40
+
41
+
42
+ @contextlib.contextmanager
43
+ def timing(msg: str):
44
+ logging.info("Started %s", msg)
45
+ tic = time.perf_counter()
46
+ yield
47
+ toc = time.perf_counter()
48
+ logging.info("Finished %s in %.3f seconds", msg, toc - tic)
49
+
50
+
51
+ def to_date(s: str):
52
+ return datetime.datetime(
53
+ year=int(s[:4]), month=int(s[5:7]), day=int(s[8:10])
54
+ )
dockformerpp/model/.DS_Store ADDED
Binary file (6.15 kB). View file
 
dockformerpp/model/__init__.py ADDED
File without changes
dockformerpp/model/dropout.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ from functools import partialmethod
19
+ from typing import Union, List
20
+
21
+
22
+ class Dropout(nn.Module):
23
+ """
24
+ Implementation of dropout with the ability to share the dropout mask
25
+ along a particular dimension.
26
+
27
+ If not in training mode, this module computes the identity function.
28
+ """
29
+
30
+ def __init__(self, r: float, batch_dim: Union[int, List[int]]):
31
+ """
32
+ Args:
33
+ r:
34
+ Dropout rate
35
+ batch_dim:
36
+ Dimension(s) along which the dropout mask is shared
37
+ """
38
+ super(Dropout, self).__init__()
39
+
40
+ self.r = r
41
+ if type(batch_dim) == int:
42
+ batch_dim = [batch_dim]
43
+ self.batch_dim = batch_dim
44
+ self.dropout = nn.Dropout(self.r)
45
+
46
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
47
+ """
48
+ Args:
49
+ x:
50
+ Tensor to which dropout is applied. Can have any shape
51
+ compatible with self.batch_dim
52
+ """
53
+ shape = list(x.shape)
54
+ if self.batch_dim is not None:
55
+ for bd in self.batch_dim:
56
+ shape[bd] = 1
57
+ mask = x.new_ones(shape)
58
+ mask = self.dropout(mask)
59
+ x *= mask
60
+ return x
61
+
62
+
63
+ class DropoutRowwise(Dropout):
64
+ """
65
+ Convenience class for rowwise dropout as described in subsection
66
+ 1.11.6.
67
+ """
68
+
69
+ __init__ = partialmethod(Dropout.__init__, batch_dim=-3)
dockformerpp/model/embedders.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from functools import partial
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ from typing import Tuple, Optional
21
+
22
+ from dockformerpp.model.primitives import Linear, LayerNorm
23
+ from dockformerpp.utils.tensor_utils import add
24
+
25
+
26
+ class StructureInputEmbedder(nn.Module):
27
+ """
28
+ Embeds a subset of the input features.
29
+
30
+ Implements a merge of Algorithms 3 and Algorithm 32.
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ protein_tf_dim: int,
36
+ additional_tf_dim: int,
37
+ c_z: int,
38
+ c_m: int,
39
+ relpos_k: int,
40
+ prot_min_bin: float,
41
+ prot_max_bin: float,
42
+ prot_no_bins: int,
43
+ inf: float = 1e8,
44
+ **kwargs,
45
+ ):
46
+ """
47
+ Args:
48
+ tf_dim:
49
+ Final dimension of the target features
50
+ c_z:
51
+ Pair embedding dimension
52
+ c_m:
53
+ Single embedding dimension
54
+ relpos_k:
55
+ Window size used in relative positional encoding
56
+ """
57
+ super(StructureInputEmbedder, self).__init__()
58
+
59
+ self.tf_dim = protein_tf_dim + additional_tf_dim
60
+
61
+ self.c_z = c_z
62
+ self.c_m = c_m
63
+
64
+ self.linear_tf_z_i = Linear(self.tf_dim, c_z)
65
+ self.linear_tf_z_j = Linear(self.tf_dim, c_z)
66
+ self.linear_tf_m = Linear(self.tf_dim, c_m)
67
+
68
+ # RPE stuff
69
+ self.relpos_k = relpos_k
70
+ self.no_bins = 2 * relpos_k + 1
71
+ self.linear_relpos = Linear(self.no_bins, c_z)
72
+
73
+ # Recycling stuff
74
+ self.prot_min_bin = prot_min_bin
75
+ self.prot_max_bin = prot_max_bin
76
+ self.prot_no_bins = prot_no_bins
77
+ self.inf = inf
78
+
79
+ self.prot_recycling_linear = Linear(self.prot_no_bins + 1, self.c_z)
80
+ self.layer_norm_m = LayerNorm(self.c_m)
81
+ self.layer_norm_z = LayerNorm(self.c_z)
82
+
83
+ def relpos(self, ri: torch.Tensor):
84
+ """
85
+ Computes relative positional encodings
86
+
87
+ Implements Algorithm 4.
88
+
89
+ Args:
90
+ ri:
91
+ "residue_index" features of shape [*, N]
92
+ """
93
+ d = ri[..., None] - ri[..., None, :]
94
+ boundaries = torch.arange(
95
+ start=-self.relpos_k, end=self.relpos_k + 1, device=d.device
96
+ )
97
+ reshaped_bins = boundaries.view(((1,) * len(d.shape)) + (len(boundaries),))
98
+ d = d[..., None] - reshaped_bins
99
+ d = torch.abs(d)
100
+ d = torch.argmin(d, dim=-1)
101
+ d = nn.functional.one_hot(d, num_classes=len(boundaries)).float()
102
+ d = d.to(ri.dtype)
103
+ return self.linear_relpos(d)
104
+
105
+ def _get_binned_distogram(self, x, min_bin, max_bin, no_bins, recycling_linear, prot_distogram_mask=None):
106
+ # This squared method might become problematic in FP16 mode.
107
+ bins = torch.linspace(
108
+ min_bin,
109
+ max_bin,
110
+ no_bins,
111
+ dtype=x.dtype,
112
+ device=x.device,
113
+ requires_grad=False,
114
+ )
115
+ squared_bins = bins ** 2
116
+ upper = torch.cat(
117
+ [squared_bins[1:], squared_bins.new_tensor([self.inf])], dim=-1
118
+ )
119
+ d = torch.sum((x[..., None, :] - x[..., None, :, :]) ** 2, dim=-1, keepdims=True)
120
+
121
+ # [*, N, N, no_bins]
122
+ d = ((d > squared_bins) * (d < upper)).type(x.dtype)
123
+ # print("d shape", d.shape, d[0][0][:10])
124
+
125
+ if prot_distogram_mask is not None:
126
+ expanded_d = torch.cat([d, torch.zeros(*d.shape[:-1], 1, device=d.device)], dim=-1)
127
+
128
+ # Step 2: Create a mask where `input_positions_masked` is 0
129
+ # Use broadcasting and tensor operations directly without additional variables
130
+ input_positions_mask = (prot_distogram_mask == 1).float() # Shape [N, crop_size]
131
+ mask_i = input_positions_mask.unsqueeze(2) # Shape [N, crop_size, 1]
132
+ mask_j = input_positions_mask.unsqueeze(1) # Shape [N, 1, crop_size]
133
+
134
+ # Step 3: Combine masks for both [N, :, i, :] and [N, i, :, :]
135
+ combined_mask = mask_i + mask_j # Shape [N, crop_size, crop_size]
136
+ combined_mask = combined_mask.clamp(max=1) # Ensure binary mask
137
+
138
+ # Step 4: Apply the mask
139
+ # a. Set all but the last position in the `no_bins + 1` dimension to 0 where the mask is 1
140
+ expanded_d[..., :-1] *= (1 - combined_mask).unsqueeze(-1) # Shape [N, crop_size, crop_size, no_bins]
141
+
142
+ # print("expanded_d shape1", expanded_d.shape, expanded_d[0][0][:10])
143
+
144
+ # b. Set the last position in the `no_bins + 1` dimension to 1 where the mask is 1
145
+ expanded_d[..., -1] += combined_mask # Shape [N, crop_size, crop_size, 1]
146
+ d = expanded_d
147
+ # print("expanded_d shape2", d.shape, d[0][0][:10])
148
+
149
+ return recycling_linear(d)
150
+
151
+ def forward(
152
+ self,
153
+ token_mask: torch.Tensor,
154
+ protein_r_mask: torch.Tensor,
155
+ protein_l_mask: torch.Tensor,
156
+ target_feat: torch.Tensor,
157
+ input_positions: torch.Tensor,
158
+ residue_index: torch.Tensor,
159
+ distogram_mask: torch.Tensor,
160
+ inplace_safe: bool = False,
161
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
162
+ """
163
+ Args:
164
+ batch: Dict containing
165
+ "protein_target_feat":
166
+ Features of shape [*, N_res + N_lig_atoms, tf_dim]
167
+ "residue_index":
168
+ Features of shape [*, N_res]
169
+ input_protein_coords:
170
+ [*, N_res, 3] AF predicted C_beta coordinates supplied as input
171
+ ligand_bonds_feat:
172
+ [*, N_lig_atoms, N_lig_atoms, tf_dim] ligand bonds features
173
+ Returns:
174
+ single_emb:
175
+ [*, N_res + N_lig_atoms, C_m] single embedding
176
+ pair_emb:
177
+ [*, N_res + N_lig_atoms, N_res + N_lig_atoms, C_z] pair embedding
178
+
179
+ """
180
+ device = token_mask.device
181
+ pair_protein_r_mask = protein_r_mask[..., None] * protein_r_mask[..., None, :]
182
+ pair_protein_l_mask = protein_l_mask[..., None] * protein_l_mask[..., None, :]
183
+ intra_pair_protein_mask = pair_protein_r_mask + pair_protein_l_mask
184
+
185
+ # Single representation embedding - Algorithm 3
186
+ tf_m = self.linear_tf_m(target_feat)
187
+ tf_m = self.layer_norm_m(tf_m) # previously this happened in the do_recycle function
188
+
189
+ # Pair representation
190
+ # protein pair embedding - Algorithm 3
191
+ # [*, N_res, c_z]
192
+ tf_emb_i = self.linear_tf_z_i(target_feat)
193
+ tf_emb_j = self.linear_tf_z_j(target_feat)
194
+
195
+ pair_emb = torch.zeros(*pair_protein_r_mask.shape, self.c_z, device=device)
196
+ pair_emb = add(pair_emb, tf_emb_i[..., None, :], inplace=inplace_safe)
197
+ pair_emb = add(pair_emb, tf_emb_j[..., None, :, :], inplace=inplace_safe)
198
+
199
+ # Apply relpos
200
+ relpos = self.relpos(residue_index.type(tf_emb_i.dtype))
201
+ pair_emb += relpos * intra_pair_protein_mask[..., None]
202
+
203
+ del relpos
204
+
205
+ # before recycles, do z_norm, this previously was a part of the recycles
206
+ pair_emb = self.layer_norm_z(pair_emb)
207
+
208
+ # apply protein recycle
209
+ prot_distogram_embed = self._get_binned_distogram(input_positions, self.prot_min_bin, self.prot_max_bin,
210
+ self.prot_no_bins, self.prot_recycling_linear,
211
+ distogram_mask)
212
+
213
+ pair_emb = add(pair_emb, prot_distogram_embed * intra_pair_protein_mask.unsqueeze(-1), inplace_safe)
214
+
215
+ del prot_distogram_embed
216
+
217
+ return tf_m, pair_emb
218
+
219
+
220
+ class RecyclingEmbedder(nn.Module):
221
+ """
222
+ Embeds the output of an iteration of the model for recycling.
223
+
224
+ Implements Algorithm 32.
225
+ """
226
+ def __init__(
227
+ self,
228
+ c_m: int,
229
+ c_z: int,
230
+ min_bin: float,
231
+ max_bin: float,
232
+ no_bins: int,
233
+ inf: float = 1e8,
234
+ **kwargs,
235
+ ):
236
+ """
237
+ Args:
238
+ c_m:
239
+ Single channel dimension
240
+ c_z:
241
+ Pair embedding channel dimension
242
+ min_bin:
243
+ Smallest distogram bin (Angstroms)
244
+ max_bin:
245
+ Largest distogram bin (Angstroms)
246
+ no_bins:
247
+ Number of distogram bins
248
+ """
249
+ super(RecyclingEmbedder, self).__init__()
250
+
251
+ self.c_m = c_m
252
+ self.c_z = c_z
253
+ self.min_bin = min_bin
254
+ self.max_bin = max_bin
255
+ self.no_bins = no_bins
256
+ self.inf = inf
257
+
258
+ self.linear = Linear(self.no_bins, self.c_z)
259
+ self.layer_norm_m = LayerNorm(self.c_m)
260
+ self.layer_norm_z = LayerNorm(self.c_z)
261
+
262
+ def forward(
263
+ self,
264
+ m: torch.Tensor,
265
+ z: torch.Tensor,
266
+ x: torch.Tensor,
267
+ inplace_safe: bool = False,
268
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
269
+ """
270
+ Args:
271
+ m:
272
+ First row of the single embedding. [*, N_res, C_m]
273
+ z:
274
+ [*, N_res, N_res, C_z] pair embedding
275
+ x:
276
+ [*, N_res, 3] predicted C_beta coordinates
277
+ Returns:
278
+ m:
279
+ [*, N_res, C_m] single embedding update
280
+ z:
281
+ [*, N_res, N_res, C_z] pair embedding update
282
+ """
283
+ # [*, N, C_m]
284
+ m_update = self.layer_norm_m(m)
285
+ if(inplace_safe):
286
+ m.copy_(m_update)
287
+ m_update = m
288
+
289
+ # [*, N, N, C_z]
290
+ z_update = self.layer_norm_z(z)
291
+ if(inplace_safe):
292
+ z.copy_(z_update)
293
+ z_update = z
294
+
295
+ # This squared method might become problematic in FP16 mode.
296
+ bins = torch.linspace(
297
+ self.min_bin,
298
+ self.max_bin,
299
+ self.no_bins,
300
+ dtype=x.dtype,
301
+ device=x.device,
302
+ requires_grad=False,
303
+ )
304
+ squared_bins = bins ** 2
305
+ upper = torch.cat(
306
+ [squared_bins[1:], squared_bins.new_tensor([self.inf])], dim=-1
307
+ )
308
+ d = torch.sum(
309
+ (x[..., None, :] - x[..., None, :, :]) ** 2, dim=-1, keepdims=True
310
+ )
311
+
312
+ # [*, N, N, no_bins]
313
+ d = ((d > squared_bins) * (d < upper)).type(x.dtype)
314
+
315
+ # [*, N, N, C_z]
316
+ d = self.linear(d)
317
+ z_update = add(z_update, d, inplace_safe)
318
+
319
+ return m_update, z_update
320
+
dockformerpp/model/evoformer.py ADDED
@@ -0,0 +1,468 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import math
16
+ import sys
17
+ import torch
18
+ import torch.nn as nn
19
+ from typing import Tuple, Sequence, Optional
20
+ from functools import partial
21
+ from abc import ABC, abstractmethod
22
+
23
+ from dockformerpp.model.primitives import Linear, LayerNorm
24
+ from dockformerpp.model.dropout import DropoutRowwise
25
+ from dockformerpp.model.single_attention import SingleRowAttentionWithPairBias
26
+
27
+ from dockformerpp.model.pair_transition import PairTransition
28
+ from dockformerpp.model.triangular_attention import (
29
+ TriangleAttention,
30
+ )
31
+ from dockformerpp.model.triangular_multiplicative_update import (
32
+ TriangleMultiplicationOutgoing,
33
+ TriangleMultiplicationIncoming,
34
+ )
35
+ from dockformerpp.utils.checkpointing import checkpoint_blocks
36
+ from dockformerpp.utils.tensor_utils import add
37
+
38
+
39
+ class SingleRepTransition(nn.Module):
40
+ """
41
+ Feed-forward network applied to single representation activations after attention.
42
+
43
+ Implements Algorithm 9
44
+ """
45
+ def __init__(self, c_m, n):
46
+ """
47
+ Args:
48
+ c_m:
49
+ channel dimension
50
+ n:
51
+ Factor multiplied to c_m to obtain the hidden channel dimension
52
+ """
53
+ super(SingleRepTransition, self).__init__()
54
+
55
+ self.c_m = c_m
56
+ self.n = n
57
+
58
+ self.layer_norm = LayerNorm(self.c_m)
59
+ self.linear_1 = Linear(self.c_m, self.n * self.c_m, init="relu")
60
+ self.relu = nn.ReLU()
61
+ self.linear_2 = Linear(self.n * self.c_m, self.c_m, init="final")
62
+
63
+ def _transition(self, m, mask):
64
+ m = self.layer_norm(m)
65
+ m = self.linear_1(m)
66
+ m = self.relu(m)
67
+ m = self.linear_2(m) * mask
68
+ return m
69
+
70
+ def forward(
71
+ self,
72
+ m: torch.Tensor,
73
+ mask: Optional[torch.Tensor] = None,
74
+ ) -> torch.Tensor:
75
+ """
76
+ Args:
77
+ m:
78
+ [*, N_res, C_m] activation after attention
79
+ mask:
80
+ [*, N_res, C_m] mask
81
+ Returns:
82
+ m:
83
+ [*, N_res, C_m] activation update
84
+ """
85
+ # DISCREPANCY: DeepMind forgets to apply the mask here.
86
+ if mask is None:
87
+ mask = m.new_ones(m.shape[:-1])
88
+
89
+ mask = mask.unsqueeze(-1)
90
+
91
+ m = self._transition(m, mask)
92
+
93
+ return m
94
+
95
+
96
+ class PairStack(nn.Module):
97
+ def __init__(
98
+ self,
99
+ c_z: int,
100
+ c_hidden_mul: int,
101
+ c_hidden_pair_att: int,
102
+ no_heads_pair: int,
103
+ transition_n: int,
104
+ pair_dropout: float,
105
+ inf: float,
106
+ eps: float
107
+ ):
108
+ super(PairStack, self).__init__()
109
+
110
+ self.tri_mul_out = TriangleMultiplicationOutgoing(
111
+ c_z,
112
+ c_hidden_mul,
113
+ )
114
+ self.tri_mul_in = TriangleMultiplicationIncoming(
115
+ c_z,
116
+ c_hidden_mul,
117
+ )
118
+
119
+ self.tri_att_start = TriangleAttention(
120
+ c_z,
121
+ c_hidden_pair_att,
122
+ no_heads_pair,
123
+ inf=inf,
124
+ )
125
+ self.tri_att_end = TriangleAttention(
126
+ c_z,
127
+ c_hidden_pair_att,
128
+ no_heads_pair,
129
+ inf=inf,
130
+ )
131
+
132
+ self.pair_transition = PairTransition(
133
+ c_z,
134
+ transition_n,
135
+ )
136
+
137
+ self.ps_dropout_row_layer = DropoutRowwise(pair_dropout)
138
+
139
+ def forward(self,
140
+ z: torch.Tensor,
141
+ pair_mask: torch.Tensor,
142
+ use_lma: bool = False,
143
+ inplace_safe: bool = False,
144
+ _mask_trans: bool = True,
145
+ ) -> torch.Tensor:
146
+ # DeepMind doesn't mask these transitions in the source, so _mask_trans
147
+ # should be disabled to better approximate the exact activations of
148
+ # the original.
149
+ pair_trans_mask = pair_mask if _mask_trans else None
150
+
151
+ tmu_update = self.tri_mul_out(
152
+ z,
153
+ mask=pair_mask,
154
+ inplace_safe=inplace_safe,
155
+ _add_with_inplace=True,
156
+ )
157
+ if (not inplace_safe):
158
+ z = z + self.ps_dropout_row_layer(tmu_update)
159
+ else:
160
+ z = tmu_update
161
+
162
+ del tmu_update
163
+
164
+ tmu_update = self.tri_mul_in(
165
+ z,
166
+ mask=pair_mask,
167
+ inplace_safe=inplace_safe,
168
+ _add_with_inplace=True,
169
+ )
170
+ if (not inplace_safe):
171
+ z = z + self.ps_dropout_row_layer(tmu_update)
172
+ else:
173
+ z = tmu_update
174
+
175
+ del tmu_update
176
+
177
+ z = add(z,
178
+ self.ps_dropout_row_layer(
179
+ self.tri_att_start(
180
+ z,
181
+ mask=pair_mask,
182
+ use_memory_efficient_kernel=False,
183
+ use_lma=use_lma,
184
+ )
185
+ ),
186
+ inplace=inplace_safe,
187
+ )
188
+
189
+ z = z.transpose(-2, -3)
190
+ if (inplace_safe):
191
+ z = z.contiguous()
192
+
193
+ z = add(z,
194
+ self.ps_dropout_row_layer(
195
+ self.tri_att_end(
196
+ z,
197
+ mask=pair_mask.transpose(-1, -2),
198
+ use_memory_efficient_kernel=False,
199
+ use_lma=use_lma,
200
+ )
201
+ ),
202
+ inplace=inplace_safe,
203
+ )
204
+
205
+ z = z.transpose(-2, -3)
206
+ if (inplace_safe):
207
+ z = z.contiguous()
208
+
209
+ z = add(z,
210
+ self.pair_transition(
211
+ z, mask=pair_trans_mask,
212
+ ),
213
+ inplace=inplace_safe,
214
+ )
215
+
216
+ return z
217
+
218
+
219
+ class EvoformerBlock(nn.Module, ABC):
220
+ def __init__(self,
221
+ c_m: int,
222
+ c_z: int,
223
+ c_hidden_single_att: int,
224
+ c_hidden_mul: int,
225
+ c_hidden_pair_att: int,
226
+ no_heads_single: int,
227
+ no_heads_pair: int,
228
+ transition_n: int,
229
+ single_dropout: float,
230
+ pair_dropout: float,
231
+ inf: float,
232
+ eps: float,
233
+ ):
234
+ super(EvoformerBlock, self).__init__()
235
+
236
+ self.single_att_row = SingleRowAttentionWithPairBias(
237
+ c_m=c_m,
238
+ c_z=c_z,
239
+ c_hidden=c_hidden_single_att,
240
+ no_heads=no_heads_single,
241
+ inf=inf,
242
+ )
243
+
244
+ self.single_dropout_layer = DropoutRowwise(single_dropout)
245
+
246
+ self.single_transition = SingleRepTransition(
247
+ c_m=c_m,
248
+ n=transition_n,
249
+ )
250
+
251
+ self.pair_stack = PairStack(
252
+ c_z=c_z,
253
+ c_hidden_mul=c_hidden_mul,
254
+ c_hidden_pair_att=c_hidden_pair_att,
255
+ no_heads_pair=no_heads_pair,
256
+ transition_n=transition_n,
257
+ pair_dropout=pair_dropout,
258
+ inf=inf,
259
+ eps=eps
260
+ )
261
+
262
+ def forward(self,
263
+ m: Optional[torch.Tensor],
264
+ z: Optional[torch.Tensor],
265
+ single_mask: torch.Tensor,
266
+ pair_mask: torch.Tensor,
267
+ use_lma: bool = False,
268
+ inplace_safe: bool = False,
269
+ _mask_trans: bool = True,
270
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
271
+
272
+ single_trans_mask = single_mask if _mask_trans else None
273
+
274
+ input_tensors = [m, z]
275
+
276
+ m, z = input_tensors
277
+
278
+ z = self.pair_stack(
279
+ z=z,
280
+ pair_mask=pair_mask,
281
+ use_lma=use_lma,
282
+ inplace_safe=inplace_safe,
283
+ _mask_trans=_mask_trans,
284
+ )
285
+
286
+ m = add(m,
287
+ self.single_dropout_layer(
288
+ self.single_att_row(
289
+ m,
290
+ z=z,
291
+ mask=single_mask,
292
+ use_memory_efficient_kernel=False,
293
+ use_lma=use_lma,
294
+ )
295
+ ),
296
+ inplace=inplace_safe,
297
+ )
298
+
299
+ m = add(m, self.single_transition(m, mask=single_mask), inplace=inplace_safe)
300
+
301
+ return m, z
302
+
303
+
304
+ class EvoformerStack(nn.Module):
305
+ """
306
+ Main Evoformer trunk.
307
+
308
+ Implements Algorithm 6.
309
+ """
310
+
311
+ def __init__(
312
+ self,
313
+ c_m: int,
314
+ c_z: int,
315
+ c_hidden_single_att: int,
316
+ c_hidden_mul: int,
317
+ c_hidden_pair_att: int,
318
+ c_s: int,
319
+ no_heads_single: int,
320
+ no_heads_pair: int,
321
+ no_blocks: int,
322
+ transition_n: int,
323
+ single_dropout: float,
324
+ pair_dropout: float,
325
+ blocks_per_ckpt: int,
326
+ inf: float,
327
+ eps: float,
328
+ clear_cache_between_blocks: bool = False,
329
+ **kwargs,
330
+ ):
331
+ """
332
+ Args:
333
+ c_m:
334
+ single channel dimension
335
+ c_z:
336
+ Pair channel dimension
337
+ c_hidden_single_att:
338
+ Hidden dimension in single representation attention
339
+ c_hidden_mul:
340
+ Hidden dimension in multiplicative updates
341
+ c_hidden_pair_att:
342
+ Hidden dimension in triangular attention
343
+ c_s:
344
+ Channel dimension of the output "single" embedding
345
+ no_heads_single:
346
+ Number of heads used for single attention
347
+ no_heads_pair:
348
+ Number of heads used for pair attention
349
+ no_blocks:
350
+ Number of Evoformer blocks in the stack
351
+ transition_n:
352
+ Factor by which to multiply c_m to obtain the SingleTransition
353
+ hidden dimension
354
+ single_dropout:
355
+ Dropout rate for single activations
356
+ pair_dropout:
357
+ Dropout used for pair activations
358
+ blocks_per_ckpt:
359
+ Number of Evoformer blocks in each activation checkpoint
360
+ clear_cache_between_blocks:
361
+ Whether to clear CUDA's GPU memory cache between blocks of the
362
+ stack. Slows down each block but can reduce fragmentation
363
+ """
364
+ super(EvoformerStack, self).__init__()
365
+
366
+ self.blocks_per_ckpt = blocks_per_ckpt
367
+ self.clear_cache_between_blocks = clear_cache_between_blocks
368
+
369
+ self.blocks = nn.ModuleList()
370
+
371
+ for _ in range(no_blocks):
372
+ block = EvoformerBlock(
373
+ c_m=c_m,
374
+ c_z=c_z,
375
+ c_hidden_single_att=c_hidden_single_att,
376
+ c_hidden_mul=c_hidden_mul,
377
+ c_hidden_pair_att=c_hidden_pair_att,
378
+ no_heads_single=no_heads_single,
379
+ no_heads_pair=no_heads_pair,
380
+ transition_n=transition_n,
381
+ single_dropout=single_dropout,
382
+ pair_dropout=pair_dropout,
383
+ inf=inf,
384
+ eps=eps,
385
+ )
386
+ self.blocks.append(block)
387
+
388
+ self.linear = Linear(c_m, c_s)
389
+
390
+ def _prep_blocks(self,
391
+ use_lma: bool,
392
+ single_mask: Optional[torch.Tensor],
393
+ pair_mask: Optional[torch.Tensor],
394
+ inplace_safe: bool,
395
+ _mask_trans: bool,
396
+ ):
397
+ blocks = [
398
+ partial(
399
+ b,
400
+ single_mask=single_mask,
401
+ pair_mask=pair_mask,
402
+ use_lma=use_lma,
403
+ inplace_safe=inplace_safe,
404
+ _mask_trans=_mask_trans,
405
+ )
406
+ for b in self.blocks
407
+ ]
408
+
409
+ if self.clear_cache_between_blocks:
410
+ def block_with_cache_clear(block, *args, **kwargs):
411
+ torch.cuda.empty_cache()
412
+ return block(*args, **kwargs)
413
+
414
+ blocks = [partial(block_with_cache_clear, b) for b in blocks]
415
+
416
+ return blocks
417
+
418
+ def forward(self,
419
+ m: torch.Tensor,
420
+ z: torch.Tensor,
421
+ single_mask: torch.Tensor,
422
+ pair_mask: torch.Tensor,
423
+ use_lma: bool = False,
424
+ inplace_safe: bool = False,
425
+ _mask_trans: bool = True,
426
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
427
+ """
428
+ Args:
429
+ m:
430
+ [*, N_res, C_m] single embedding
431
+ z:
432
+ [*, N_res, N_res, C_z] pair embedding
433
+ single_mask:
434
+ [*, N_res] single mask
435
+ pair_mask:
436
+ [*, N_res, N_res] pair mask
437
+ use_lma:
438
+ Whether to use low-memory attention during inference.
439
+
440
+ Returns:
441
+ m:
442
+ [*, N_res, C_m] single embedding
443
+ z:
444
+ [*, N_res, N_res, C_z] pair embedding
445
+ s:
446
+ [*, N_res, C_s] single embedding after linear layer
447
+ """
448
+ blocks = self._prep_blocks(
449
+ use_lma=use_lma,
450
+ single_mask=single_mask,
451
+ pair_mask=pair_mask,
452
+ inplace_safe=inplace_safe,
453
+ _mask_trans=_mask_trans,
454
+ )
455
+
456
+ blocks_per_ckpt = self.blocks_per_ckpt
457
+ if(not torch.is_grad_enabled()):
458
+ blocks_per_ckpt = None
459
+
460
+ m, z = checkpoint_blocks(
461
+ blocks,
462
+ args=(m, z),
463
+ blocks_per_ckpt=blocks_per_ckpt,
464
+ )
465
+
466
+ s = self.linear(m)
467
+
468
+ return m, z, s
dockformerpp/model/heads.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ from torch.nn import Parameter
19
+
20
+ from dockformerpp.model.primitives import Linear, LayerNorm
21
+ from dockformerpp.utils.loss import (
22
+ compute_plddt,
23
+ compute_tm,
24
+ compute_predicted_aligned_error,
25
+ )
26
+ from dockformerpp.utils.precision_utils import is_fp16_enabled
27
+
28
+
29
+ class AuxiliaryHeads(nn.Module):
30
+ def __init__(self, config):
31
+ super(AuxiliaryHeads, self).__init__()
32
+
33
+ self.plddt = PerResidueLDDTCaPredictor(
34
+ **config["lddt"],
35
+ )
36
+
37
+ self.distogram = DistogramHead(
38
+ **config["distogram"],
39
+ )
40
+
41
+ self.affinity_2d = Affinity2DPredictor(
42
+ **config["affinity_2d"],
43
+ )
44
+
45
+ self.affinity_cls = AffinityClsTokenPredictor(
46
+ **config["affinity_cls"],
47
+ )
48
+
49
+ self.binding_site = BindingSitePredictor(
50
+ **config["binding_site"],
51
+ )
52
+
53
+ self.inter_contact = InterContactHead(
54
+ **config["inter_contact"],
55
+ )
56
+
57
+ self.config = config
58
+
59
+ def forward(self, outputs, inter_mask, affinity_mask):
60
+ aux_out = {}
61
+ lddt_logits = self.plddt(outputs["sm"]["single"])
62
+ aux_out["lddt_logits"] = lddt_logits
63
+
64
+ # Required for relaxation later on
65
+ aux_out["plddt"] = compute_plddt(lddt_logits)
66
+
67
+ distogram_logits = self.distogram(outputs["pair"])
68
+ aux_out["distogram_logits"] = distogram_logits
69
+
70
+ aux_out["inter_contact_logits"] = self.inter_contact(outputs["single"], outputs["pair"])
71
+
72
+ aux_out["affinity_2d_logits"] = self.affinity_2d(outputs["pair"], aux_out["inter_contact_logits"], inter_mask)
73
+
74
+ aux_out["affinity_cls_logits"] = self.affinity_cls(outputs["single"], affinity_mask)
75
+
76
+ aux_out["binding_site_logits"] = self.binding_site(outputs["single"])
77
+
78
+ return aux_out
79
+
80
+
81
+ class Affinity2DPredictor(nn.Module):
82
+ def __init__(self, c_z, num_bins):
83
+ super(Affinity2DPredictor, self).__init__()
84
+
85
+ self.c_z = c_z
86
+
87
+ self.weight_linear = Linear(self.c_z + 1, 1)
88
+ self.embed_linear = Linear(self.c_z, self.c_z)
89
+ self.bins_linear = Linear(self.c_z, num_bins)
90
+
91
+ def forward(self, z, inter_contacts_logits, inter_pair_mask):
92
+ z_with_inter_contacts = torch.cat((z, inter_contacts_logits), dim=-1) # [*, N, N, c_z + 1]
93
+ weights = self.weight_linear(z_with_inter_contacts) # [*, N, N, 1]
94
+
95
+ x = self.embed_linear(z) # [*, N, N, c_z]
96
+ batch_size, N, M, _ = x.shape
97
+
98
+ flat_weights = weights.reshape(batch_size, N*M, -1) # [*, N*M, 1]
99
+ flat_x = x.reshape(batch_size, N*M, -1) # [*, N*M, c_z]
100
+ flat_inter_pair_mask = inter_pair_mask.reshape(batch_size, N*M, 1)
101
+
102
+ flat_weights = flat_weights.masked_fill(~(flat_inter_pair_mask.bool()), float('-inf')) # [*, N*N, 1]
103
+ flat_weights = torch.nn.functional.softmax(flat_weights, dim=1) # [*, N*N, 1]
104
+ flat_weights = torch.nan_to_num(flat_weights, nan=0.0) # [*, N*N, 1]
105
+ weighted_sum = torch.sum((flat_weights * flat_x).reshape(batch_size, N*M, -1), dim=1) # [*, c_z]
106
+
107
+ return self.bins_linear(weighted_sum)
108
+
109
+
110
+ class AffinityClsTokenPredictor(nn.Module):
111
+ def __init__(self, c_s, num_bins, **kwargs):
112
+ super(AffinityClsTokenPredictor, self).__init__()
113
+
114
+ self.c_s = c_s
115
+ self.linear = Linear(self.c_s, num_bins, init="final")
116
+
117
+ def forward(self, s, affinity_mask):
118
+ affinity_tokens = (s * affinity_mask.unsqueeze(-1)).sum(dim=1)
119
+ return self.linear(affinity_tokens)
120
+
121
+
122
+ class BindingSitePredictor(nn.Module):
123
+ def __init__(self, c_s, c_out, **kwargs):
124
+ super(BindingSitePredictor, self).__init__()
125
+
126
+ self.c_s = c_s
127
+ self.c_out = c_out
128
+
129
+ self.linear = Linear(self.c_s, self.c_out, init="final")
130
+
131
+ def forward(self, s):
132
+ # [*, N, C_out]
133
+ return self.linear(s)
134
+
135
+
136
+ class InterContactHead(nn.Module):
137
+ def __init__(self, c_s, c_z, c_out, **kwargs):
138
+ """
139
+ Args:
140
+ c_z:
141
+ Input channel dimension
142
+ c_out:
143
+ Number of bins, but since boolean should be 1
144
+ """
145
+ super(InterContactHead, self).__init__()
146
+
147
+ self.c_s = c_s
148
+ self.c_z = c_z
149
+ self.c_out = c_out
150
+
151
+ self.linear = Linear(2 * self.c_s + self.c_z, self.c_out, init="final")
152
+
153
+ def forward(self, s, z): # [*, N, N, C_z]
154
+ # [*, N, N, no_bins]
155
+ batch_size, n, s_dim = s.shape
156
+
157
+ s_i = s.unsqueeze(2).expand(batch_size, n, n, s_dim)
158
+ s_j = s.unsqueeze(1).expand(batch_size, n, n, s_dim)
159
+ joined = torch.cat((s_i, s_j, z), dim=-1)
160
+
161
+ logits = self.linear(joined)
162
+
163
+ return logits
164
+
165
+
166
+ class PerResidueLDDTCaPredictor(nn.Module):
167
+ def __init__(self, no_bins, c_in, c_hidden):
168
+ super(PerResidueLDDTCaPredictor, self).__init__()
169
+
170
+ self.no_bins = no_bins
171
+ self.c_in = c_in
172
+ self.c_hidden = c_hidden
173
+
174
+ self.layer_norm = LayerNorm(self.c_in)
175
+
176
+ self.linear_1 = Linear(self.c_in, self.c_hidden, init="relu")
177
+ self.linear_2 = Linear(self.c_hidden, self.c_hidden, init="relu")
178
+ self.linear_3 = Linear(self.c_hidden, self.no_bins, init="final")
179
+
180
+ self.relu = nn.ReLU()
181
+
182
+ def forward(self, s):
183
+ s = self.layer_norm(s)
184
+ s = self.linear_1(s)
185
+ s = self.relu(s)
186
+ s = self.linear_2(s)
187
+ s = self.relu(s)
188
+ s = self.linear_3(s)
189
+
190
+ return s
191
+
192
+
193
+ class DistogramHead(nn.Module):
194
+ """
195
+ Computes a distogram probability distribution.
196
+
197
+ For use in computation of distogram loss, subsection 1.9.8
198
+ """
199
+
200
+ def __init__(self, c_z, no_bins, **kwargs):
201
+ """
202
+ Args:
203
+ c_z:
204
+ Input channel dimension
205
+ no_bins:
206
+ Number of distogram bins
207
+ """
208
+ super(DistogramHead, self).__init__()
209
+
210
+ self.c_z = c_z
211
+ self.no_bins = no_bins
212
+
213
+ self.linear = Linear(self.c_z, self.no_bins, init="final")
214
+
215
+ def _forward(self, z): # [*, N, N, C_z]
216
+ """
217
+ Args:
218
+ z:
219
+ [*, N_res, N_res, C_z] pair embedding
220
+ Returns:
221
+ [*, N, N, no_bins] distogram probability distribution
222
+ """
223
+ # [*, N, N, no_bins]
224
+ logits = self.linear(z)
225
+ logits = logits + logits.transpose(-2, -3)
226
+ return logits
227
+
228
+ def forward(self, z):
229
+ if(is_fp16_enabled()):
230
+ with torch.cuda.amp.autocast(enabled=False):
231
+ return self._forward(z.float())
232
+ else:
233
+ return self._forward(z)
dockformerpp/model/model.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from functools import partial
16
+ import weakref
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+ from dockformerpp.utils.tensor_utils import masked_mean
22
+ from dockformerpp.model.embedders import (
23
+ StructureInputEmbedder,
24
+ RecyclingEmbedder,
25
+ )
26
+ from dockformerpp.model.evoformer import EvoformerStack
27
+ from dockformerpp.model.heads import AuxiliaryHeads
28
+ from dockformerpp.model.structure_module import StructureModule
29
+ import dockformerpp.utils.residue_constants as residue_constants
30
+ from dockformerpp.utils.feats import (
31
+ pseudo_beta_fn,
32
+ atom14_to_atom37,
33
+ )
34
+ from dockformerpp.utils.tensor_utils import (
35
+ add,
36
+ tensor_tree_map,
37
+ )
38
+
39
+
40
+ class AlphaFold(nn.Module):
41
+ """
42
+ Alphafold 2.
43
+
44
+ Implements Algorithm 2 (but with training).
45
+ """
46
+
47
+ def __init__(self, config):
48
+ """
49
+ Args:
50
+ config:
51
+ A dict-like config object (like the one in config.py)
52
+ """
53
+ super(AlphaFold, self).__init__()
54
+
55
+ self.globals = config.globals
56
+ self.config = config.model
57
+
58
+ # Main trunk + structure module
59
+ self.input_embedder = StructureInputEmbedder(
60
+ **self.config["structure_input_embedder"],
61
+ )
62
+
63
+ self.recycling_embedder = RecyclingEmbedder(
64
+ **self.config["recycling_embedder"],
65
+ )
66
+
67
+ self.evoformer = EvoformerStack(
68
+ **self.config["evoformer_stack"],
69
+ )
70
+
71
+ self.structure_module = StructureModule(
72
+ **self.config["structure_module"],
73
+ )
74
+ self.aux_heads = AuxiliaryHeads(
75
+ self.config["heads"],
76
+ )
77
+
78
+ def tolerance_reached(self, prev_pos, next_pos, mask, eps=1e-8) -> bool:
79
+ """
80
+ Early stopping criteria based on criteria used in
81
+ AF2Complex: https://www.nature.com/articles/s41467-022-29394-2
82
+ Args:
83
+ prev_pos: Previous atom positions in atom37/14 representation
84
+ next_pos: Current atom positions in atom37/14 representation
85
+ mask: 1-D sequence mask
86
+ eps: Epsilon used in square root calculation
87
+ Returns:
88
+ Whether to stop recycling early based on the desired tolerance.
89
+ """
90
+
91
+ def distances(points):
92
+ """Compute all pairwise distances for a set of points."""
93
+ d = points[..., None, :] - points[..., None, :, :]
94
+ return torch.sqrt(torch.sum(d ** 2, dim=-1))
95
+
96
+ if self.config.recycle_early_stop_tolerance < 0:
97
+ return False
98
+
99
+ ca_idx = residue_constants.atom_order['CA']
100
+ sq_diff = (distances(prev_pos[..., ca_idx, :]) - distances(next_pos[..., ca_idx, :])) ** 2
101
+ mask = mask[..., None] * mask[..., None, :]
102
+ sq_diff = masked_mean(mask=mask, value=sq_diff, dim=list(range(len(mask.shape))))
103
+ diff = torch.sqrt(sq_diff + eps).item()
104
+ return diff <= self.config.recycle_early_stop_tolerance
105
+
106
+ def iteration(self, feats, prevs, _recycle=True):
107
+ # Primary output dictionary
108
+ outputs = {}
109
+
110
+ # This needs to be done manually for DeepSpeed's sake
111
+ dtype = next(self.parameters()).dtype
112
+ for k in feats:
113
+ if feats[k].dtype == torch.float32:
114
+ feats[k] = feats[k].to(dtype=dtype)
115
+
116
+ # Grab some data about the input
117
+ batch_dims, n_total = feats["token_mask"].shape
118
+ device = feats["token_mask"].device
119
+
120
+ print("doing sample of size", feats["token_mask"].shape,
121
+ feats["protein_r_mask"].sum(dim=1), feats["protein_l_mask"].sum(dim=1))
122
+
123
+ # Controls whether the model uses in-place operations throughout
124
+ # The dual condition accounts for activation checkpoints
125
+ # inplace_safe = not (self.training or torch.is_grad_enabled())
126
+ inplace_safe = False # so we don't need attn_core_inplace_cuda
127
+
128
+ # Prep some features
129
+ token_mask = feats["token_mask"]
130
+ pair_mask = token_mask[..., None] * token_mask[..., None, :]
131
+
132
+ # Initialize the single and pair representations
133
+ # m: [*, 1, n_total, C_m]
134
+ # z: [*, n_total, n_total, C_z]
135
+ m, z = self.input_embedder(
136
+ feats["token_mask"],
137
+ feats["protein_r_mask"],
138
+ feats["protein_l_mask"],
139
+ feats["target_feat"],
140
+ feats["input_positions"],
141
+ feats["residue_index"],
142
+ feats["distogram_mask"],
143
+ inplace_safe=inplace_safe,
144
+ )
145
+
146
+ # Unpack the recycling embeddings. Removing them from the list allows
147
+ # them to be freed further down in this function, saving memory
148
+ m_1_prev, z_prev, x_prev = reversed([prevs.pop() for _ in range(3)])
149
+
150
+ # Initialize the recycling embeddings, if needs be
151
+ if None in [m_1_prev, z_prev, x_prev]:
152
+ # [*, N, C_m]
153
+ m_1_prev = m.new_zeros(
154
+ (batch_dims, n_total, self.config.structure_input_embedder.c_m),
155
+ requires_grad=False,
156
+ )
157
+
158
+ # [*, N, N, C_z]
159
+ z_prev = z.new_zeros(
160
+ (batch_dims, n_total, n_total, self.config.structure_input_embedder.c_z),
161
+ requires_grad=False,
162
+ )
163
+
164
+ # [*, N, 3]
165
+ x_prev = z.new_zeros(
166
+ (batch_dims, n_total, residue_constants.atom_type_num, 3),
167
+ requires_grad=False,
168
+ )
169
+
170
+ # shape == [1, n_total, 37, 3]
171
+ pseudo_beta_or_lig_x_prev = pseudo_beta_fn(feats["aatype"], x_prev, None).to(dtype=z.dtype)
172
+
173
+ # m_1_prev_emb: [*, N, C_m]
174
+ # z_prev_emb: [*, N, N, C_z]
175
+ m_1_prev_emb, z_prev_emb = self.recycling_embedder(
176
+ m_1_prev,
177
+ z_prev,
178
+ pseudo_beta_or_lig_x_prev,
179
+ inplace_safe=inplace_safe,
180
+ )
181
+
182
+ del pseudo_beta_or_lig_x_prev
183
+
184
+ # [*, S_c, N, C_m]
185
+ m += m_1_prev_emb
186
+
187
+ # [*, N, N, C_z]
188
+ z = add(z, z_prev_emb, inplace=inplace_safe)
189
+
190
+ # Deletions like these become significant for inference with large N,
191
+ # where they free unused tensors and remove references to others such
192
+ # that they can be offloaded later
193
+ del m_1_prev, z_prev, m_1_prev_emb, z_prev_emb
194
+
195
+ # Run single + pair embeddings through the trunk of the network
196
+ # m: [*, N, C_m]
197
+ # z: [*, N, N, C_z]
198
+ # s: [*, N, C_s]
199
+ m, z, s = self.evoformer(
200
+ m,
201
+ z,
202
+ single_mask=token_mask.to(dtype=m.dtype),
203
+ pair_mask=pair_mask.to(dtype=z.dtype),
204
+ use_lma=self.globals.use_lma,
205
+ inplace_safe=inplace_safe,
206
+ _mask_trans=self.config._mask_trans,
207
+ )
208
+
209
+ outputs["pair"] = z
210
+ outputs["single"] = s
211
+
212
+ del z
213
+
214
+ # Predict 3D structure
215
+ outputs["sm"] = self.structure_module(
216
+ outputs,
217
+ feats["aatype"],
218
+ mask=token_mask.to(dtype=s.dtype),
219
+ inplace_safe=inplace_safe,
220
+ )
221
+ outputs["final_atom_positions"] = atom14_to_atom37(
222
+ outputs["sm"]["positions"][-1], feats
223
+ )
224
+ outputs["final_atom_mask"] = feats["atom37_atom_exists"]
225
+
226
+ # Save embeddings for use during the next recycling iteration
227
+
228
+ # [*, N, C_m]
229
+ m_1_prev = m[..., 0, :, :]
230
+
231
+ # [*, N, N, C_z]
232
+ z_prev = outputs["pair"]
233
+
234
+ # TODO bshor: early stop depends on is_multimer, but I don't think it must
235
+ early_stop = False
236
+ # if self.globals.is_multimer:
237
+ # early_stop = self.tolerance_reached(x_prev, outputs["final_atom_positions"], seq_mask)
238
+
239
+ del x_prev
240
+
241
+ # [*, N, 3]
242
+ x_prev = outputs["final_atom_positions"]
243
+
244
+ return outputs, m_1_prev, z_prev, x_prev, early_stop
245
+
246
+ def forward(self, batch):
247
+ """
248
+ Args:
249
+ batch:
250
+ Dictionary of arguments outlined in Algorithm 2. Keys must
251
+ include the official names of the features in the
252
+ supplement subsection 1.2.9.
253
+
254
+ The final dimension of each input must have length equal to
255
+ the number of recycling iterations.
256
+
257
+ Features (without the recycling dimension):
258
+
259
+ "aatype" ([*, N_res]):
260
+ Contrary to the supplement, this tensor of residue
261
+ indices is not one-hot.
262
+ "protein_target_feat" ([*, N_res, C_tf])
263
+ One-hot encoding of the target sequence. C_tf is
264
+ config.model.input_embedder.tf_dim.
265
+ "residue_index" ([*, N_res])
266
+ Tensor whose final dimension consists of
267
+ consecutive indices from 0 to N_res.
268
+ "token_mask" ([*, N_token])
269
+ 1-D token mask
270
+ "pair_mask" ([*, N_token, N_token])
271
+ 2-D pair mask
272
+ """
273
+ # Initialize recycling embeddings
274
+ m_1_prev, z_prev, x_prev = None, None, None
275
+ prevs = [m_1_prev, z_prev, x_prev]
276
+
277
+ is_grad_enabled = torch.is_grad_enabled()
278
+
279
+ # Main recycling loop
280
+ num_iters = batch["aatype"].shape[-1]
281
+ early_stop = False
282
+ num_recycles = 0
283
+ for cycle_no in range(num_iters):
284
+ # Select the features for the current recycling cycle
285
+ fetch_cur_batch = lambda t: t[..., cycle_no]
286
+ feats = tensor_tree_map(fetch_cur_batch, batch)
287
+
288
+ # Enable grad iff we're training and it's the final recycling layer
289
+ is_final_iter = cycle_no == (num_iters - 1) or early_stop
290
+ with torch.set_grad_enabled(is_grad_enabled and is_final_iter):
291
+ if is_final_iter:
292
+ # Sidestep AMP bug (PyTorch issue #65766)
293
+ if torch.is_autocast_enabled():
294
+ torch.clear_autocast_cache()
295
+
296
+ # Run the next iteration of the model
297
+ outputs, m_1_prev, z_prev, x_prev, early_stop = self.iteration(
298
+ feats,
299
+ prevs,
300
+ _recycle=(num_iters > 1)
301
+ )
302
+
303
+ num_recycles += 1
304
+
305
+ if not is_final_iter:
306
+ del outputs
307
+ prevs = [m_1_prev, z_prev, x_prev]
308
+ del m_1_prev, z_prev, x_prev
309
+ else:
310
+ break
311
+
312
+ outputs["num_recycles"] = torch.tensor(num_recycles, device=feats["aatype"].device)
313
+
314
+ # Run auxiliary heads, remove the recycling dimension batch properties
315
+ outputs.update(self.aux_heads(outputs, batch["inter_pair_mask"][..., 0], batch["affinity_mask"][..., 0]))
316
+
317
+ return outputs
dockformerpp/model/pair_transition.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from typing import Optional
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ from dockformerpp.model.primitives import Linear, LayerNorm
21
+
22
+
23
+ class PairTransition(nn.Module):
24
+ """
25
+ Implements Algorithm 15.
26
+ """
27
+
28
+ def __init__(self, c_z, n):
29
+ """
30
+ Args:
31
+ c_z:
32
+ Pair transition channel dimension
33
+ n:
34
+ Factor by which c_z is multiplied to obtain hidden channel
35
+ dimension
36
+ """
37
+ super(PairTransition, self).__init__()
38
+
39
+ self.c_z = c_z
40
+ self.n = n
41
+
42
+ self.layer_norm = LayerNorm(self.c_z)
43
+ self.linear_1 = Linear(self.c_z, self.n * self.c_z, init="relu")
44
+ self.relu = nn.ReLU()
45
+ self.linear_2 = Linear(self.n * self.c_z, c_z, init="final")
46
+
47
+ def _transition(self, z, mask):
48
+ # [*, N_res, N_res, C_z]
49
+ z = self.layer_norm(z)
50
+
51
+ # [*, N_res, N_res, C_hidden]
52
+ z = self.linear_1(z)
53
+ z = self.relu(z)
54
+
55
+ # [*, N_res, N_res, C_z]
56
+ z = self.linear_2(z)
57
+ z = z * mask
58
+
59
+ return z
60
+
61
+ def forward(self,
62
+ z: torch.Tensor,
63
+ mask: Optional[torch.Tensor] = None,
64
+ ) -> torch.Tensor:
65
+ """
66
+ Args:
67
+ z:
68
+ [*, N_res, N_res, C_z] pair embedding
69
+ Returns:
70
+ [*, N_res, N_res, C_z] pair embedding update
71
+ """
72
+ # DISCREPANCY: DeepMind forgets to apply the mask in this module.
73
+ if mask is None:
74
+ mask = z.new_ones(z.shape[:-1])
75
+
76
+ # [*, N_res, N_res, 1]
77
+ mask = mask.unsqueeze(-1)
78
+
79
+ z = self._transition(z=z, mask=mask)
80
+
81
+ return z
dockformerpp/model/primitives.py ADDED
@@ -0,0 +1,598 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import importlib
16
+ import math
17
+ from typing import Optional, Callable, List, Tuple
18
+ import numpy as np
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.utils.checkpoint
23
+ from scipy.stats import truncnorm
24
+
25
+ from dockformerpp.utils.kernel.attention_core import attention_core
26
+ from dockformerpp.utils.precision_utils import is_fp16_enabled
27
+ from dockformerpp.utils.tensor_utils import (
28
+ permute_final_dims,
29
+ flatten_final_dims,
30
+ )
31
+
32
+
33
+ # Suited for 40gb GPU
34
+ # DEFAULT_LMA_Q_CHUNK_SIZE = 1024
35
+ # DEFAULT_LMA_KV_CHUNK_SIZE = 4096
36
+ # Suited for 10gb GPU
37
+ DEFAULT_LMA_Q_CHUNK_SIZE = 64
38
+ DEFAULT_LMA_KV_CHUNK_SIZE = 256
39
+
40
+
41
+ def _prod(nums):
42
+ out = 1
43
+ for n in nums:
44
+ out = out * n
45
+ return out
46
+
47
+
48
+ def _calculate_fan(linear_weight_shape, fan="fan_in"):
49
+ fan_out, fan_in = linear_weight_shape
50
+
51
+ if fan == "fan_in":
52
+ f = fan_in
53
+ elif fan == "fan_out":
54
+ f = fan_out
55
+ elif fan == "fan_avg":
56
+ f = (fan_in + fan_out) / 2
57
+ else:
58
+ raise ValueError("Invalid fan option")
59
+
60
+ return f
61
+
62
+
63
+ def trunc_normal_init_(weights, scale=1.0, fan="fan_in"):
64
+ shape = weights.shape
65
+ f = _calculate_fan(shape, fan)
66
+ scale = scale / max(1, f)
67
+ a = -2
68
+ b = 2
69
+ std = math.sqrt(scale) / truncnorm.std(a=a, b=b, loc=0, scale=1)
70
+ size = _prod(shape)
71
+ samples = truncnorm.rvs(a=a, b=b, loc=0, scale=std, size=size)
72
+ samples = np.reshape(samples, shape)
73
+ with torch.no_grad():
74
+ weights.copy_(torch.tensor(samples, device=weights.device))
75
+
76
+
77
+ def lecun_normal_init_(weights):
78
+ trunc_normal_init_(weights, scale=1.0)
79
+
80
+
81
+ def he_normal_init_(weights):
82
+ trunc_normal_init_(weights, scale=2.0)
83
+
84
+
85
+ def glorot_uniform_init_(weights):
86
+ nn.init.xavier_uniform_(weights, gain=1)
87
+
88
+
89
+ def final_init_(weights):
90
+ with torch.no_grad():
91
+ weights.fill_(0.0)
92
+
93
+
94
+ def gating_init_(weights):
95
+ with torch.no_grad():
96
+ weights.fill_(0.0)
97
+
98
+
99
+ def normal_init_(weights):
100
+ torch.nn.init.kaiming_normal_(weights, nonlinearity="linear")
101
+
102
+
103
+ def ipa_point_weights_init_(weights):
104
+ with torch.no_grad():
105
+ softplus_inverse_1 = 0.541324854612918
106
+ weights.fill_(softplus_inverse_1)
107
+
108
+
109
+ class Linear(nn.Linear):
110
+ """
111
+ A Linear layer with built-in nonstandard initializations. Called just
112
+ like torch.nn.Linear.
113
+
114
+ Implements the initializers in 1.11.4, plus some additional ones found
115
+ in the code.
116
+ """
117
+
118
+ def __init__(
119
+ self,
120
+ in_dim: int,
121
+ out_dim: int,
122
+ bias: bool = True,
123
+ init: str = "default",
124
+ init_fn: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None,
125
+ precision=None
126
+ ):
127
+ """
128
+ Args:
129
+ in_dim:
130
+ The final dimension of inputs to the layer
131
+ out_dim:
132
+ The final dimension of layer outputs
133
+ bias:
134
+ Whether to learn an additive bias. True by default
135
+ init:
136
+ The initializer to use. Choose from:
137
+
138
+ "default": LeCun fan-in truncated normal initialization
139
+ "relu": He initialization w/ truncated normal distribution
140
+ "glorot": Fan-average Glorot uniform initialization
141
+ "gating": Weights=0, Bias=1
142
+ "normal": Normal initialization with std=1/sqrt(fan_in)
143
+ "final": Weights=0, Bias=0
144
+
145
+ Overridden by init_fn if the latter is not None.
146
+ init_fn:
147
+ A custom initializer taking weight and bias as inputs.
148
+ Overrides init if not None.
149
+ """
150
+ super(Linear, self).__init__(in_dim, out_dim, bias=bias)
151
+
152
+ if bias:
153
+ with torch.no_grad():
154
+ self.bias.fill_(0)
155
+
156
+ with torch.no_grad():
157
+ if init_fn is not None:
158
+ init_fn(self.weight, self.bias)
159
+ else:
160
+ if init == "default":
161
+ lecun_normal_init_(self.weight)
162
+ elif init == "relu":
163
+ he_normal_init_(self.weight)
164
+ elif init == "glorot":
165
+ glorot_uniform_init_(self.weight)
166
+ elif init == "gating":
167
+ gating_init_(self.weight)
168
+ if bias:
169
+ self.bias.fill_(1.0)
170
+ elif init == "normal":
171
+ normal_init_(self.weight)
172
+ elif init == "final":
173
+ final_init_(self.weight)
174
+ else:
175
+ raise ValueError("Invalid init string.")
176
+
177
+ self.precision = precision
178
+
179
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
180
+ d = input.dtype
181
+ if self.precision is not None:
182
+ with torch.cuda.amp.autocast(enabled=False):
183
+ bias = self.bias.to(dtype=self.precision) if self.bias is not None else None
184
+ return nn.functional.linear(input.to(dtype=self.precision),
185
+ self.weight.to(dtype=self.precision),
186
+ bias).to(dtype=d)
187
+
188
+ if d is torch.bfloat16:
189
+ with torch.cuda.amp.autocast(enabled=False):
190
+ bias = self.bias.to(dtype=d) if self.bias is not None else None
191
+ return nn.functional.linear(input, self.weight.to(dtype=d), bias)
192
+
193
+ return nn.functional.linear(input, self.weight, self.bias)
194
+
195
+
196
+ class LayerNorm(nn.Module):
197
+ def __init__(self, c_in, eps=1e-5):
198
+ super(LayerNorm, self).__init__()
199
+
200
+ self.c_in = (c_in,)
201
+ self.eps = eps
202
+
203
+ self.weight = nn.Parameter(torch.ones(c_in))
204
+ self.bias = nn.Parameter(torch.zeros(c_in))
205
+
206
+ def forward(self, x):
207
+ d = x.dtype
208
+ if d is torch.bfloat16:
209
+ with torch.cuda.amp.autocast(enabled=False):
210
+ out = nn.functional.layer_norm(
211
+ x,
212
+ self.c_in,
213
+ self.weight.to(dtype=d),
214
+ self.bias.to(dtype=d),
215
+ self.eps
216
+ )
217
+ else:
218
+ out = nn.functional.layer_norm(
219
+ x,
220
+ self.c_in,
221
+ self.weight,
222
+ self.bias,
223
+ self.eps,
224
+ )
225
+
226
+ return out
227
+
228
+
229
+ @torch.jit.ignore
230
+ def softmax_no_cast(t: torch.Tensor, dim: int = -1) -> torch.Tensor:
231
+ """
232
+ Softmax, but without automatic casting to fp32 when the input is of
233
+ type bfloat16
234
+ """
235
+ d = t.dtype
236
+ if d is torch.bfloat16:
237
+ with torch.cuda.amp.autocast(enabled=False):
238
+ s = torch.nn.functional.softmax(t, dim=dim)
239
+ else:
240
+ s = torch.nn.functional.softmax(t, dim=dim)
241
+
242
+ return s
243
+
244
+
245
+ #@torch.jit.script
246
+ def _attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, biases: List[torch.Tensor]) -> torch.Tensor:
247
+ # [*, H, C_hidden, K]
248
+ key = permute_final_dims(key, (1, 0))
249
+
250
+ # [*, H, Q, K]
251
+ a = torch.matmul(query, key)
252
+
253
+ for b in biases:
254
+ a += b
255
+
256
+ a = softmax_no_cast(a, -1)
257
+
258
+ # [*, H, Q, C_hidden]
259
+ a = torch.matmul(a, value)
260
+
261
+ return a
262
+
263
+
264
+ class Attention(nn.Module):
265
+ """
266
+ Standard multi-head attention using AlphaFold's default layer
267
+ initialization. Allows multiple bias vectors.
268
+ """
269
+ def __init__(
270
+ self,
271
+ c_q: int,
272
+ c_k: int,
273
+ c_v: int,
274
+ c_hidden: int,
275
+ no_heads: int,
276
+ gating: bool = True,
277
+ ):
278
+ """
279
+ Args:
280
+ c_q:
281
+ Input dimension of query data
282
+ c_k:
283
+ Input dimension of key data
284
+ c_v:
285
+ Input dimension of value data
286
+ c_hidden:
287
+ Per-head hidden dimension
288
+ no_heads:
289
+ Number of attention heads
290
+ gating:
291
+ Whether the output should be gated using query data
292
+ """
293
+ super(Attention, self).__init__()
294
+
295
+ self.c_q = c_q
296
+ self.c_k = c_k
297
+ self.c_v = c_v
298
+ self.c_hidden = c_hidden
299
+ self.no_heads = no_heads
300
+ self.gating = gating
301
+
302
+ # DISCREPANCY: c_hidden is not the per-head channel dimension, as
303
+ # stated in the supplement, but the overall channel dimension.
304
+
305
+ self.linear_q = Linear(
306
+ self.c_q, self.c_hidden * self.no_heads, bias=False, init="glorot"
307
+ )
308
+ self.linear_k = Linear(
309
+ self.c_k, self.c_hidden * self.no_heads, bias=False, init="glorot"
310
+ )
311
+ self.linear_v = Linear(
312
+ self.c_v, self.c_hidden * self.no_heads, bias=False, init="glorot"
313
+ )
314
+ self.linear_o = Linear(
315
+ self.c_hidden * self.no_heads, self.c_q, init="final"
316
+ )
317
+
318
+ self.linear_g = None
319
+ if self.gating:
320
+ self.linear_g = Linear(
321
+ self.c_q, self.c_hidden * self.no_heads, init="gating"
322
+ )
323
+
324
+ self.sigmoid = nn.Sigmoid()
325
+
326
+ def _prep_qkv(self,
327
+ q_x: torch.Tensor,
328
+ kv_x: torch.Tensor,
329
+ apply_scale: bool = True
330
+ ) -> Tuple[
331
+ torch.Tensor, torch.Tensor, torch.Tensor
332
+ ]:
333
+ # [*, Q/K/V, H * C_hidden]
334
+ q = self.linear_q(q_x)
335
+ k = self.linear_k(kv_x)
336
+ v = self.linear_v(kv_x)
337
+
338
+ # [*, Q/K, H, C_hidden]
339
+ q = q.view(q.shape[:-1] + (self.no_heads, -1))
340
+ k = k.view(k.shape[:-1] + (self.no_heads, -1))
341
+ v = v.view(v.shape[:-1] + (self.no_heads, -1))
342
+
343
+ # [*, H, Q/K, C_hidden]
344
+ q = q.transpose(-2, -3)
345
+ k = k.transpose(-2, -3)
346
+ v = v.transpose(-2, -3)
347
+
348
+ if apply_scale:
349
+ q /= math.sqrt(self.c_hidden)
350
+
351
+ return q, k, v
352
+
353
+ def _wrap_up(self,
354
+ o: torch.Tensor,
355
+ q_x: torch.Tensor
356
+ ) -> torch.Tensor:
357
+ if self.linear_g is not None:
358
+ g = self.sigmoid(self.linear_g(q_x))
359
+
360
+ # [*, Q, H, C_hidden]
361
+ g = g.view(g.shape[:-1] + (self.no_heads, -1))
362
+ o = o * g
363
+
364
+ # [*, Q, H * C_hidden]
365
+ o = flatten_final_dims(o, 2)
366
+
367
+ # [*, Q, C_q]
368
+ o = self.linear_o(o)
369
+
370
+ return o
371
+
372
+ def forward(
373
+ self,
374
+ q_x: torch.Tensor,
375
+ kv_x: torch.Tensor,
376
+ biases: Optional[List[torch.Tensor]] = None,
377
+ use_memory_efficient_kernel: bool = False,
378
+ use_lma: bool = False,
379
+ lma_q_chunk_size: int = DEFAULT_LMA_Q_CHUNK_SIZE,
380
+ lma_kv_chunk_size: int = DEFAULT_LMA_KV_CHUNK_SIZE,
381
+ ) -> torch.Tensor:
382
+ """
383
+ Args:
384
+ q_x:
385
+ [*, Q, C_q] query data
386
+ kv_x:
387
+ [*, K, C_k] key data
388
+ biases:
389
+ List of biases that broadcast to [*, H, Q, K]
390
+ use_memory_efficient_kernel:
391
+ Whether to use a custom memory-efficient attention kernel.
392
+ This should be the default choice for most. If none of the
393
+ "use_<...>" flags are True, a stock PyTorch implementation
394
+ is used instead
395
+ use_lma:
396
+ Whether to use low-memory attention (Staats & Rabe 2021). If
397
+ none of the "use_<...>" flags are True, a stock PyTorch
398
+ implementation is used instead
399
+ lma_q_chunk_size:
400
+ Query chunk size (for LMA)
401
+ lma_kv_chunk_size:
402
+ Key/Value chunk size (for LMA)
403
+ Returns
404
+ [*, Q, C_q] attention update
405
+ """
406
+ if use_lma and (lma_q_chunk_size is None or lma_kv_chunk_size is None):
407
+ raise ValueError(
408
+ "If use_lma is specified, lma_q_chunk_size and "
409
+ "lma_kv_chunk_size must be provided"
410
+ )
411
+
412
+ attn_options = [use_memory_efficient_kernel, use_lma]
413
+ if sum(attn_options) > 1:
414
+ raise ValueError(
415
+ "Choose at most one alternative attention algorithm"
416
+ )
417
+
418
+ if biases is None:
419
+ biases = []
420
+
421
+ q, k, v = self._prep_qkv(q_x, kv_x, apply_scale=True)
422
+
423
+ if is_fp16_enabled():
424
+ use_memory_efficient_kernel = False
425
+
426
+ if use_memory_efficient_kernel:
427
+ if len(biases) > 2:
428
+ raise ValueError(
429
+ "If use_memory_efficient_kernel is True, you may only "
430
+ "provide up to two bias terms"
431
+ )
432
+ o = attention_core(q, k, v, *((biases + [None] * 2)[:2]))
433
+ o = o.transpose(-2, -3)
434
+ elif use_lma:
435
+ biases = [
436
+ b.expand(b.shape[:-2] + (q_x.shape[-2],) + (kv_x.shape[-2],))
437
+ for b in biases
438
+ ]
439
+ o = _lma(q, k, v, biases, lma_q_chunk_size, lma_kv_chunk_size)
440
+ o = o.transpose(-2, -3)
441
+ else:
442
+ o = _attention(q, k, v, biases)
443
+ o = o.transpose(-2, -3)
444
+
445
+ o = self._wrap_up(o, q_x)
446
+
447
+ return o
448
+
449
+
450
+ class GlobalAttention(nn.Module):
451
+ def __init__(self, c_in, c_hidden, no_heads, inf, eps):
452
+ super(GlobalAttention, self).__init__()
453
+
454
+ self.c_in = c_in
455
+ self.c_hidden = c_hidden
456
+ self.no_heads = no_heads
457
+ self.inf = inf
458
+ self.eps = eps
459
+
460
+ self.linear_q = Linear(
461
+ c_in, c_hidden * no_heads, bias=False, init="glorot"
462
+ )
463
+
464
+ self.linear_k = Linear(
465
+ c_in, c_hidden, bias=False, init="glorot",
466
+ )
467
+ self.linear_v = Linear(
468
+ c_in, c_hidden, bias=False, init="glorot",
469
+ )
470
+ self.linear_g = Linear(c_in, c_hidden * no_heads, init="gating")
471
+ self.linear_o = Linear(c_hidden * no_heads, c_in, init="final")
472
+
473
+ self.sigmoid = nn.Sigmoid()
474
+
475
+ def forward(self,
476
+ m: torch.Tensor,
477
+ mask: torch.Tensor,
478
+ use_lma: bool = False,
479
+ ) -> torch.Tensor:
480
+ # [*, N_res, C_in]
481
+ q = torch.sum(m * mask.unsqueeze(-1), dim=-2) / (
482
+ torch.sum(mask, dim=-1)[..., None] + self.eps
483
+ )
484
+
485
+ # [*, N_res, H * C_hidden]
486
+ q = self.linear_q(q)
487
+ q *= (self.c_hidden ** (-0.5))
488
+
489
+ # [*, N_res, H, C_hidden]
490
+ q = q.view(q.shape[:-1] + (self.no_heads, -1))
491
+
492
+ # [*, N_res, C_hidden]
493
+ k = self.linear_k(m)
494
+ v = self.linear_v(m)
495
+
496
+ bias = (self.inf * (mask - 1))[..., :, None, :]
497
+ if not use_lma:
498
+ # [*, N_res, H, N_seq]
499
+ a = torch.matmul(
500
+ q,
501
+ k.transpose(-1, -2), # [*, N_res, C_hidden, N_seq]
502
+ )
503
+ a += bias
504
+ a = softmax_no_cast(a)
505
+
506
+ # [*, N_res, H, C_hidden]
507
+ o = torch.matmul(
508
+ a,
509
+ v,
510
+ )
511
+ else:
512
+ o = _lma(
513
+ q,
514
+ k,
515
+ v,
516
+ [bias],
517
+ DEFAULT_LMA_Q_CHUNK_SIZE,
518
+ DEFAULT_LMA_KV_CHUNK_SIZE
519
+ )
520
+
521
+ # [*, N_res, C_hidden]
522
+ g = self.sigmoid(self.linear_g(m))
523
+
524
+ # [*, N_res, H, C_hidden]
525
+ g = g.view(g.shape[:-1] + (self.no_heads, -1))
526
+
527
+ # [*, N_res, H, C_hidden]
528
+ o = o.unsqueeze(-3) * g
529
+
530
+ # [*, N_res, H * C_hidden]
531
+ o = o.reshape(o.shape[:-2] + (-1,))
532
+
533
+ # [*, N_res, C_in]
534
+ m = self.linear_o(o)
535
+
536
+ return m
537
+
538
+
539
+ def _lma(
540
+ q: torch.Tensor,
541
+ k: torch.Tensor,
542
+ v: torch.Tensor,
543
+ biases: List[torch.Tensor],
544
+ q_chunk_size: int,
545
+ kv_chunk_size: int,
546
+ ):
547
+ no_q, no_kv = q.shape[-2], k.shape[-2]
548
+
549
+ # [*, H, Q, C_hidden]
550
+ o = q.new_zeros(q.shape)
551
+ for q_s in range(0, no_q, q_chunk_size):
552
+ q_chunk = q[..., q_s: q_s + q_chunk_size, :]
553
+ large_bias_chunks = [
554
+ b[..., q_s: q_s + q_chunk_size, :] for b in biases
555
+ ]
556
+
557
+ maxes = []
558
+ weights = []
559
+ values = []
560
+ for kv_s in range(0, no_kv, kv_chunk_size):
561
+ k_chunk = k[..., kv_s: kv_s + kv_chunk_size, :]
562
+ v_chunk = v[..., kv_s: kv_s + kv_chunk_size, :]
563
+ small_bias_chunks = [
564
+ b[..., kv_s: kv_s + kv_chunk_size] for b in large_bias_chunks
565
+ ]
566
+
567
+ a = torch.einsum(
568
+ "...hqd,...hkd->...hqk", q_chunk, k_chunk,
569
+ )
570
+
571
+ for b in small_bias_chunks:
572
+ a += b
573
+
574
+ max_a = torch.max(a, dim=-1, keepdim=True)[0]
575
+ exp_a = torch.exp(a - max_a)
576
+ exp_v = torch.einsum("...hvf,...hqv->...hqf", v_chunk, exp_a)
577
+
578
+ maxes.append(max_a.detach().squeeze(-1))
579
+ weights.append(torch.sum(exp_a, dim=-1))
580
+ values.append(exp_v)
581
+
582
+ chunk_max = torch.stack(maxes, dim=-3)
583
+ chunk_weights = torch.stack(weights, dim=-3)
584
+ chunk_values = torch.stack(values, dim=-4)
585
+
586
+ global_max = torch.max(chunk_max, dim=-3, keepdim=True)[0]
587
+ max_diffs = torch.exp(chunk_max - global_max)
588
+ chunk_values = chunk_values * max_diffs.unsqueeze(-1)
589
+ chunk_weights = chunk_weights * max_diffs
590
+
591
+ all_values = torch.sum(chunk_values, dim=-4)
592
+ all_weights = torch.sum(chunk_weights.unsqueeze(-1), dim=-4)
593
+
594
+ q_chunk_out = all_values / all_weights
595
+
596
+ o[..., q_s: q_s + q_chunk_size, :] = q_chunk_out
597
+
598
+ return o
dockformerpp/model/single_attention.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from functools import partial
16
+ import math
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.utils.checkpoint
20
+ from typing import Optional, List, Tuple
21
+
22
+ from dockformerpp.model.primitives import (
23
+ Linear,
24
+ LayerNorm,
25
+ Attention,
26
+ )
27
+ from dockformerpp.utils.tensor_utils import permute_final_dims
28
+
29
+
30
+ class SingleAttention(nn.Module):
31
+ def __init__(
32
+ self,
33
+ c_in,
34
+ c_hidden,
35
+ no_heads,
36
+ pair_bias=False,
37
+ c_z=None,
38
+ inf=1e9,
39
+ ):
40
+ """
41
+ Args:
42
+ c_in:
43
+ Input channel dimension
44
+ c_hidden:
45
+ Per-head hidden channel dimension
46
+ no_heads:
47
+ Number of attention heads
48
+ pair_bias:
49
+ Whether to use pair embedding bias
50
+ c_z:
51
+ Pair embedding channel dimension. Ignored unless pair_bias
52
+ is true
53
+ inf:
54
+ A large number to be used in computing the attention mask
55
+ """
56
+ super(SingleAttention, self).__init__()
57
+
58
+ self.c_in = c_in
59
+ self.c_hidden = c_hidden
60
+ self.no_heads = no_heads
61
+ self.pair_bias = pair_bias
62
+ self.c_z = c_z
63
+ self.inf = inf
64
+
65
+ self.layer_norm_m = LayerNorm(self.c_in)
66
+
67
+ self.layer_norm_z = None
68
+ self.linear_z = None
69
+ if self.pair_bias:
70
+ self.layer_norm_z = LayerNorm(self.c_z)
71
+ self.linear_z = Linear(
72
+ self.c_z, self.no_heads, bias=False, init="normal"
73
+ )
74
+
75
+ self.mha = Attention(
76
+ self.c_in,
77
+ self.c_in,
78
+ self.c_in,
79
+ self.c_hidden,
80
+ self.no_heads,
81
+ )
82
+
83
+ def _prep_inputs(self,
84
+ m: torch.Tensor,
85
+ z: Optional[torch.Tensor],
86
+ mask: Optional[torch.Tensor],
87
+ inplace_safe: bool = False,
88
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
89
+ if mask is None:
90
+ # [*, N_res]
91
+ mask = m.new_ones(m.shape[:-1])
92
+
93
+ # [*, 1, 1, N_res]
94
+ mask_bias = (self.inf * (mask - 1))[..., :, None, None, :]
95
+
96
+ if (self.pair_bias and
97
+ z is not None and # For the
98
+ self.layer_norm_z is not None and # benefit of
99
+ self.linear_z is not None # TorchScript
100
+ ):
101
+ chunks = []
102
+
103
+ for i in range(0, z.shape[-3], 256):
104
+ z_chunk = z[..., i: i + 256, :, :]
105
+
106
+ # [*, N_res, N_res, C_z]
107
+ z_chunk = self.layer_norm_z(z_chunk)
108
+
109
+ # [*, N_res, N_res, no_heads]
110
+ z_chunk = self.linear_z(z_chunk)
111
+
112
+ chunks.append(z_chunk)
113
+
114
+ z = torch.cat(chunks, dim=-3)
115
+
116
+ # [*, no_heads, N_res, N_res]
117
+ z = permute_final_dims(z, (2, 0, 1))
118
+
119
+ return m, mask_bias, z
120
+
121
+ def forward(self,
122
+ m: torch.Tensor,
123
+ z: Optional[torch.Tensor] = None,
124
+ mask: Optional[torch.Tensor] = None,
125
+ use_memory_efficient_kernel: bool = False,
126
+ use_lma: bool = False,
127
+ inplace_safe: bool = False,
128
+ ) -> torch.Tensor:
129
+ """
130
+ Args:
131
+ m:
132
+ [*, N_res, C_m] single embedding
133
+ z:
134
+ [*, N_res, N_res, C_z] pair embedding. Required only if pair_bias is True
135
+ mask:
136
+ [*, N_res] single mask
137
+ """
138
+ m, mask_bias, z = self._prep_inputs(
139
+ m, z, mask, inplace_safe=inplace_safe
140
+ )
141
+
142
+ biases = [mask_bias]
143
+ if(z is not None):
144
+ biases.append(z)
145
+
146
+ m = self.layer_norm_m(m)
147
+ m = self.mha(
148
+ q_x=m,
149
+ kv_x=m,
150
+ biases=biases,
151
+ use_memory_efficient_kernel=use_memory_efficient_kernel,
152
+ use_lma=use_lma,
153
+ )
154
+
155
+ return m
156
+
157
+
158
+ class SingleRowAttentionWithPairBias(SingleAttention):
159
+ """
160
+ Implements Algorithm 7.
161
+ """
162
+
163
+ def __init__(self, c_m, c_z, c_hidden, no_heads, inf=1e9):
164
+ """
165
+ Args:
166
+ c_m:
167
+ Input channel dimension
168
+ c_z:
169
+ Pair embedding channel dimension
170
+ c_hidden:
171
+ Per-head hidden channel dimension
172
+ no_heads:
173
+ Number of attention heads
174
+ inf:
175
+ Large number used to construct attention masks
176
+ """
177
+ super(SingleRowAttentionWithPairBias, self).__init__(
178
+ c_m,
179
+ c_hidden,
180
+ no_heads,
181
+ pair_bias=True,
182
+ c_z=c_z,
183
+ inf=inf,
184
+ )
dockformerpp/model/structure_module.py ADDED
@@ -0,0 +1,837 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from functools import reduce
16
+ import importlib
17
+ import math
18
+ import sys
19
+ from operator import mul
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+ from typing import Optional, Tuple, Sequence, Union
24
+
25
+ from dockformerpp.model.primitives import Linear, LayerNorm, ipa_point_weights_init_
26
+ from dockformerpp.utils.residue_constants import (
27
+ restype_rigid_group_default_frame,
28
+ restype_atom14_to_rigid_group,
29
+ restype_atom14_mask,
30
+ restype_atom14_rigid_group_positions,
31
+ )
32
+ from dockformerpp.utils.geometry.quat_rigid import QuatRigid
33
+ from dockformerpp.utils.geometry.rigid_matrix_vector import Rigid3Array
34
+ from dockformerpp.utils.geometry.vector import Vec3Array, square_euclidean_distance
35
+ from dockformerpp.utils.feats import (
36
+ frames_and_literature_positions_to_atom14_pos,
37
+ torsion_angles_to_frames,
38
+ )
39
+ from dockformerpp.utils.precision_utils import is_fp16_enabled
40
+ from dockformerpp.utils.rigid_utils import Rotation, Rigid
41
+ from dockformerpp.utils.tensor_utils import (
42
+ dict_multimap,
43
+ permute_final_dims,
44
+ flatten_final_dims,
45
+ )
46
+
47
+ import importlib.util
48
+ attn_core_is_installed = importlib.util.find_spec("attn_core_inplace_cuda") is not None
49
+ attn_core_inplace_cuda = None
50
+ if attn_core_is_installed:
51
+ attn_core_inplace_cuda = importlib.import_module("attn_core_inplace_cuda")
52
+
53
+
54
+ class AngleResnetBlock(nn.Module):
55
+ def __init__(self, c_hidden):
56
+ """
57
+ Args:
58
+ c_hidden:
59
+ Hidden channel dimension
60
+ """
61
+ super(AngleResnetBlock, self).__init__()
62
+
63
+ self.c_hidden = c_hidden
64
+
65
+ self.linear_1 = Linear(self.c_hidden, self.c_hidden, init="relu")
66
+ self.linear_2 = Linear(self.c_hidden, self.c_hidden, init="final")
67
+
68
+ self.relu = nn.ReLU()
69
+
70
+ def forward(self, a: torch.Tensor) -> torch.Tensor:
71
+
72
+ s_initial = a
73
+
74
+ a = self.relu(a)
75
+ a = self.linear_1(a)
76
+ a = self.relu(a)
77
+ a = self.linear_2(a)
78
+
79
+ return a + s_initial
80
+
81
+
82
+ class AngleResnet(nn.Module):
83
+ """
84
+ Implements Algorithm 20, lines 11-14
85
+ """
86
+
87
+ def __init__(self, c_in, c_hidden, no_blocks, no_angles, epsilon):
88
+ """
89
+ Args:
90
+ c_in:
91
+ Input channel dimension
92
+ c_hidden:
93
+ Hidden channel dimension
94
+ no_blocks:
95
+ Number of resnet blocks
96
+ no_angles:
97
+ Number of torsion angles to generate
98
+ epsilon:
99
+ Small constant for normalization
100
+ """
101
+ super(AngleResnet, self).__init__()
102
+
103
+ self.c_in = c_in
104
+ self.c_hidden = c_hidden
105
+ self.no_blocks = no_blocks
106
+ self.no_angles = no_angles
107
+ self.eps = epsilon
108
+
109
+ self.linear_in = Linear(self.c_in, self.c_hidden)
110
+ self.linear_initial = Linear(self.c_in, self.c_hidden)
111
+
112
+ self.layers = nn.ModuleList()
113
+ for _ in range(self.no_blocks):
114
+ layer = AngleResnetBlock(c_hidden=self.c_hidden)
115
+ self.layers.append(layer)
116
+
117
+ self.linear_out = Linear(self.c_hidden, self.no_angles * 2)
118
+
119
+ self.relu = nn.ReLU()
120
+
121
+ def forward(
122
+ self, s: torch.Tensor, s_initial: torch.Tensor
123
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
124
+ """
125
+ Args:
126
+ s:
127
+ [*, C_hidden] single embedding
128
+ s_initial:
129
+ [*, C_hidden] single embedding as of the start of the
130
+ StructureModule
131
+ Returns:
132
+ [*, no_angles, 2] predicted angles
133
+ """
134
+ # NOTE: The ReLU's applied to the inputs are absent from the supplement
135
+ # pseudocode but present in the source. For maximal compatibility with
136
+ # the pretrained weights, I'm going with the source.
137
+
138
+ # [*, C_hidden]
139
+ s_initial = self.relu(s_initial)
140
+ s_initial = self.linear_initial(s_initial)
141
+ s = self.relu(s)
142
+ s = self.linear_in(s)
143
+ s = s + s_initial
144
+
145
+ for l in self.layers:
146
+ s = l(s)
147
+
148
+ s = self.relu(s)
149
+
150
+ # [*, no_angles * 2]
151
+ s = self.linear_out(s)
152
+
153
+ # [*, no_angles, 2]
154
+ s = s.view(s.shape[:-1] + (-1, 2))
155
+
156
+ unnormalized_s = s
157
+ norm_denom = torch.sqrt(
158
+ torch.clamp(
159
+ torch.sum(s ** 2, dim=-1, keepdim=True),
160
+ min=self.eps,
161
+ )
162
+ )
163
+ s = s / norm_denom
164
+
165
+ return unnormalized_s, s
166
+
167
+
168
+ class PointProjection(nn.Module):
169
+ def __init__(self,
170
+ c_hidden: int,
171
+ num_points: int,
172
+ no_heads: int,
173
+ return_local_points: bool = False,
174
+ ):
175
+ super().__init__()
176
+ self.return_local_points = return_local_points
177
+ self.no_heads = no_heads
178
+ self.num_points = num_points
179
+
180
+ # Multimer requires this to be run with fp32 precision during training
181
+ precision = None
182
+ self.linear = Linear(c_hidden, no_heads * 3 * num_points, precision=precision)
183
+
184
+ def forward(self,
185
+ activations: torch.Tensor,
186
+ rigids: Union[Rigid, Rigid3Array],
187
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
188
+ # TODO: Needs to run in high precision during training
189
+ points_local = self.linear(activations)
190
+ out_shape = points_local.shape[:-1] + (self.no_heads, self.num_points, 3)
191
+
192
+ points_local = torch.split(
193
+ points_local, points_local.shape[-1] // 3, dim=-1
194
+ )
195
+
196
+ points_local = torch.stack(points_local, dim=-1).view(out_shape)
197
+
198
+ points_global = rigids[..., None, None].apply(points_local)
199
+
200
+ if(self.return_local_points):
201
+ return points_global, points_local
202
+
203
+ return points_global
204
+
205
+
206
+ class InvariantPointAttention(nn.Module):
207
+ """
208
+ Implements Algorithm 22.
209
+ """
210
+ def __init__(
211
+ self,
212
+ c_s: int,
213
+ c_z: int,
214
+ c_hidden: int,
215
+ no_heads: int,
216
+ no_qk_points: int,
217
+ no_v_points: int,
218
+ inf: float = 1e5,
219
+ eps: float = 1e-8,
220
+ ):
221
+ """
222
+ Args:
223
+ c_s:
224
+ Single representation channel dimension
225
+ c_z:
226
+ Pair representation channel dimension
227
+ c_hidden:
228
+ Hidden channel dimension
229
+ no_heads:
230
+ Number of attention heads
231
+ no_qk_points:
232
+ Number of query/key points to generate
233
+ no_v_points:
234
+ Number of value points to generate
235
+ """
236
+ super(InvariantPointAttention, self).__init__()
237
+
238
+ self.c_s = c_s
239
+ self.c_z = c_z
240
+ self.c_hidden = c_hidden
241
+ self.no_heads = no_heads
242
+ self.no_qk_points = no_qk_points
243
+ self.no_v_points = no_v_points
244
+ self.inf = inf
245
+ self.eps = eps
246
+
247
+ # These linear layers differ from their specifications in the
248
+ # supplement. There, they lack bias and use Glorot initialization.
249
+ # Here as in the official source, they have bias and use the default
250
+ # Lecun initialization.
251
+ hc = self.c_hidden * self.no_heads
252
+ self.linear_q = Linear(self.c_s, hc, bias=True)
253
+
254
+ self.linear_q_points = PointProjection(
255
+ self.c_s,
256
+ self.no_qk_points,
257
+ self.no_heads,
258
+ )
259
+
260
+
261
+ self.linear_kv = Linear(self.c_s, 2 * hc)
262
+ self.linear_kv_points = PointProjection(
263
+ self.c_s,
264
+ self.no_qk_points + self.no_v_points,
265
+ self.no_heads,
266
+ )
267
+
268
+ self.linear_b = Linear(self.c_z, self.no_heads)
269
+
270
+ self.head_weights = nn.Parameter(torch.zeros((no_heads)))
271
+ ipa_point_weights_init_(self.head_weights)
272
+
273
+ concat_out_dim = self.no_heads * (
274
+ self.c_z + self.c_hidden + self.no_v_points * 4
275
+ )
276
+ self.linear_out = Linear(concat_out_dim, self.c_s, init="final")
277
+
278
+ self.softmax = nn.Softmax(dim=-1)
279
+ self.softplus = nn.Softplus()
280
+
281
+ def forward(
282
+ self,
283
+ s: torch.Tensor,
284
+ z: torch.Tensor,
285
+ r: Union[Rigid, Rigid3Array],
286
+ mask: torch.Tensor,
287
+ inplace_safe: bool = False,
288
+ ) -> torch.Tensor:
289
+ """
290
+ Args:
291
+ s:
292
+ [*, N_res, C_s] single representation
293
+ z:
294
+ [*, N_res, N_res, C_z] pair representation
295
+ r:
296
+ [*, N_res] transformation object
297
+ mask:
298
+ [*, N_res] mask
299
+ Returns:
300
+ [*, N_res, C_s] single representation update
301
+ """
302
+ z = [z]
303
+
304
+ #######################################
305
+ # Generate scalar and point activations
306
+ #######################################
307
+ # [*, N_res, H * C_hidden]
308
+ q = self.linear_q(s)
309
+
310
+ # [*, N_res, H, C_hidden]
311
+ q = q.view(q.shape[:-1] + (self.no_heads, -1))
312
+
313
+ # [*, N_res, H, P_qk]
314
+ q_pts = self.linear_q_points(s, r)
315
+
316
+ # The following two blocks are equivalent
317
+ # They're separated only to preserve compatibility with old AF weights
318
+
319
+ # [*, N_res, H * 2 * C_hidden]
320
+ kv = self.linear_kv(s)
321
+
322
+ # [*, N_res, H, 2 * C_hidden]
323
+ kv = kv.view(kv.shape[:-1] + (self.no_heads, -1))
324
+
325
+ # [*, N_res, H, C_hidden]
326
+ k, v = torch.split(kv, self.c_hidden, dim=-1)
327
+
328
+ kv_pts = self.linear_kv_points(s, r)
329
+
330
+ # [*, N_res, H, P_q/P_v, 3]
331
+ k_pts, v_pts = torch.split(
332
+ kv_pts, [self.no_qk_points, self.no_v_points], dim=-2
333
+ )
334
+
335
+ ##########################
336
+ # Compute attention scores
337
+ ##########################
338
+ # [*, N_res, N_res, H]
339
+ b = self.linear_b(z[0])
340
+
341
+ # [*, H, N_res, N_res]
342
+ if (is_fp16_enabled()):
343
+ with torch.cuda.amp.autocast(enabled=False):
344
+ a = torch.matmul(
345
+ permute_final_dims(q.float(), (1, 0, 2)), # [*, H, N_res, C_hidden]
346
+ permute_final_dims(k.float(), (1, 2, 0)), # [*, H, C_hidden, N_res]
347
+ )
348
+ else:
349
+ a = torch.matmul(
350
+ permute_final_dims(q, (1, 0, 2)), # [*, H, N_res, C_hidden]
351
+ permute_final_dims(k, (1, 2, 0)), # [*, H, C_hidden, N_res]
352
+ )
353
+
354
+ a *= math.sqrt(1.0 / (3 * self.c_hidden))
355
+ a += (math.sqrt(1.0 / 3) * permute_final_dims(b, (2, 0, 1)))
356
+
357
+ # [*, N_res, N_res, H, P_q, 3]
358
+ pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5)
359
+
360
+ if (inplace_safe):
361
+ pt_att *= pt_att
362
+ else:
363
+ pt_att = pt_att ** 2
364
+
365
+ pt_att = sum(torch.unbind(pt_att, dim=-1))
366
+
367
+ head_weights = self.softplus(self.head_weights).view(
368
+ *((1,) * len(pt_att.shape[:-2]) + (-1, 1))
369
+ )
370
+ head_weights = head_weights * math.sqrt(
371
+ 1.0 / (3 * (self.no_qk_points * 9.0 / 2))
372
+ )
373
+
374
+ if (inplace_safe):
375
+ pt_att *= head_weights
376
+ else:
377
+ pt_att = pt_att * head_weights
378
+
379
+ # [*, N_res, N_res, H]
380
+ pt_att = torch.sum(pt_att, dim=-1) * (-0.5)
381
+
382
+ # [*, N_res, N_res]
383
+ square_mask = mask.unsqueeze(-1) * mask.unsqueeze(-2)
384
+ square_mask = self.inf * (square_mask - 1)
385
+
386
+ # [*, H, N_res, N_res]
387
+ pt_att = permute_final_dims(pt_att, (2, 0, 1))
388
+
389
+ if (inplace_safe):
390
+ a += pt_att
391
+ del pt_att
392
+ a += square_mask.unsqueeze(-3)
393
+ # in-place softmax
394
+ attn_core_inplace_cuda.forward_(
395
+ a,
396
+ reduce(mul, a.shape[:-1]),
397
+ a.shape[-1],
398
+ )
399
+ else:
400
+ a = a + pt_att
401
+ a = a + square_mask.unsqueeze(-3)
402
+ a = self.softmax(a)
403
+
404
+ ################
405
+ # Compute output
406
+ ################
407
+ # [*, N_res, H, C_hidden]
408
+ o = torch.matmul(
409
+ a, v.transpose(-2, -3).to(dtype=a.dtype)
410
+ ).transpose(-2, -3)
411
+
412
+ # [*, N_res, H * C_hidden]
413
+ o = flatten_final_dims(o, 2)
414
+
415
+ # [*, H, 3, N_res, P_v]
416
+ if (inplace_safe):
417
+ v_pts = permute_final_dims(v_pts, (1, 3, 0, 2))
418
+ o_pt = [
419
+ torch.matmul(a, v.to(a.dtype))
420
+ for v in torch.unbind(v_pts, dim=-3)
421
+ ]
422
+ o_pt = torch.stack(o_pt, dim=-3)
423
+ else:
424
+ o_pt = torch.sum(
425
+ (
426
+ a[..., None, :, :, None]
427
+ * permute_final_dims(v_pts, (1, 3, 0, 2))[..., None, :, :]
428
+ ),
429
+ dim=-2,
430
+ )
431
+
432
+ # [*, N_res, H, P_v, 3]
433
+ o_pt = permute_final_dims(o_pt, (2, 0, 3, 1))
434
+ o_pt = r[..., None, None].invert_apply(o_pt)
435
+
436
+ # [*, N_res, H * P_v]
437
+ o_pt_norm = flatten_final_dims(
438
+ torch.sqrt(torch.sum(o_pt ** 2, dim=-1) + self.eps), 2
439
+ )
440
+
441
+ # [*, N_res, H * P_v, 3]
442
+ o_pt = o_pt.reshape(*o_pt.shape[:-3], -1, 3)
443
+ o_pt = torch.unbind(o_pt, dim=-1)
444
+
445
+ # [*, N_res, H, C_z]
446
+ o_pair = torch.matmul(a.transpose(-2, -3), z[0].to(dtype=a.dtype))
447
+
448
+ # [*, N_res, H * C_z]
449
+ o_pair = flatten_final_dims(o_pair, 2)
450
+
451
+ # [*, N_res, C_s]
452
+ s = self.linear_out(
453
+ torch.cat(
454
+ (o, *o_pt, o_pt_norm, o_pair), dim=-1
455
+ ).to(dtype=z[0].dtype)
456
+ )
457
+
458
+ return s
459
+
460
+
461
+ class BackboneUpdate(nn.Module):
462
+ """
463
+ Implements part of Algorithm 23.
464
+ """
465
+
466
+ def __init__(self, c_s):
467
+ """
468
+ Args:
469
+ c_s:
470
+ Single representation channel dimension
471
+ """
472
+ super(BackboneUpdate, self).__init__()
473
+
474
+ self.c_s = c_s
475
+
476
+ self.linear = Linear(self.c_s, 6, init="final")
477
+
478
+ def forward(self, s: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
479
+ """
480
+ Args:
481
+ [*, N_res, C_s] single representation
482
+ Returns:
483
+ [*, N_res, 6] update vector
484
+ """
485
+ # [*, 6]
486
+ update = self.linear(s)
487
+
488
+ return update
489
+
490
+
491
+ class StructureModuleTransitionLayer(nn.Module):
492
+ def __init__(self, c):
493
+ super(StructureModuleTransitionLayer, self).__init__()
494
+
495
+ self.c = c
496
+
497
+ self.linear_1 = Linear(self.c, self.c, init="relu")
498
+ self.linear_2 = Linear(self.c, self.c, init="relu")
499
+ self.linear_3 = Linear(self.c, self.c, init="final")
500
+
501
+ self.relu = nn.ReLU()
502
+
503
+ def forward(self, s):
504
+ s_initial = s
505
+ s = self.linear_1(s)
506
+ s = self.relu(s)
507
+ s = self.linear_2(s)
508
+ s = self.relu(s)
509
+ s = self.linear_3(s)
510
+
511
+ s = s + s_initial
512
+
513
+ return s
514
+
515
+
516
+ class StructureModuleTransition(nn.Module):
517
+ def __init__(self, c, num_layers, dropout_rate):
518
+ super(StructureModuleTransition, self).__init__()
519
+
520
+ self.c = c
521
+ self.num_layers = num_layers
522
+ self.dropout_rate = dropout_rate
523
+
524
+ self.layers = nn.ModuleList()
525
+ for _ in range(self.num_layers):
526
+ l = StructureModuleTransitionLayer(self.c)
527
+ self.layers.append(l)
528
+
529
+ self.dropout = nn.Dropout(self.dropout_rate)
530
+ self.layer_norm = LayerNorm(self.c)
531
+
532
+ def forward(self, s):
533
+ for l in self.layers:
534
+ s = l(s)
535
+
536
+ s = self.dropout(s)
537
+ s = self.layer_norm(s)
538
+
539
+ return s
540
+
541
+
542
+ class StructureModule(nn.Module):
543
+ def __init__(
544
+ self,
545
+ c_s,
546
+ c_z,
547
+ c_ipa,
548
+ c_resnet,
549
+ no_heads_ipa,
550
+ no_qk_points,
551
+ no_v_points,
552
+ dropout_rate,
553
+ no_blocks,
554
+ no_transition_layers,
555
+ no_resnet_blocks,
556
+ no_angles,
557
+ trans_scale_factor,
558
+ epsilon,
559
+ inf,
560
+ **kwargs,
561
+ ):
562
+ """
563
+ Args:
564
+ c_s:
565
+ Single representation channel dimension
566
+ c_z:
567
+ Pair representation channel dimension
568
+ c_ipa:
569
+ IPA hidden channel dimension
570
+ c_resnet:
571
+ Angle resnet (Alg. 23 lines 11-14) hidden channel dimension
572
+ no_heads_ipa:
573
+ Number of IPA heads
574
+ no_qk_points:
575
+ Number of query/key points to generate during IPA
576
+ no_v_points:
577
+ Number of value points to generate during IPA
578
+ dropout_rate:
579
+ Dropout rate used throughout the layer
580
+ no_blocks:
581
+ Number of structure module blocks
582
+ no_transition_layers:
583
+ Number of layers in the single representation transition
584
+ (Alg. 23 lines 8-9)
585
+ no_resnet_blocks:
586
+ Number of blocks in the angle resnet
587
+ no_angles:
588
+ Number of angles to generate in the angle resnet
589
+ trans_scale_factor:
590
+ Scale of single representation transition hidden dimension
591
+ epsilon:
592
+ Small number used in angle resnet normalization
593
+ inf:
594
+ Large number used for attention masking
595
+ """
596
+ super(StructureModule, self).__init__()
597
+
598
+ self.c_s = c_s
599
+ self.c_z = c_z
600
+ self.c_ipa = c_ipa
601
+ self.c_resnet = c_resnet
602
+ self.no_heads_ipa = no_heads_ipa
603
+ self.no_qk_points = no_qk_points
604
+ self.no_v_points = no_v_points
605
+ self.dropout_rate = dropout_rate
606
+ self.no_blocks = no_blocks
607
+ self.no_transition_layers = no_transition_layers
608
+ self.no_resnet_blocks = no_resnet_blocks
609
+ self.no_angles = no_angles
610
+ self.trans_scale_factor = trans_scale_factor
611
+ self.epsilon = epsilon
612
+ self.inf = inf
613
+
614
+ # Buffers to be lazily initialized later
615
+ # self.default_frames
616
+ # self.group_idx
617
+ # self.atom_mask
618
+ # self.lit_positions
619
+
620
+ self.layer_norm_s = LayerNorm(self.c_s)
621
+ self.layer_norm_z = LayerNorm(self.c_z)
622
+
623
+ self.linear_in = Linear(self.c_s, self.c_s)
624
+
625
+ self.ipa = InvariantPointAttention(
626
+ self.c_s,
627
+ self.c_z,
628
+ self.c_ipa,
629
+ self.no_heads_ipa,
630
+ self.no_qk_points,
631
+ self.no_v_points,
632
+ inf=self.inf,
633
+ eps=self.epsilon,
634
+ )
635
+
636
+ self.ipa_dropout = nn.Dropout(self.dropout_rate)
637
+ self.layer_norm_ipa = LayerNorm(self.c_s)
638
+
639
+ self.transition = StructureModuleTransition(
640
+ self.c_s,
641
+ self.no_transition_layers,
642
+ self.dropout_rate,
643
+ )
644
+
645
+ self.bb_update = BackboneUpdate(self.c_s)
646
+
647
+ self.angle_resnet = AngleResnet(
648
+ self.c_s,
649
+ self.c_resnet,
650
+ self.no_resnet_blocks,
651
+ self.no_angles,
652
+ self.epsilon,
653
+ )
654
+
655
+ def forward(
656
+ self,
657
+ evoformer_output_dict,
658
+ aatype,
659
+ mask=None,
660
+ inplace_safe=False,
661
+ ):
662
+ """
663
+ Args:
664
+ evoformer_output_dict:
665
+ Dictionary containing:
666
+ "single":
667
+ [*, N_res, C_s] single representation
668
+ "pair":
669
+ [*, N_res, N_res, C_z] pair representation
670
+ aatype:
671
+ [*, N_res] amino acid indices
672
+ mask:
673
+ Optional [*, N_res] sequence mask
674
+ Returns:
675
+ A dictionary of outputs
676
+ """
677
+ s = evoformer_output_dict["single"]
678
+
679
+ if mask is None:
680
+ # [*, N]
681
+ mask = s.new_ones(s.shape[:-1])
682
+
683
+ # [*, N, C_s]
684
+ s = self.layer_norm_s(s)
685
+
686
+ # [*, N, N, C_z]
687
+ z = self.layer_norm_z(evoformer_output_dict["pair"])
688
+
689
+ # [*, N, C_s]
690
+ s_initial = s
691
+ s = self.linear_in(s)
692
+
693
+ # [*, N]
694
+ rigids = Rigid.identity(
695
+ s.shape[:-1],
696
+ s.dtype,
697
+ s.device,
698
+ self.training,
699
+ fmt="quat",
700
+ )
701
+ outputs = []
702
+ for i in range(self.no_blocks):
703
+ # [*, N, C_s]
704
+ s = s + self.ipa(
705
+ s,
706
+ z,
707
+ rigids,
708
+ mask,
709
+ inplace_safe=inplace_safe,
710
+ )
711
+ s = self.ipa_dropout(s)
712
+ s = self.layer_norm_ipa(s)
713
+ s = self.transition(s)
714
+
715
+ # [*, N]
716
+
717
+ # [*, N_res, 6] vector of translations and rotations
718
+ bb_update_output = self.bb_update(s)
719
+
720
+ rigids = rigids.compose_q_update_vec(bb_update_output)
721
+
722
+
723
+ # To hew as closely as possible to AlphaFold, we convert our
724
+ # quaternion-based transformations to rotation-matrix ones
725
+ # here
726
+ backb_to_global = Rigid(
727
+ Rotation(
728
+ rot_mats=rigids.get_rots().get_rot_mats(),
729
+ quats=None
730
+ ),
731
+ rigids.get_trans(),
732
+ )
733
+
734
+ backb_to_global = backb_to_global.scale_translation(
735
+ self.trans_scale_factor
736
+ )
737
+
738
+ # [*, N, 7, 2]
739
+ unnormalized_angles, angles = self.angle_resnet(s, s_initial)
740
+
741
+ all_frames_to_global = self.torsion_angles_to_frames(
742
+ backb_to_global,
743
+ angles,
744
+ aatype,
745
+ )
746
+
747
+ pred_xyz = self.frames_and_literature_positions_to_atom14_pos(
748
+ all_frames_to_global,
749
+ aatype,
750
+ )
751
+
752
+ scaled_rigids = rigids.scale_translation(self.trans_scale_factor)
753
+
754
+ preds = {
755
+ "frames": scaled_rigids.to_tensor_7(),
756
+ "sidechain_frames": all_frames_to_global.to_tensor_4x4(),
757
+ "unnormalized_angles": unnormalized_angles,
758
+ "angles": angles,
759
+ "positions": pred_xyz,
760
+ "states": s,
761
+ }
762
+
763
+ outputs.append(preds)
764
+
765
+ rigids = rigids.stop_rot_gradient()
766
+
767
+ del z
768
+
769
+ outputs = dict_multimap(torch.stack, outputs)
770
+ outputs["single"] = s
771
+
772
+ return outputs
773
+
774
+ def _init_residue_constants(self, float_dtype, device):
775
+ if not hasattr(self, "default_frames"):
776
+ self.register_buffer(
777
+ "default_frames",
778
+ torch.tensor(
779
+ restype_rigid_group_default_frame,
780
+ dtype=float_dtype,
781
+ device=device,
782
+ requires_grad=False,
783
+ ),
784
+ persistent=False,
785
+ )
786
+ if not hasattr(self, "group_idx"):
787
+ self.register_buffer(
788
+ "group_idx",
789
+ torch.tensor(
790
+ restype_atom14_to_rigid_group,
791
+ device=device,
792
+ requires_grad=False,
793
+ ),
794
+ persistent=False,
795
+ )
796
+ if not hasattr(self, "atom_mask"):
797
+ self.register_buffer(
798
+ "atom_mask",
799
+ torch.tensor(
800
+ restype_atom14_mask,
801
+ dtype=float_dtype,
802
+ device=device,
803
+ requires_grad=False,
804
+ ),
805
+ persistent=False,
806
+ )
807
+ if not hasattr(self, "lit_positions"):
808
+ self.register_buffer(
809
+ "lit_positions",
810
+ torch.tensor(
811
+ restype_atom14_rigid_group_positions,
812
+ dtype=float_dtype,
813
+ device=device,
814
+ requires_grad=False,
815
+ ),
816
+ persistent=False,
817
+ )
818
+
819
+ def torsion_angles_to_frames(self, r, alpha, f):
820
+ # Lazily initialize the residue constants on the correct device
821
+ self._init_residue_constants(alpha.dtype, alpha.device)
822
+ # Separated purely to make testing less annoying
823
+ return torsion_angles_to_frames(r, alpha, f, self.default_frames)
824
+
825
+ def frames_and_literature_positions_to_atom14_pos(
826
+ self, r, f # [*, N, 8] # [*, N]
827
+ ):
828
+ # Lazily initialize the residue constants on the correct device
829
+ self._init_residue_constants(r.dtype, r.device)
830
+ return frames_and_literature_positions_to_atom14_pos(
831
+ r,
832
+ f,
833
+ self.default_frames,
834
+ self.group_idx,
835
+ self.atom_mask,
836
+ self.lit_positions,
837
+ )
dockformerpp/model/torchscript.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional, Sequence, Tuple
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ from dockformerpp.model.evoformer import (
21
+ EvoformerBlock,
22
+ EvoformerStack,
23
+ )
24
+ from dockformerpp.model.single_attention import SingleRowAttentionWithPairBias
25
+ from dockformerpp.model.primitives import Attention, GlobalAttention
26
+
27
+
28
+ def script_preset_(model: torch.nn.Module):
29
+ """
30
+ TorchScript a handful of low-level but frequently used submodule types
31
+ that are known to be scriptable.
32
+
33
+ Args:
34
+ model:
35
+ A torch.nn.Module. It should contain at least some modules from
36
+ this repository, or this function won't do anything.
37
+ """
38
+ script_submodules_(
39
+ model,
40
+ [
41
+ nn.Dropout,
42
+ Attention,
43
+ GlobalAttention,
44
+ EvoformerBlock,
45
+ ],
46
+ attempt_trace=False,
47
+ batch_dims=None,
48
+ )
49
+
50
+
51
+ def _get_module_device(module: torch.nn.Module) -> torch.device:
52
+ """
53
+ Fetches the device of a module, assuming that all of the module's
54
+ parameters reside on a single device
55
+
56
+ Args:
57
+ module: A torch.nn.Module
58
+ Returns:
59
+ The module's device
60
+ """
61
+ return next(module.parameters()).device
62
+
63
+
64
+ def _trace_module(module, batch_dims=None):
65
+ if(batch_dims is None):
66
+ batch_dims = ()
67
+
68
+ # Stand-in values
69
+ n_seq = 10
70
+ n_res = 10
71
+
72
+ device = _get_module_device(module)
73
+
74
+ def msa(channel_dim):
75
+ return torch.rand(
76
+ (*batch_dims, n_seq, n_res, channel_dim),
77
+ device=device,
78
+ )
79
+
80
+ def pair(channel_dim):
81
+ return torch.rand(
82
+ (*batch_dims, n_res, n_res, channel_dim),
83
+ device=device,
84
+ )
85
+
86
+ if(isinstance(module, SingleRowAttentionWithPairBias)):
87
+ inputs = {
88
+ "forward": (
89
+ msa(module.c_in), # m
90
+ pair(module.c_z), # z
91
+ torch.randint(
92
+ 0, 2,
93
+ (*batch_dims, n_seq, n_res)
94
+ ), # mask
95
+ ),
96
+ }
97
+ else:
98
+ raise TypeError(
99
+ f"tracing is not supported for modules of type {type(module)}"
100
+ )
101
+
102
+ return torch.jit.trace_module(module, inputs)
103
+
104
+
105
+ def _script_submodules_helper_(
106
+ model,
107
+ types,
108
+ attempt_trace,
109
+ to_trace,
110
+ ):
111
+ for name, child in model.named_children():
112
+ if(types is None or any(isinstance(child, t) for t in types)):
113
+ try:
114
+ scripted = torch.jit.script(child)
115
+ setattr(model, name, scripted)
116
+ continue
117
+ except (RuntimeError, torch.jit.frontend.NotSupportedError) as e:
118
+ if(attempt_trace):
119
+ to_trace.add(type(child))
120
+ else:
121
+ raise e
122
+
123
+ _script_submodules_helper_(child, types, attempt_trace, to_trace)
124
+
125
+
126
+ def _trace_submodules_(
127
+ model,
128
+ types,
129
+ batch_dims=None,
130
+ ):
131
+ for name, child in model.named_children():
132
+ if(any(isinstance(child, t) for t in types)):
133
+ traced = _trace_module(child, batch_dims=batch_dims)
134
+ setattr(model, name, traced)
135
+ else:
136
+ _trace_submodules_(child, types, batch_dims=batch_dims)
137
+
138
+
139
+ def script_submodules_(
140
+ model: nn.Module,
141
+ types: Optional[Sequence[type]] = None,
142
+ attempt_trace: Optional[bool] = True,
143
+ batch_dims: Optional[Tuple[int]] = None,
144
+ ):
145
+ """
146
+ Convert all submodules whose types match one of those in the input
147
+ list to recursively scripted equivalents in place. To script the entire
148
+ model, just call torch.jit.script on it directly.
149
+
150
+ When types is None, all submodules are scripted.
151
+
152
+ Args:
153
+ model:
154
+ A torch.nn.Module
155
+ types:
156
+ A list of types of submodules to script
157
+ attempt_trace:
158
+ Whether to attempt to trace specified modules if scripting
159
+ fails. Recall that tracing eliminates all conditional
160
+ logic---with great tracing comes the mild responsibility of
161
+ having to remember to ensure that the modules in question
162
+ perform the same computations no matter what.
163
+ """
164
+ to_trace = set()
165
+
166
+ # Aggressively script as much as possible first...
167
+ _script_submodules_helper_(model, types, attempt_trace, to_trace)
168
+
169
+ # ... and then trace stragglers.
170
+ if(attempt_trace and len(to_trace) > 0):
171
+ _trace_submodules_(model, to_trace, batch_dims=batch_dims)
dockformerpp/model/triangular_attention.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from functools import partialmethod, partial
17
+ import math
18
+ from typing import Optional, List
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+
23
+ from dockformerpp.model.primitives import Linear, LayerNorm, Attention
24
+ from dockformerpp.utils.tensor_utils import permute_final_dims
25
+
26
+
27
+ class TriangleAttention(nn.Module):
28
+ def __init__(
29
+ self, c_in, c_hidden, no_heads, starting=True, inf=1e9
30
+ ):
31
+ """
32
+ Args:
33
+ c_in:
34
+ Input channel dimension
35
+ c_hidden:
36
+ Overall hidden channel dimension (not per-head)
37
+ no_heads:
38
+ Number of attention heads
39
+ """
40
+ super(TriangleAttention, self).__init__()
41
+
42
+ self.c_in = c_in
43
+ self.c_hidden = c_hidden
44
+ self.no_heads = no_heads
45
+ self.starting = starting
46
+ self.inf = inf
47
+
48
+ self.layer_norm = LayerNorm(self.c_in)
49
+
50
+ self.linear = Linear(c_in, self.no_heads, bias=False, init="normal")
51
+
52
+ self.mha = Attention(
53
+ self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads
54
+ )
55
+
56
+ def forward(self,
57
+ x: torch.Tensor,
58
+ mask: Optional[torch.Tensor] = None,
59
+ use_memory_efficient_kernel: bool = False,
60
+ use_lma: bool = False,
61
+ ) -> torch.Tensor:
62
+ """
63
+ Args:
64
+ x:
65
+ [*, I, J, C_in] input tensor (e.g. the pair representation)
66
+ Returns:
67
+ [*, I, J, C_in] output tensor
68
+ """
69
+ if mask is None:
70
+ # [*, I, J]
71
+ mask = x.new_ones(
72
+ x.shape[:-1],
73
+ )
74
+
75
+ if(not self.starting):
76
+ x = x.transpose(-2, -3)
77
+ mask = mask.transpose(-1, -2)
78
+
79
+ # [*, I, J, C_in]
80
+ x = self.layer_norm(x)
81
+
82
+ # [*, I, 1, 1, J]
83
+ mask_bias = (self.inf * (mask - 1))[..., :, None, None, :]
84
+
85
+ # [*, H, I, J]
86
+ triangle_bias = permute_final_dims(self.linear(x), (2, 0, 1))
87
+
88
+ # [*, 1, H, I, J]
89
+ triangle_bias = triangle_bias.unsqueeze(-4)
90
+
91
+ biases = [mask_bias, triangle_bias]
92
+
93
+ x = self.mha(
94
+ q_x=x,
95
+ kv_x=x,
96
+ biases=biases,
97
+ use_memory_efficient_kernel=use_memory_efficient_kernel,
98
+ use_lma=use_lma
99
+ )
100
+
101
+ if(not self.starting):
102
+ x = x.transpose(-2, -3)
103
+
104
+ return x
dockformerpp/model/triangular_multiplicative_update.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from functools import partialmethod
17
+ from typing import Optional
18
+ from abc import ABC, abstractmethod
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+
23
+ from dockformerpp.model.primitives import Linear, LayerNorm
24
+ from dockformerpp.utils.precision_utils import is_fp16_enabled
25
+ from dockformerpp.utils.tensor_utils import permute_final_dims
26
+
27
+
28
+ class BaseTriangleMultiplicativeUpdate(nn.Module, ABC):
29
+ """
30
+ Implements Algorithms 11 and 12.
31
+ """
32
+ @abstractmethod
33
+ def __init__(self, c_z, c_hidden, _outgoing):
34
+ """
35
+ Args:
36
+ c_z:
37
+ Input channel dimension
38
+ c:
39
+ Hidden channel dimension
40
+ """
41
+ super(BaseTriangleMultiplicativeUpdate, self).__init__()
42
+ self.c_z = c_z
43
+ self.c_hidden = c_hidden
44
+ self._outgoing = _outgoing
45
+
46
+ self.linear_g = Linear(self.c_z, self.c_z, init="gating")
47
+ self.linear_z = Linear(self.c_hidden, self.c_z, init="final")
48
+
49
+ self.layer_norm_in = LayerNorm(self.c_z)
50
+ self.layer_norm_out = LayerNorm(self.c_hidden)
51
+
52
+ self.sigmoid = nn.Sigmoid()
53
+
54
+ def _combine_projections(self,
55
+ a: torch.Tensor,
56
+ b: torch.Tensor,
57
+ ) -> torch.Tensor:
58
+ if(self._outgoing):
59
+ a = permute_final_dims(a, (2, 0, 1))
60
+ b = permute_final_dims(b, (2, 1, 0))
61
+ else:
62
+ a = permute_final_dims(a, (2, 1, 0))
63
+ b = permute_final_dims(b, (2, 0, 1))
64
+
65
+ p = torch.matmul(a, b)
66
+
67
+ return permute_final_dims(p, (1, 2, 0))
68
+
69
+ @abstractmethod
70
+ def forward(self,
71
+ z: torch.Tensor,
72
+ mask: Optional[torch.Tensor] = None,
73
+ inplace_safe: bool = False,
74
+ _add_with_inplace: bool = False
75
+ ) -> torch.Tensor:
76
+ """
77
+ Args:
78
+ x:
79
+ [*, N_res, N_res, C_z] input tensor
80
+ mask:
81
+ [*, N_res, N_res] input mask
82
+ Returns:
83
+ [*, N_res, N_res, C_z] output tensor
84
+ """
85
+ pass
86
+
87
+
88
+ class TriangleMultiplicativeUpdate(BaseTriangleMultiplicativeUpdate):
89
+ """
90
+ Implements Algorithms 11 and 12.
91
+ """
92
+ def __init__(self, c_z, c_hidden, _outgoing=True):
93
+ """
94
+ Args:
95
+ c_z:
96
+ Input channel dimension
97
+ c:
98
+ Hidden channel dimension
99
+ """
100
+ super(TriangleMultiplicativeUpdate, self).__init__(c_z=c_z,
101
+ c_hidden=c_hidden,
102
+ _outgoing=_outgoing)
103
+
104
+ self.linear_a_p = Linear(self.c_z, self.c_hidden)
105
+ self.linear_a_g = Linear(self.c_z, self.c_hidden, init="gating")
106
+ self.linear_b_p = Linear(self.c_z, self.c_hidden)
107
+ self.linear_b_g = Linear(self.c_z, self.c_hidden, init="gating")
108
+
109
+ def forward(self,
110
+ z: torch.Tensor,
111
+ mask: Optional[torch.Tensor] = None,
112
+ inplace_safe: bool = False,
113
+ _add_with_inplace: bool = False,
114
+ ) -> torch.Tensor:
115
+ """
116
+ Args:
117
+ x:
118
+ [*, N_res, N_res, C_z] input tensor
119
+ mask:
120
+ [*, N_res, N_res] input mask
121
+ Returns:
122
+ [*, N_res, N_res, C_z] output tensor
123
+ """
124
+
125
+ if mask is None:
126
+ mask = z.new_ones(z.shape[:-1])
127
+
128
+ mask = mask.unsqueeze(-1)
129
+
130
+ z = self.layer_norm_in(z)
131
+ a = mask
132
+ a = a * self.sigmoid(self.linear_a_g(z))
133
+ a = a * self.linear_a_p(z)
134
+ b = mask
135
+ b = b * self.sigmoid(self.linear_b_g(z))
136
+ b = b * self.linear_b_p(z)
137
+
138
+ # Prevents overflow of torch.matmul in combine projections in
139
+ # reduced-precision modes
140
+ a_std = a.std()
141
+ b_std = b.std()
142
+ if(is_fp16_enabled() and a_std != 0. and b_std != 0.):
143
+ a = a / a.std()
144
+ b = b / b.std()
145
+
146
+ if(is_fp16_enabled()):
147
+ with torch.cuda.amp.autocast(enabled=False):
148
+ x = self._combine_projections(a.float(), b.float())
149
+ else:
150
+ x = self._combine_projections(a, b)
151
+
152
+ del a, b
153
+ x = self.layer_norm_out(x)
154
+ x = self.linear_z(x)
155
+ g = self.sigmoid(self.linear_g(z))
156
+ x = x * g
157
+
158
+ return x
159
+
160
+
161
+ class TriangleMultiplicationOutgoing(TriangleMultiplicativeUpdate):
162
+ """
163
+ Implements Algorithm 11.
164
+ """
165
+ __init__ = partialmethod(TriangleMultiplicativeUpdate.__init__, _outgoing=True)
166
+
167
+
168
+ class TriangleMultiplicationIncoming(TriangleMultiplicativeUpdate):
169
+ """
170
+ Implements Algorithm 12.
171
+ """
172
+ __init__ = partialmethod(TriangleMultiplicativeUpdate.__init__, _outgoing=False)
173
+
dockformerpp/resources/.DS_Store ADDED
Binary file (6.15 kB). View file
 
dockformerpp/resources/__init__.py ADDED
File without changes
dockformerpp/resources/stereo_chemical_props.txt ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Bond Residue Mean StdDev
2
+ CA-CB ALA 1.520 0.021
3
+ N-CA ALA 1.459 0.020
4
+ CA-C ALA 1.525 0.026
5
+ C-O ALA 1.229 0.019
6
+ CA-CB ARG 1.535 0.022
7
+ CB-CG ARG 1.521 0.027
8
+ CG-CD ARG 1.515 0.025
9
+ CD-NE ARG 1.460 0.017
10
+ NE-CZ ARG 1.326 0.013
11
+ CZ-NH1 ARG 1.326 0.013
12
+ CZ-NH2 ARG 1.326 0.013
13
+ N-CA ARG 1.459 0.020
14
+ CA-C ARG 1.525 0.026
15
+ C-O ARG 1.229 0.019
16
+ CA-CB ASN 1.527 0.026
17
+ CB-CG ASN 1.506 0.023
18
+ CG-OD1 ASN 1.235 0.022
19
+ CG-ND2 ASN 1.324 0.025
20
+ N-CA ASN 1.459 0.020
21
+ CA-C ASN 1.525 0.026
22
+ C-O ASN 1.229 0.019
23
+ CA-CB ASP 1.535 0.022
24
+ CB-CG ASP 1.513 0.021
25
+ CG-OD1 ASP 1.249 0.023
26
+ CG-OD2 ASP 1.249 0.023
27
+ N-CA ASP 1.459 0.020
28
+ CA-C ASP 1.525 0.026
29
+ C-O ASP 1.229 0.019
30
+ CA-CB CYS 1.526 0.013
31
+ CB-SG CYS 1.812 0.016
32
+ N-CA CYS 1.459 0.020
33
+ CA-C CYS 1.525 0.026
34
+ C-O CYS 1.229 0.019
35
+ CA-CB GLU 1.535 0.022
36
+ CB-CG GLU 1.517 0.019
37
+ CG-CD GLU 1.515 0.015
38
+ CD-OE1 GLU 1.252 0.011
39
+ CD-OE2 GLU 1.252 0.011
40
+ N-CA GLU 1.459 0.020
41
+ CA-C GLU 1.525 0.026
42
+ C-O GLU 1.229 0.019
43
+ CA-CB GLN 1.535 0.022
44
+ CB-CG GLN 1.521 0.027
45
+ CG-CD GLN 1.506 0.023
46
+ CD-OE1 GLN 1.235 0.022
47
+ CD-NE2 GLN 1.324 0.025
48
+ N-CA GLN 1.459 0.020
49
+ CA-C GLN 1.525 0.026
50
+ C-O GLN 1.229 0.019
51
+ N-CA GLY 1.456 0.015
52
+ CA-C GLY 1.514 0.016
53
+ C-O GLY 1.232 0.016
54
+ CA-CB HIS 1.535 0.022
55
+ CB-CG HIS 1.492 0.016
56
+ CG-ND1 HIS 1.369 0.015
57
+ CG-CD2 HIS 1.353 0.017
58
+ ND1-CE1 HIS 1.343 0.025
59
+ CD2-NE2 HIS 1.415 0.021
60
+ CE1-NE2 HIS 1.322 0.023
61
+ N-CA HIS 1.459 0.020
62
+ CA-C HIS 1.525 0.026
63
+ C-O HIS 1.229 0.019
64
+ CA-CB ILE 1.544 0.023
65
+ CB-CG1 ILE 1.536 0.028
66
+ CB-CG2 ILE 1.524 0.031
67
+ CG1-CD1 ILE 1.500 0.069
68
+ N-CA ILE 1.459 0.020
69
+ CA-C ILE 1.525 0.026
70
+ C-O ILE 1.229 0.019
71
+ CA-CB LEU 1.533 0.023
72
+ CB-CG LEU 1.521 0.029
73
+ CG-CD1 LEU 1.514 0.037
74
+ CG-CD2 LEU 1.514 0.037
75
+ N-CA LEU 1.459 0.020
76
+ CA-C LEU 1.525 0.026
77
+ C-O LEU 1.229 0.019
78
+ CA-CB LYS 1.535 0.022
79
+ CB-CG LYS 1.521 0.027
80
+ CG-CD LYS 1.520 0.034
81
+ CD-CE LYS 1.508 0.025
82
+ CE-NZ LYS 1.486 0.025
83
+ N-CA LYS 1.459 0.020
84
+ CA-C LYS 1.525 0.026
85
+ C-O LYS 1.229 0.019
86
+ CA-CB MET 1.535 0.022
87
+ CB-CG MET 1.509 0.032
88
+ CG-SD MET 1.807 0.026
89
+ SD-CE MET 1.774 0.056
90
+ N-CA MET 1.459 0.020
91
+ CA-C MET 1.525 0.026
92
+ C-O MET 1.229 0.019
93
+ CA-CB PHE 1.535 0.022
94
+ CB-CG PHE 1.509 0.017
95
+ CG-CD1 PHE 1.383 0.015
96
+ CG-CD2 PHE 1.383 0.015
97
+ CD1-CE1 PHE 1.388 0.020
98
+ CD2-CE2 PHE 1.388 0.020
99
+ CE1-CZ PHE 1.369 0.019
100
+ CE2-CZ PHE 1.369 0.019
101
+ N-CA PHE 1.459 0.020
102
+ CA-C PHE 1.525 0.026
103
+ C-O PHE 1.229 0.019
104
+ CA-CB PRO 1.531 0.020
105
+ CB-CG PRO 1.495 0.050
106
+ CG-CD PRO 1.502 0.033
107
+ CD-N PRO 1.474 0.014
108
+ N-CA PRO 1.468 0.017
109
+ CA-C PRO 1.524 0.020
110
+ C-O PRO 1.228 0.020
111
+ CA-CB SER 1.525 0.015
112
+ CB-OG SER 1.418 0.013
113
+ N-CA SER 1.459 0.020
114
+ CA-C SER 1.525 0.026
115
+ C-O SER 1.229 0.019
116
+ CA-CB THR 1.529 0.026
117
+ CB-OG1 THR 1.428 0.020
118
+ CB-CG2 THR 1.519 0.033
119
+ N-CA THR 1.459 0.020
120
+ CA-C THR 1.525 0.026
121
+ C-O THR 1.229 0.019
122
+ CA-CB TRP 1.535 0.022
123
+ CB-CG TRP 1.498 0.018
124
+ CG-CD1 TRP 1.363 0.014
125
+ CG-CD2 TRP 1.432 0.017
126
+ CD1-NE1 TRP 1.375 0.017
127
+ NE1-CE2 TRP 1.371 0.013
128
+ CD2-CE2 TRP 1.409 0.012
129
+ CD2-CE3 TRP 1.399 0.015
130
+ CE2-CZ2 TRP 1.393 0.017
131
+ CE3-CZ3 TRP 1.380 0.017
132
+ CZ2-CH2 TRP 1.369 0.019
133
+ CZ3-CH2 TRP 1.396 0.016
134
+ N-CA TRP 1.459 0.020
135
+ CA-C TRP 1.525 0.026
136
+ C-O TRP 1.229 0.019
137
+ CA-CB TYR 1.535 0.022
138
+ CB-CG TYR 1.512 0.015
139
+ CG-CD1 TYR 1.387 0.013
140
+ CG-CD2 TYR 1.387 0.013
141
+ CD1-CE1 TYR 1.389 0.015
142
+ CD2-CE2 TYR 1.389 0.015
143
+ CE1-CZ TYR 1.381 0.013
144
+ CE2-CZ TYR 1.381 0.013
145
+ CZ-OH TYR 1.374 0.017
146
+ N-CA TYR 1.459 0.020
147
+ CA-C TYR 1.525 0.026
148
+ C-O TYR 1.229 0.019
149
+ CA-CB VAL 1.543 0.021
150
+ CB-CG1 VAL 1.524 0.021
151
+ CB-CG2 VAL 1.524 0.021
152
+ N-CA VAL 1.459 0.020
153
+ CA-C VAL 1.525 0.026
154
+ C-O VAL 1.229 0.019
155
+ -
156
+
157
+ Angle Residue Mean StdDev
158
+ N-CA-CB ALA 110.1 1.4
159
+ CB-CA-C ALA 110.1 1.5
160
+ N-CA-C ALA 111.0 2.7
161
+ CA-C-O ALA 120.1 2.1
162
+ N-CA-CB ARG 110.6 1.8
163
+ CB-CA-C ARG 110.4 2.0
164
+ CA-CB-CG ARG 113.4 2.2
165
+ CB-CG-CD ARG 111.6 2.6
166
+ CG-CD-NE ARG 111.8 2.1
167
+ CD-NE-CZ ARG 123.6 1.4
168
+ NE-CZ-NH1 ARG 120.3 0.5
169
+ NE-CZ-NH2 ARG 120.3 0.5
170
+ NH1-CZ-NH2 ARG 119.4 1.1
171
+ N-CA-C ARG 111.0 2.7
172
+ CA-C-O ARG 120.1 2.1
173
+ N-CA-CB ASN 110.6 1.8
174
+ CB-CA-C ASN 110.4 2.0
175
+ CA-CB-CG ASN 113.4 2.2
176
+ CB-CG-ND2 ASN 116.7 2.4
177
+ CB-CG-OD1 ASN 121.6 2.0
178
+ ND2-CG-OD1 ASN 121.9 2.3
179
+ N-CA-C ASN 111.0 2.7
180
+ CA-C-O ASN 120.1 2.1
181
+ N-CA-CB ASP 110.6 1.8
182
+ CB-CA-C ASP 110.4 2.0
183
+ CA-CB-CG ASP 113.4 2.2
184
+ CB-CG-OD1 ASP 118.3 0.9
185
+ CB-CG-OD2 ASP 118.3 0.9
186
+ OD1-CG-OD2 ASP 123.3 1.9
187
+ N-CA-C ASP 111.0 2.7
188
+ CA-C-O ASP 120.1 2.1
189
+ N-CA-CB CYS 110.8 1.5
190
+ CB-CA-C CYS 111.5 1.2
191
+ CA-CB-SG CYS 114.2 1.1
192
+ N-CA-C CYS 111.0 2.7
193
+ CA-C-O CYS 120.1 2.1
194
+ N-CA-CB GLU 110.6 1.8
195
+ CB-CA-C GLU 110.4 2.0
196
+ CA-CB-CG GLU 113.4 2.2
197
+ CB-CG-CD GLU 114.2 2.7
198
+ CG-CD-OE1 GLU 118.3 2.0
199
+ CG-CD-OE2 GLU 118.3 2.0
200
+ OE1-CD-OE2 GLU 123.3 1.2
201
+ N-CA-C GLU 111.0 2.7
202
+ CA-C-O GLU 120.1 2.1
203
+ N-CA-CB GLN 110.6 1.8
204
+ CB-CA-C GLN 110.4 2.0
205
+ CA-CB-CG GLN 113.4 2.2
206
+ CB-CG-CD GLN 111.6 2.6
207
+ CG-CD-OE1 GLN 121.6 2.0
208
+ CG-CD-NE2 GLN 116.7 2.4
209
+ OE1-CD-NE2 GLN 121.9 2.3
210
+ N-CA-C GLN 111.0 2.7
211
+ CA-C-O GLN 120.1 2.1
212
+ N-CA-C GLY 113.1 2.5
213
+ CA-C-O GLY 120.6 1.8
214
+ N-CA-CB HIS 110.6 1.8
215
+ CB-CA-C HIS 110.4 2.0
216
+ CA-CB-CG HIS 113.6 1.7
217
+ CB-CG-ND1 HIS 123.2 2.5
218
+ CB-CG-CD2 HIS 130.8 3.1
219
+ CG-ND1-CE1 HIS 108.2 1.4
220
+ ND1-CE1-NE2 HIS 109.9 2.2
221
+ CE1-NE2-CD2 HIS 106.6 2.5
222
+ NE2-CD2-CG HIS 109.2 1.9
223
+ CD2-CG-ND1 HIS 106.0 1.4
224
+ N-CA-C HIS 111.0 2.7
225
+ CA-C-O HIS 120.1 2.1
226
+ N-CA-CB ILE 110.8 2.3
227
+ CB-CA-C ILE 111.6 2.0
228
+ CA-CB-CG1 ILE 111.0 1.9
229
+ CB-CG1-CD1 ILE 113.9 2.8
230
+ CA-CB-CG2 ILE 110.9 2.0
231
+ CG1-CB-CG2 ILE 111.4 2.2
232
+ N-CA-C ILE 111.0 2.7
233
+ CA-C-O ILE 120.1 2.1
234
+ N-CA-CB LEU 110.4 2.0
235
+ CB-CA-C LEU 110.2 1.9
236
+ CA-CB-CG LEU 115.3 2.3
237
+ CB-CG-CD1 LEU 111.0 1.7
238
+ CB-CG-CD2 LEU 111.0 1.7
239
+ CD1-CG-CD2 LEU 110.5 3.0
240
+ N-CA-C LEU 111.0 2.7
241
+ CA-C-O LEU 120.1 2.1
242
+ N-CA-CB LYS 110.6 1.8
243
+ CB-CA-C LYS 110.4 2.0
244
+ CA-CB-CG LYS 113.4 2.2
245
+ CB-CG-CD LYS 111.6 2.6
246
+ CG-CD-CE LYS 111.9 3.0
247
+ CD-CE-NZ LYS 111.7 2.3
248
+ N-CA-C LYS 111.0 2.7
249
+ CA-C-O LYS 120.1 2.1
250
+ N-CA-CB MET 110.6 1.8
251
+ CB-CA-C MET 110.4 2.0
252
+ CA-CB-CG MET 113.3 1.7
253
+ CB-CG-SD MET 112.4 3.0
254
+ CG-SD-CE MET 100.2 1.6
255
+ N-CA-C MET 111.0 2.7
256
+ CA-C-O MET 120.1 2.1
257
+ N-CA-CB PHE 110.6 1.8
258
+ CB-CA-C PHE 110.4 2.0
259
+ CA-CB-CG PHE 113.9 2.4
260
+ CB-CG-CD1 PHE 120.8 0.7
261
+ CB-CG-CD2 PHE 120.8 0.7
262
+ CD1-CG-CD2 PHE 118.3 1.3
263
+ CG-CD1-CE1 PHE 120.8 1.1
264
+ CG-CD2-CE2 PHE 120.8 1.1
265
+ CD1-CE1-CZ PHE 120.1 1.2
266
+ CD2-CE2-CZ PHE 120.1 1.2
267
+ CE1-CZ-CE2 PHE 120.0 1.8
268
+ N-CA-C PHE 111.0 2.7
269
+ CA-C-O PHE 120.1 2.1
270
+ N-CA-CB PRO 103.3 1.2
271
+ CB-CA-C PRO 111.7 2.1
272
+ CA-CB-CG PRO 104.8 1.9
273
+ CB-CG-CD PRO 106.5 3.9
274
+ CG-CD-N PRO 103.2 1.5
275
+ CA-N-CD PRO 111.7 1.4
276
+ N-CA-C PRO 112.1 2.6
277
+ CA-C-O PRO 120.2 2.4
278
+ N-CA-CB SER 110.5 1.5
279
+ CB-CA-C SER 110.1 1.9
280
+ CA-CB-OG SER 111.2 2.7
281
+ N-CA-C SER 111.0 2.7
282
+ CA-C-O SER 120.1 2.1
283
+ N-CA-CB THR 110.3 1.9
284
+ CB-CA-C THR 111.6 2.7
285
+ CA-CB-OG1 THR 109.0 2.1
286
+ CA-CB-CG2 THR 112.4 1.4
287
+ OG1-CB-CG2 THR 110.0 2.3
288
+ N-CA-C THR 111.0 2.7
289
+ CA-C-O THR 120.1 2.1
290
+ N-CA-CB TRP 110.6 1.8
291
+ CB-CA-C TRP 110.4 2.0
292
+ CA-CB-CG TRP 113.7 1.9
293
+ CB-CG-CD1 TRP 127.0 1.3
294
+ CB-CG-CD2 TRP 126.6 1.3
295
+ CD1-CG-CD2 TRP 106.3 0.8
296
+ CG-CD1-NE1 TRP 110.1 1.0
297
+ CD1-NE1-CE2 TRP 109.0 0.9
298
+ NE1-CE2-CD2 TRP 107.3 1.0
299
+ CE2-CD2-CG TRP 107.3 0.8
300
+ CG-CD2-CE3 TRP 133.9 0.9
301
+ NE1-CE2-CZ2 TRP 130.4 1.1
302
+ CE3-CD2-CE2 TRP 118.7 1.2
303
+ CD2-CE2-CZ2 TRP 122.3 1.2
304
+ CE2-CZ2-CH2 TRP 117.4 1.0
305
+ CZ2-CH2-CZ3 TRP 121.6 1.2
306
+ CH2-CZ3-CE3 TRP 121.2 1.1
307
+ CZ3-CE3-CD2 TRP 118.8 1.3
308
+ N-CA-C TRP 111.0 2.7
309
+ CA-C-O TRP 120.1 2.1
310
+ N-CA-CB TYR 110.6 1.8
311
+ CB-CA-C TYR 110.4 2.0
312
+ CA-CB-CG TYR 113.4 1.9
313
+ CB-CG-CD1 TYR 121.0 0.6
314
+ CB-CG-CD2 TYR 121.0 0.6
315
+ CD1-CG-CD2 TYR 117.9 1.1
316
+ CG-CD1-CE1 TYR 121.3 0.8
317
+ CG-CD2-CE2 TYR 121.3 0.8
318
+ CD1-CE1-CZ TYR 119.8 0.9
319
+ CD2-CE2-CZ TYR 119.8 0.9
320
+ CE1-CZ-CE2 TYR 119.8 1.6
321
+ CE1-CZ-OH TYR 120.1 2.7
322
+ CE2-CZ-OH TYR 120.1 2.7
323
+ N-CA-C TYR 111.0 2.7
324
+ CA-C-O TYR 120.1 2.1
325
+ N-CA-CB VAL 111.5 2.2
326
+ CB-CA-C VAL 111.4 1.9
327
+ CA-CB-CG1 VAL 110.9 1.5
328
+ CA-CB-CG2 VAL 110.9 1.5
329
+ CG1-CB-CG2 VAL 110.9 1.6
330
+ N-CA-C VAL 111.0 2.7
331
+ CA-C-O VAL 120.1 2.1
332
+ -
333
+
334
+ Non-bonded distance Minimum Dist Tolerance
335
+ C-C 3.4 1.5
336
+ C-N 3.25 1.5
337
+ C-S 3.5 1.5
338
+ C-O 3.22 1.5
339
+ N-N 3.1 1.5
340
+ N-S 3.35 1.5
341
+ N-O 3.07 1.5
342
+ O-S 3.32 1.5
343
+ O-O 3.04 1.5
344
+ S-S 2.03 1.0
345
+ -
dockformerpp/utils/.DS_Store ADDED
Binary file (6.15 kB). View file
 
dockformerpp/utils/__init__.py ADDED
File without changes
dockformerpp/utils/callbacks.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from lightning.pytorch.callbacks import EarlyStopping
2
+ from lightning_utilities.core.rank_zero import rank_zero_info
3
+
4
+
5
+ class EarlyStoppingVerbose(EarlyStopping):
6
+ """
7
+ The default EarlyStopping callback's verbose mode is too verbose.
8
+ This class outputs a message only when it's getting ready to stop.
9
+ """
10
+ def _evalute_stopping_criteria(self, *args, **kwargs):
11
+ should_stop, reason = super()._evalute_stopping_criteria(*args, **kwargs)
12
+ if(should_stop):
13
+ rank_zero_info(f"{reason}\n")
14
+
15
+ return should_stop, reason
dockformerpp/utils/checkpointing.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import importlib
15
+ from typing import Any, Tuple, List, Callable, Optional
16
+
17
+
18
+ import torch
19
+ import torch.utils.checkpoint
20
+
21
+
22
+ BLOCK_ARG = Any
23
+ BLOCK_ARGS = List[BLOCK_ARG]
24
+
25
+
26
+ @torch.jit.ignore
27
+ def checkpoint_blocks(
28
+ blocks: List[Callable],
29
+ args: BLOCK_ARGS,
30
+ blocks_per_ckpt: Optional[int],
31
+ ) -> BLOCK_ARGS:
32
+ """
33
+ Chunk a list of blocks and run each chunk with activation
34
+ checkpointing. We define a "block" as a callable whose only inputs are
35
+ the outputs of the previous block.
36
+
37
+ Implements Subsection 1.11.8
38
+
39
+ Args:
40
+ blocks:
41
+ List of blocks
42
+ args:
43
+ Tuple of arguments for the first block.
44
+ blocks_per_ckpt:
45
+ Size of each chunk. A higher value corresponds to fewer
46
+ checkpoints, and trades memory for speed. If None, no checkpointing
47
+ is performed.
48
+ Returns:
49
+ The output of the final block
50
+ """
51
+ def wrap(a):
52
+ return (a,) if type(a) is not tuple else a
53
+
54
+ def exec(b, a):
55
+ for block in b:
56
+ a = wrap(block(*a))
57
+ return a
58
+
59
+ def chunker(s, e):
60
+ def exec_sliced(*a):
61
+ return exec(blocks[s:e], a)
62
+
63
+ return exec_sliced
64
+
65
+ # Avoids mishaps when the blocks take just one argument
66
+ args = wrap(args)
67
+
68
+ if blocks_per_ckpt is None or not torch.is_grad_enabled():
69
+ return exec(blocks, args)
70
+ elif blocks_per_ckpt < 1 or blocks_per_ckpt > len(blocks):
71
+ raise ValueError("blocks_per_ckpt must be between 1 and len(blocks)")
72
+
73
+ for s in range(0, len(blocks), blocks_per_ckpt):
74
+ e = s + blocks_per_ckpt
75
+ args = torch.utils.checkpoint.checkpoint(chunker(s, e), *args, use_reentrant=True)
76
+ args = wrap(args)
77
+
78
+ return args
dockformerpp/utils/config_tools.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib.util
2
+
3
+ import ml_collections as mlc
4
+
5
+
6
+ def set_inf(c, inf):
7
+ for k, v in c.items():
8
+ if isinstance(v, mlc.ConfigDict):
9
+ set_inf(v, inf)
10
+ elif k == "inf":
11
+ c[k] = inf
12
+
13
+
14
+ def enforce_config_constraints(config):
15
+ def string_to_setting(s):
16
+ path = s.split('.')
17
+ setting = config
18
+ for p in path:
19
+ setting = setting.get(p)
20
+
21
+ return setting
22
+
23
+ mutually_exclusive_bools = [
24
+ (
25
+ "globals.use_lma",
26
+ ),
27
+ ]
28
+
29
+ for options in mutually_exclusive_bools:
30
+ option_settings = [string_to_setting(o) for o in options]
31
+ if sum(option_settings) > 1:
32
+ raise ValueError(f"Only one of {', '.join(options)} may be set at a time")
dockformerpp/utils/consts.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from rdkit.Chem.rdchem import ChiralType, BondType
2
+
3
+ # Survey of atom types in the PDBBind
4
+ # {'C': 403253, 'O': 101283, 'N': 81325, 'S': 6262, 'F': 5256, 'P': 3378, 'Cl': 2920, 'Br': 552, 'B': 237, 'I': 185,
5
+ # 'H': 181, 'Fe': 19, 'Se': 15, 'Ru': 10, 'Si': 5, 'Co': 4, 'Ir': 4, 'As': 2, 'Pt': 2, 'V': 1, 'Mg': 1, 'Be': 1,
6
+ # 'Rh': 1, 'Cu': 1, 'Re': 1}
7
+ # I have changed the uncommon types to common ions for the plinder dataset
8
+ # {'As': "Zn", 'Pt': "Mn", 'V': "Ca", 'Mg': "Mg", 'Be': "Na", 'Rh': "Al", 'Cu': "K", 'Re': "Ni"}
9
+
10
+ POSSIBLE_ATOM_TYPES = ['C', 'O', 'N', 'S', 'F', 'P', 'Cl', 'Br', 'B', 'I', 'H', 'Fe', 'Se', 'Ru', 'Si', 'Co', 'Ir',
11
+ 'Zn', 'Mn', 'Ca', 'Mg', 'Na', 'Al', 'K', 'Ni']
12
+
13
+ # bonds Counter({BondType.SINGLE: 366857, BondType.AROMATIC: 214238, BondType.DOUBLE: 59725, BondType.TRIPLE: 866,
14
+ # BondType.UNSPECIFIED: 18, BondType.DATIVE: 8})
15
+ POSSIBLE_BOND_TYPES = [BondType.SINGLE, BondType.DOUBLE, BondType.TRIPLE, BondType.AROMATIC, BondType.UNSPECIFIED,
16
+ BondType.DATIVE]
17
+
18
+ # {0: 580061, 1: 13273, -1: 11473, 2: 44, 7: 17, -2: 8, 9: 7, 10: 7, 5: 3, 3: 3, 4: 1, 6: 1, 8: 1}
19
+ POSSIBLE_CHARGES = [-1, 0, 1]
20
+
21
+ # {ChiralType.CHI_UNSPECIFIED: 551374, ChiralType.CHI_TETRAHEDRAL_CCW: 27328, ChiralType.CHI_TETRAHEDRAL_CW: 26178,
22
+ # ChiralType.CHI_OCTAHEDRAL: 13, ChiralType.CHI_SQUAREPLANAR: 3, ChiralType.CHI_TRIGONALBIPYRAMIDAL: 3}
23
+ POSSIBLE_CHIRALITIES = [ChiralType.CHI_UNSPECIFIED, ChiralType.CHI_TETRAHEDRAL_CCW, ChiralType.CHI_TETRAHEDRAL_CW,
24
+ ChiralType.CHI_OCTAHEDRAL, ChiralType.CHI_SQUAREPLANAR, ChiralType.CHI_TRIGONALBIPYRAMIDAL]
25
+
dockformerpp/utils/exponential_moving_average.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ import copy
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from dockformerpp.utils.tensor_utils import tensor_tree_map
7
+
8
+
9
+ class ExponentialMovingAverage:
10
+ """
11
+ Maintains moving averages of parameters with exponential decay
12
+
13
+ At each step, the stored copy `copy` of each parameter `param` is
14
+ updated as follows:
15
+
16
+ `copy = decay * copy + (1 - decay) * param`
17
+
18
+ where `decay` is an attribute of the ExponentialMovingAverage object.
19
+ """
20
+
21
+ def __init__(self, model: nn.Module, decay: float):
22
+ """
23
+ Args:
24
+ model:
25
+ A torch.nn.Module whose parameters are to be tracked
26
+ decay:
27
+ A value (usually close to 1.) by which updates are
28
+ weighted as part of the above formula
29
+ """
30
+ super(ExponentialMovingAverage, self).__init__()
31
+
32
+ clone_param = lambda t: t.clone().detach()
33
+ self.params = tensor_tree_map(clone_param, model.state_dict())
34
+ self.decay = decay
35
+ self.device = next(model.parameters()).device
36
+
37
+ def to(self, device):
38
+ self.params = tensor_tree_map(lambda t: t.to(device), self.params)
39
+ self.device = device
40
+
41
+ def _update_state_dict_(self, update, state_dict):
42
+ with torch.no_grad():
43
+ for k, v in update.items():
44
+ stored = state_dict[k]
45
+ if not isinstance(v, torch.Tensor):
46
+ self._update_state_dict_(v, stored)
47
+ else:
48
+ diff = stored - v
49
+ diff *= 1 - self.decay
50
+ stored -= diff
51
+
52
+ def update(self, model: torch.nn.Module) -> None:
53
+ """
54
+ Updates the stored parameters using the state dict of the provided
55
+ module. The module should have the same structure as that used to
56
+ initialize the ExponentialMovingAverage object.
57
+ """
58
+ self._update_state_dict_(model.state_dict(), self.params)
59
+
60
+ def load_state_dict(self, state_dict: OrderedDict) -> None:
61
+ for k in state_dict["params"].keys():
62
+ self.params[k] = state_dict["params"][k].clone()
63
+ self.decay = state_dict["decay"]
64
+
65
+ def state_dict(self) -> OrderedDict:
66
+ return OrderedDict(
67
+ {
68
+ "params": self.params,
69
+ "decay": self.decay,
70
+ }
71
+ )
dockformerpp/utils/feats.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import math
17
+
18
+ import numpy as np
19
+ import torch
20
+ import torch.nn as nn
21
+ from typing import Dict, Union
22
+
23
+ from dockformerpp.utils import protein
24
+ import dockformerpp.utils.residue_constants as rc
25
+ from dockformerpp.utils.geometry import rigid_matrix_vector, rotation_matrix, vector
26
+ from dockformerpp.utils.rigid_utils import Rotation, Rigid
27
+ from dockformerpp.utils.tensor_utils import (
28
+ batched_gather,
29
+ one_hot,
30
+ tree_map,
31
+ tensor_tree_map,
32
+ )
33
+
34
+
35
+ def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks):
36
+ # rc.restype_order["X"] defines a ligand, and the atom position used is the CA
37
+ is_gly_or_lig = (aatype == rc.restype_order["G"]) | (aatype == rc.restype_order["Z"])
38
+ ca_idx = rc.atom_order["CA"]
39
+ cb_idx = rc.atom_order["CB"]
40
+ pseudo_beta = torch.where(
41
+ is_gly_or_lig[..., None].expand(*((-1,) * len(is_gly_or_lig.shape)), 3),
42
+ all_atom_positions[..., ca_idx, :],
43
+ all_atom_positions[..., cb_idx, :],
44
+ )
45
+
46
+ if all_atom_masks is not None:
47
+ pseudo_beta_mask = torch.where(
48
+ is_gly_or_lig,
49
+ all_atom_masks[..., ca_idx],
50
+ all_atom_masks[..., cb_idx],
51
+ )
52
+ return pseudo_beta, pseudo_beta_mask
53
+ else:
54
+ return pseudo_beta
55
+
56
+
57
+ def atom14_to_atom37(atom14, batch):
58
+ atom37_data = batched_gather(
59
+ atom14,
60
+ batch["residx_atom37_to_atom14"],
61
+ dim=-2,
62
+ no_batch_dims=len(atom14.shape[:-2]),
63
+ )
64
+
65
+ atom37_data = atom37_data * batch["atom37_atom_exists"][..., None]
66
+
67
+ return atom37_data
68
+
69
+
70
+ def torsion_angles_to_frames(
71
+ r: Union[Rigid, rigid_matrix_vector.Rigid3Array],
72
+ alpha: torch.Tensor,
73
+ aatype: torch.Tensor,
74
+ rrgdf: torch.Tensor,
75
+ ):
76
+
77
+ rigid_type = type(r)
78
+
79
+ # [*, N, 8, 4, 4]
80
+ default_4x4 = rrgdf[aatype, ...]
81
+
82
+ # [*, N, 8] transformations, i.e.
83
+ # One [*, N, 8, 3, 3] rotation matrix and
84
+ # One [*, N, 8, 3] translation matrix
85
+ default_r = rigid_type.from_tensor_4x4(default_4x4)
86
+
87
+ bb_rot = alpha.new_zeros((*((1,) * len(alpha.shape[:-1])), 2))
88
+ bb_rot[..., 1] = 1
89
+
90
+ # [*, N, 8, 2]
91
+ alpha = torch.cat(
92
+ [bb_rot.expand(*alpha.shape[:-2], -1, -1), alpha], dim=-2
93
+ )
94
+
95
+ # [*, N, 8, 3, 3]
96
+ # Produces rotation matrices of the form:
97
+ # [
98
+ # [1, 0 , 0 ],
99
+ # [0, a_2,-a_1],
100
+ # [0, a_1, a_2]
101
+ # ]
102
+ # This follows the original code rather than the supplement, which uses
103
+ # different indices.
104
+
105
+ all_rots = alpha.new_zeros(default_r.shape + (4, 4))
106
+ all_rots[..., 0, 0] = 1
107
+ all_rots[..., 1, 1] = alpha[..., 1]
108
+ all_rots[..., 1, 2] = -alpha[..., 0]
109
+ all_rots[..., 2, 1:3] = alpha
110
+
111
+ all_rots = rigid_type.from_tensor_4x4(all_rots)
112
+ all_frames = default_r.compose(all_rots)
113
+
114
+ chi2_frame_to_frame = all_frames[..., 5]
115
+ chi3_frame_to_frame = all_frames[..., 6]
116
+ chi4_frame_to_frame = all_frames[..., 7]
117
+
118
+ chi1_frame_to_bb = all_frames[..., 4]
119
+ chi2_frame_to_bb = chi1_frame_to_bb.compose(chi2_frame_to_frame)
120
+ chi3_frame_to_bb = chi2_frame_to_bb.compose(chi3_frame_to_frame)
121
+ chi4_frame_to_bb = chi3_frame_to_bb.compose(chi4_frame_to_frame)
122
+
123
+ all_frames_to_bb = rigid_type.cat(
124
+ [
125
+ all_frames[..., :5],
126
+ chi2_frame_to_bb.unsqueeze(-1),
127
+ chi3_frame_to_bb.unsqueeze(-1),
128
+ chi4_frame_to_bb.unsqueeze(-1),
129
+ ],
130
+ dim=-1,
131
+ )
132
+
133
+ all_frames_to_global = r[..., None].compose(all_frames_to_bb)
134
+
135
+ return all_frames_to_global
136
+
137
+
138
+ def frames_and_literature_positions_to_atom14_pos(
139
+ r: Union[Rigid, rigid_matrix_vector.Rigid3Array],
140
+ aatype: torch.Tensor,
141
+ default_frames,
142
+ group_idx,
143
+ atom_mask,
144
+ lit_positions,
145
+ ):
146
+ # [*, N, 14, 4, 4]
147
+ default_4x4 = default_frames[aatype, ...]
148
+
149
+ # [*, N, 14]
150
+ group_mask = group_idx[aatype, ...]
151
+
152
+ # [*, N, 14, 8]
153
+ group_mask = nn.functional.one_hot(
154
+ group_mask,
155
+ num_classes=default_frames.shape[-3],
156
+ )
157
+
158
+ # [*, N, 14, 8]
159
+ t_atoms_to_global = r[..., None, :] * group_mask
160
+
161
+ # [*, N, 14]
162
+ t_atoms_to_global = t_atoms_to_global.map_tensor_fn(
163
+ lambda x: torch.sum(x, dim=-1)
164
+ )
165
+
166
+ # [*, N, 14]
167
+ atom_mask = atom_mask[aatype, ...].unsqueeze(-1)
168
+
169
+ # [*, N, 14, 3]
170
+ lit_positions = lit_positions[aatype, ...]
171
+ pred_positions = t_atoms_to_global.apply(lit_positions)
172
+ pred_positions = pred_positions * atom_mask
173
+
174
+ return pred_positions
dockformerpp/utils/geometry/__init__.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 DeepMind Technologies Limited
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Geometry Module."""
15
+
16
+ from dockformerpp.utils.geometry import rigid_matrix_vector
17
+ from dockformerpp.utils.geometry import rotation_matrix
18
+ from dockformerpp.utils.geometry import vector
19
+
20
+ Rot3Array = rotation_matrix.Rot3Array
21
+ Rigid3Array = rigid_matrix_vector.Rigid3Array
22
+
23
+ Vec3Array = vector.Vec3Array
24
+ square_euclidean_distance = vector.square_euclidean_distance
25
+ euclidean_distance = vector.euclidean_distance
26
+ dihedral_angle = vector.dihedral_angle
27
+ dot = vector.dot
28
+ cross = vector.cross
dockformerpp/utils/geometry/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (545 Bytes). View file
 
dockformerpp/utils/geometry/__pycache__/quat_rigid.cpython-39.pyc ADDED
Binary file (1.51 kB). View file
 
dockformerpp/utils/geometry/__pycache__/rigid_matrix_vector.cpython-39.pyc ADDED
Binary file (6.84 kB). View file
 
dockformerpp/utils/geometry/__pycache__/rotation_matrix.cpython-39.pyc ADDED
Binary file (7.98 kB). View file
 
dockformerpp/utils/geometry/__pycache__/utils.cpython-39.pyc ADDED
Binary file (575 Bytes). View file
 
dockformerpp/utils/geometry/__pycache__/vector.cpython-39.pyc ADDED
Binary file (8.92 kB). View file
 
dockformerpp/utils/geometry/quat_rigid.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from dockformerpp.model.primitives import Linear
5
+ from dockformerpp.utils.geometry.rigid_matrix_vector import Rigid3Array
6
+ from dockformerpp.utils.geometry.rotation_matrix import Rot3Array
7
+ from dockformerpp.utils.geometry.vector import Vec3Array
8
+
9
+
10
+ class QuatRigid(nn.Module):
11
+ def __init__(self, c_hidden, full_quat):
12
+ super().__init__()
13
+ self.full_quat = full_quat
14
+ if self.full_quat:
15
+ rigid_dim = 7
16
+ else:
17
+ rigid_dim = 6
18
+
19
+ self.linear = Linear(c_hidden, rigid_dim, init="final", precision=torch.float32)
20
+
21
+ def forward(self, activations: torch.Tensor) -> Rigid3Array:
22
+ # NOTE: During training, this needs to be run in higher precision
23
+ rigid_flat = self.linear(activations)
24
+
25
+ rigid_flat = torch.unbind(rigid_flat, dim=-1)
26
+ if(self.full_quat):
27
+ qw, qx, qy, qz = rigid_flat[:4]
28
+ translation = rigid_flat[4:]
29
+ else:
30
+ qx, qy, qz = rigid_flat[:3]
31
+ qw = torch.ones_like(qx)
32
+ translation = rigid_flat[3:]
33
+
34
+ rotation = Rot3Array.from_quaternion(
35
+ qw, qx, qy, qz, normalize=True,
36
+ )
37
+ translation = Vec3Array(*translation)
38
+ return Rigid3Array(rotation, translation)
dockformerpp/utils/geometry/rigid_matrix_vector.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 DeepMind Technologies Limited
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Rigid3Array Transformations represented by a Matrix and a Vector."""
15
+
16
+ from __future__ import annotations
17
+ import dataclasses
18
+ from typing import Union, List
19
+
20
+ import torch
21
+
22
+ from dockformerpp.utils.geometry import rotation_matrix
23
+ from dockformerpp.utils.geometry import vector
24
+
25
+
26
+ Float = Union[float, torch.Tensor]
27
+
28
+
29
+ @dataclasses.dataclass(frozen=True)
30
+ class Rigid3Array:
31
+ """Rigid Transformation, i.e. element of special euclidean group."""
32
+
33
+ rotation: rotation_matrix.Rot3Array
34
+ translation: vector.Vec3Array
35
+
36
+ def __matmul__(self, other: Rigid3Array) -> Rigid3Array:
37
+ new_rotation = self.rotation @ other.rotation # __matmul__
38
+ new_translation = self.apply_to_point(other.translation)
39
+ return Rigid3Array(new_rotation, new_translation)
40
+
41
+ def __getitem__(self, index) -> Rigid3Array:
42
+ return Rigid3Array(
43
+ self.rotation[index],
44
+ self.translation[index],
45
+ )
46
+
47
+ def __mul__(self, other: torch.Tensor) -> Rigid3Array:
48
+ return Rigid3Array(
49
+ self.rotation * other,
50
+ self.translation * other,
51
+ )
52
+
53
+ def map_tensor_fn(self, fn) -> Rigid3Array:
54
+ return Rigid3Array(
55
+ self.rotation.map_tensor_fn(fn),
56
+ self.translation.map_tensor_fn(fn),
57
+ )
58
+
59
+ def inverse(self) -> Rigid3Array:
60
+ """Return Rigid3Array corresponding to inverse transform."""
61
+ inv_rotation = self.rotation.inverse()
62
+ inv_translation = inv_rotation.apply_to_point(-self.translation)
63
+ return Rigid3Array(inv_rotation, inv_translation)
64
+
65
+ def apply_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array:
66
+ """Apply Rigid3Array transform to point."""
67
+ return self.rotation.apply_to_point(point) + self.translation
68
+
69
+ def apply(self, point: torch.Tensor) -> torch.Tensor:
70
+ return self.apply_to_point(vector.Vec3Array.from_array(point)).to_tensor()
71
+
72
+ def apply_inverse_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array:
73
+ """Apply inverse Rigid3Array transform to point."""
74
+ new_point = point - self.translation
75
+ return self.rotation.apply_inverse_to_point(new_point)
76
+
77
+ def invert_apply(self, point: torch.Tensor) -> torch.Tensor:
78
+ return self.apply_inverse_to_point(vector.Vec3Array.from_array(point)).to_tensor()
79
+
80
+ def compose_rotation(self, other_rotation):
81
+ rot = self.rotation @ other_rotation
82
+ return Rigid3Array(rot, self.translation.clone())
83
+
84
+ def compose(self, other_rigid):
85
+ return self @ other_rigid
86
+
87
+ def unsqueeze(self, dim: int):
88
+ return Rigid3Array(
89
+ self.rotation.unsqueeze(dim),
90
+ self.translation.unsqueeze(dim),
91
+ )
92
+
93
+ @property
94
+ def shape(self) -> torch.Size:
95
+ return self.rotation.xx.shape
96
+
97
+ @property
98
+ def dtype(self) -> torch.dtype:
99
+ return self.rotation.xx.dtype
100
+
101
+ @property
102
+ def device(self) -> torch.device:
103
+ return self.rotation.xx.device
104
+
105
+ @classmethod
106
+ def identity(cls, shape, device) -> Rigid3Array:
107
+ """Return identity Rigid3Array of given shape."""
108
+ return cls(
109
+ rotation_matrix.Rot3Array.identity(shape, device),
110
+ vector.Vec3Array.zeros(shape, device)
111
+ )
112
+
113
+ @classmethod
114
+ def cat(cls, rigids: List[Rigid3Array], dim: int) -> Rigid3Array:
115
+ return cls(
116
+ rotation_matrix.Rot3Array.cat(
117
+ [r.rotation for r in rigids], dim=dim
118
+ ),
119
+ vector.Vec3Array.cat(
120
+ [r.translation for r in rigids], dim=dim
121
+ ),
122
+ )
123
+
124
+ def scale_translation(self, factor: Float) -> Rigid3Array:
125
+ """Scale translation in Rigid3Array by 'factor'."""
126
+ return Rigid3Array(self.rotation, self.translation * factor)
127
+
128
+ def to_tensor(self) -> torch.Tensor:
129
+ rot_array = self.rotation.to_tensor()
130
+ vec_array = self.translation.to_tensor()
131
+ array = torch.zeros(
132
+ rot_array.shape[:-2] + (4, 4),
133
+ device=rot_array.device,
134
+ dtype=rot_array.dtype
135
+ )
136
+ array[..., :3, :3] = rot_array
137
+ array[..., :3, 3] = vec_array
138
+ array[..., 3, 3] = 1.
139
+ return array
140
+
141
+ def to_tensor_4x4(self) -> torch.Tensor:
142
+ return self.to_tensor()
143
+
144
+ def reshape(self, new_shape) -> Rigid3Array:
145
+ rots = self.rotation.reshape(new_shape)
146
+ trans = self.translation.reshape(new_shape)
147
+ return Rigid3Array(rots, trans)
148
+
149
+ def stop_rot_gradient(self) -> Rigid3Array:
150
+ return Rigid3Array(
151
+ self.rotation.stop_gradient(),
152
+ self.translation,
153
+ )
154
+
155
+ @classmethod
156
+ def from_array(cls, array):
157
+ rot = rotation_matrix.Rot3Array.from_array(
158
+ array[..., :3, :3],
159
+ )
160
+ vec = vector.Vec3Array.from_array(array[..., :3, 3])
161
+ return cls(rot, vec)
162
+
163
+ @classmethod
164
+ def from_tensor_4x4(cls, array):
165
+ return cls.from_array(array)
166
+
167
+ @classmethod
168
+ def from_array4x4(cls, array: torch.tensor) -> Rigid3Array:
169
+ """Construct Rigid3Array from homogeneous 4x4 array."""
170
+ rotation = rotation_matrix.Rot3Array(
171
+ array[..., 0, 0], array[..., 0, 1], array[..., 0, 2],
172
+ array[..., 1, 0], array[..., 1, 1], array[..., 1, 2],
173
+ array[..., 2, 0], array[..., 2, 1], array[..., 2, 2]
174
+ )
175
+ translation = vector.Vec3Array(
176
+ array[..., 0, 3], array[..., 1, 3], array[..., 2, 3]
177
+ )
178
+ return cls(rotation, translation)
179
+
180
+ def cuda(self) -> Rigid3Array:
181
+ return Rigid3Array.from_tensor_4x4(self.to_tensor_4x4().cuda())
dockformerpp/utils/geometry/rotation_matrix.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 DeepMind Technologies Limited
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Rot3Array Matrix Class."""
15
+
16
+ from __future__ import annotations
17
+ import dataclasses
18
+ from typing import List
19
+
20
+ import torch
21
+
22
+ from dockformerpp.utils.geometry import utils
23
+ from dockformerpp.utils.geometry import vector
24
+ from dockformerpp.utils.tensor_utils import tensor_tree_map
25
+
26
+
27
+ COMPONENTS = ['xx', 'xy', 'xz', 'yx', 'yy', 'yz', 'zx', 'zy', 'zz']
28
+
29
+ @dataclasses.dataclass(frozen=True)
30
+ class Rot3Array:
31
+ """Rot3Array Matrix in 3 dimensional Space implemented as struct of arrays."""
32
+ xx: torch.Tensor = dataclasses.field(metadata={'dtype': torch.float32})
33
+ xy: torch.Tensor
34
+ xz: torch.Tensor
35
+ yx: torch.Tensor
36
+ yy: torch.Tensor
37
+ yz: torch.Tensor
38
+ zx: torch.Tensor
39
+ zy: torch.Tensor
40
+ zz: torch.Tensor
41
+
42
+ __array_ufunc__ = None
43
+
44
+ def __getitem__(self, index):
45
+ field_names = utils.get_field_names(Rot3Array)
46
+ return Rot3Array(
47
+ **{
48
+ name: getattr(self, name)[index]
49
+ for name in field_names
50
+ }
51
+ )
52
+
53
+ def __mul__(self, other: torch.Tensor):
54
+ field_names = utils.get_field_names(Rot3Array)
55
+ return Rot3Array(
56
+ **{
57
+ name: getattr(self, name) * other
58
+ for name in field_names
59
+ }
60
+ )
61
+
62
+ def __matmul__(self, other: Rot3Array) -> Rot3Array:
63
+ """Composes two Rot3Arrays."""
64
+ c0 = self.apply_to_point(vector.Vec3Array(other.xx, other.yx, other.zx))
65
+ c1 = self.apply_to_point(vector.Vec3Array(other.xy, other.yy, other.zy))
66
+ c2 = self.apply_to_point(vector.Vec3Array(other.xz, other.yz, other.zz))
67
+ return Rot3Array(c0.x, c1.x, c2.x, c0.y, c1.y, c2.y, c0.z, c1.z, c2.z)
68
+
69
+ def map_tensor_fn(self, fn) -> Rot3Array:
70
+ field_names = utils.get_field_names(Rot3Array)
71
+ return Rot3Array(
72
+ **{
73
+ name: fn(getattr(self, name))
74
+ for name in field_names
75
+ }
76
+ )
77
+
78
+ def inverse(self) -> Rot3Array:
79
+ """Returns inverse of Rot3Array."""
80
+ return Rot3Array(
81
+ self.xx, self.yx, self.zx,
82
+ self.xy, self.yy, self.zy,
83
+ self.xz, self.yz, self.zz
84
+ )
85
+
86
+ def apply_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array:
87
+ """Applies Rot3Array to point."""
88
+ return vector.Vec3Array(
89
+ self.xx * point.x + self.xy * point.y + self.xz * point.z,
90
+ self.yx * point.x + self.yy * point.y + self.yz * point.z,
91
+ self.zx * point.x + self.zy * point.y + self.zz * point.z
92
+ )
93
+
94
+ def apply_inverse_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array:
95
+ """Applies inverse Rot3Array to point."""
96
+ return self.inverse().apply_to_point(point)
97
+
98
+
99
+ def unsqueeze(self, dim: int):
100
+ return Rot3Array(
101
+ *tensor_tree_map(
102
+ lambda t: t.unsqueeze(dim),
103
+ [getattr(self, c) for c in COMPONENTS]
104
+ )
105
+ )
106
+
107
+ def stop_gradient(self) -> Rot3Array:
108
+ return Rot3Array(
109
+ *[getattr(self, c).detach() for c in COMPONENTS]
110
+ )
111
+
112
+ @classmethod
113
+ def identity(cls, shape, device) -> Rot3Array:
114
+ """Returns identity of given shape."""
115
+ ones = torch.ones(shape, dtype=torch.float32, device=device)
116
+ zeros = torch.zeros(shape, dtype=torch.float32, device=device)
117
+ return cls(ones, zeros, zeros, zeros, ones, zeros, zeros, zeros, ones)
118
+
119
+ @classmethod
120
+ def from_two_vectors(
121
+ cls, e0: vector.Vec3Array,
122
+ e1: vector.Vec3Array
123
+ ) -> Rot3Array:
124
+ """Construct Rot3Array from two Vectors.
125
+
126
+ Rot3Array is constructed such that in the corresponding frame 'e0' lies on
127
+ the positive x-Axis and 'e1' lies in the xy plane with positive sign of y.
128
+
129
+ Args:
130
+ e0: Vector
131
+ e1: Vector
132
+ Returns:
133
+ Rot3Array
134
+ """
135
+ # Normalize the unit vector for the x-axis, e0.
136
+ e0 = e0.normalized()
137
+ # make e1 perpendicular to e0.
138
+ c = e1.dot(e0)
139
+ e1 = (e1 - c * e0).normalized()
140
+ # Compute e2 as cross product of e0 and e1.
141
+ e2 = e0.cross(e1)
142
+ return cls(e0.x, e1.x, e2.x, e0.y, e1.y, e2.y, e0.z, e1.z, e2.z)
143
+
144
+ @classmethod
145
+ def from_array(cls, array: torch.Tensor) -> Rot3Array:
146
+ """Construct Rot3Array Matrix from array of shape. [..., 3, 3]."""
147
+ rows = torch.unbind(array, dim=-2)
148
+ rc = [torch.unbind(e, dim=-1) for e in rows]
149
+ return cls(*[e for row in rc for e in row])
150
+
151
+ def to_tensor(self) -> torch.Tensor:
152
+ """Convert Rot3Array to array of shape [..., 3, 3]."""
153
+ return torch.stack(
154
+ [
155
+ torch.stack([self.xx, self.xy, self.xz], dim=-1),
156
+ torch.stack([self.yx, self.yy, self.yz], dim=-1),
157
+ torch.stack([self.zx, self.zy, self.zz], dim=-1)
158
+ ],
159
+ dim=-2
160
+ )
161
+
162
+ @classmethod
163
+ def from_quaternion(cls,
164
+ w: torch.Tensor,
165
+ x: torch.Tensor,
166
+ y: torch.Tensor,
167
+ z: torch.Tensor,
168
+ normalize: bool = True,
169
+ eps: float = 1e-6
170
+ ) -> Rot3Array:
171
+ """Construct Rot3Array from components of quaternion."""
172
+ if normalize:
173
+ inv_norm = torch.rsqrt(torch.clamp(w**2 + x**2 + y**2 + z**2, min=eps))
174
+ w = w * inv_norm
175
+ x = x * inv_norm
176
+ y = y * inv_norm
177
+ z = z * inv_norm
178
+ xx = 1.0 - 2.0 * (y ** 2 + z ** 2)
179
+ xy = 2.0 * (x * y - w * z)
180
+ xz = 2.0 * (x * z + w * y)
181
+ yx = 2.0 * (x * y + w * z)
182
+ yy = 1.0 - 2.0 * (x ** 2 + z ** 2)
183
+ yz = 2.0 * (y * z - w * x)
184
+ zx = 2.0 * (x * z - w * y)
185
+ zy = 2.0 * (y * z + w * x)
186
+ zz = 1.0 - 2.0 * (x ** 2 + y ** 2)
187
+ return cls(xx, xy, xz, yx, yy, yz, zx, zy, zz)
188
+
189
+ def reshape(self, new_shape):
190
+ field_names = utils.get_field_names(Rot3Array)
191
+ reshape_fn = lambda t: t.reshape(new_shape)
192
+ return Rot3Array(
193
+ **{
194
+ name: reshape_fn(getattr(self, name))
195
+ for name in field_names
196
+ }
197
+ )
198
+
199
+ @classmethod
200
+ def cat(cls, rots: List[Rot3Array], dim: int) -> Rot3Array:
201
+ field_names = utils.get_field_names(Rot3Array)
202
+ cat_fn = lambda l: torch.cat(l, dim=dim)
203
+ return cls(
204
+ **{
205
+ name: cat_fn([getattr(r, name) for r in rots])
206
+ for name in field_names
207
+ }
208
+ )
dockformerpp/utils/geometry/test_utils.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 DeepMind Technologies Limited
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Shared utils for tests."""
15
+
16
+ import dataclasses
17
+ import torch
18
+
19
+ from dockformerpp.utils.geometry import rigid_matrix_vector
20
+ from dockformerpp.utils.geometry import rotation_matrix
21
+ from dockformerpp.utils.geometry import vector
22
+
23
+
24
+ def assert_rotation_matrix_equal(matrix1: rotation_matrix.Rot3Array,
25
+ matrix2: rotation_matrix.Rot3Array):
26
+ for field in dataclasses.fields(rotation_matrix.Rot3Array):
27
+ field = field.name
28
+ assert torch.equal(
29
+ getattr(matrix1, field), getattr(matrix2, field))
30
+
31
+
32
+ def assert_rotation_matrix_close(mat1: rotation_matrix.Rot3Array,
33
+ mat2: rotation_matrix.Rot3Array):
34
+ assert torch.allclose(mat1.to_tensor(), mat2.to_tensor(), atol=1e-6)
35
+
36
+
37
+ def assert_array_equal_to_rotation_matrix(array: torch.Tensor,
38
+ matrix: rotation_matrix.Rot3Array):
39
+ """Check that array and Matrix match."""
40
+ assert torch.equal(matrix.xx, array[..., 0, 0])
41
+ assert torch.equal(matrix.xy, array[..., 0, 1])
42
+ assert torch.equal(matrix.xz, array[..., 0, 2])
43
+ assert torch.equal(matrix.yx, array[..., 1, 0])
44
+ assert torch.equal(matrix.yy, array[..., 1, 1])
45
+ assert torch.equal(matrix.yz, array[..., 1, 2])
46
+ assert torch.equal(matrix.zx, array[..., 2, 0])
47
+ assert torch.equal(matrix.zy, array[..., 2, 1])
48
+ assert torch.equal(matrix.zz, array[..., 2, 2])
49
+
50
+
51
+ def assert_array_close_to_rotation_matrix(array: torch.Tensor,
52
+ matrix: rotation_matrix.Rot3Array):
53
+ assert torch.allclose(matrix.to_tensor(), array, atol=1e-6)
54
+
55
+
56
+ def assert_vectors_equal(vec1: vector.Vec3Array, vec2: vector.Vec3Array):
57
+ assert torch.equal(vec1.x, vec2.x)
58
+ assert torch.equal(vec1.y, vec2.y)
59
+ assert torch.equal(vec1.z, vec2.z)
60
+
61
+
62
+ def assert_vectors_close(vec1: vector.Vec3Array, vec2: vector.Vec3Array):
63
+ assert torch.allclose(vec1.x, vec2.x, atol=1e-6, rtol=0.)
64
+ assert torch.allclose(vec1.y, vec2.y, atol=1e-6, rtol=0.)
65
+ assert torch.allclose(vec1.z, vec2.z, atol=1e-6, rtol=0.)
66
+
67
+
68
+ def assert_array_close_to_vector(array: torch.Tensor, vec: vector.Vec3Array):
69
+ assert torch.allclose(vec.to_tensor(), array, atol=1e-6, rtol=0.)
70
+
71
+
72
+ def assert_array_equal_to_vector(array: torch.Tensor, vec: vector.Vec3Array):
73
+ assert torch.equal(vec.to_tensor(), array)
74
+
75
+
76
+ def assert_rigid_equal_to_rigid(rigid1: rigid_matrix_vector.Rigid3Array,
77
+ rigid2: rigid_matrix_vector.Rigid3Array):
78
+ assert_rot_trans_equal_to_rigid(rigid1.rotation, rigid1.translation, rigid2)
79
+
80
+
81
+ def assert_rigid_close_to_rigid(rigid1: rigid_matrix_vector.Rigid3Array,
82
+ rigid2: rigid_matrix_vector.Rigid3Array):
83
+ assert_rot_trans_close_to_rigid(rigid1.rotation, rigid1.translation, rigid2)
84
+
85
+
86
+ def assert_rot_trans_equal_to_rigid(rot: rotation_matrix.Rot3Array,
87
+ trans: vector.Vec3Array,
88
+ rigid: rigid_matrix_vector.Rigid3Array):
89
+ assert_rotation_matrix_equal(rot, rigid.rotation)
90
+ assert_vectors_equal(trans, rigid.translation)
91
+
92
+
93
+ def assert_rot_trans_close_to_rigid(rot: rotation_matrix.Rot3Array,
94
+ trans: vector.Vec3Array,
95
+ rigid: rigid_matrix_vector.Rigid3Array):
96
+ assert_rotation_matrix_close(rot, rigid.rotation)
97
+ assert_vectors_close(trans, rigid.translation)
dockformerpp/utils/geometry/utils.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 DeepMind Technologies Limited
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Utils for geometry library."""
15
+
16
+ import dataclasses
17
+
18
+
19
+ def get_field_names(cls):
20
+ fields = dataclasses.fields(cls)
21
+ field_names = [f.name for f in fields]
22
+ return field_names
dockformerpp/utils/geometry/vector.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 DeepMind Technologies Limited
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Vec3Array Class."""
15
+
16
+ from __future__ import annotations
17
+ import dataclasses
18
+ from typing import Union, List
19
+
20
+ import torch
21
+
22
+ Float = Union[float, torch.Tensor]
23
+
24
+ @dataclasses.dataclass(frozen=True)
25
+ class Vec3Array:
26
+ x: torch.Tensor = dataclasses.field(metadata={'dtype': torch.float32})
27
+ y: torch.Tensor
28
+ z: torch.Tensor
29
+
30
+ def __post_init__(self):
31
+ if hasattr(self.x, 'dtype'):
32
+ assert self.x.dtype == self.y.dtype
33
+ assert self.x.dtype == self.z.dtype
34
+ assert all([x == y for x, y in zip(self.x.shape, self.y.shape)])
35
+ assert all([x == z for x, z in zip(self.x.shape, self.z.shape)])
36
+
37
+ def __add__(self, other: Vec3Array) -> Vec3Array:
38
+ return Vec3Array(
39
+ self.x + other.x,
40
+ self.y + other.y,
41
+ self.z + other.z,
42
+ )
43
+
44
+ def __sub__(self, other: Vec3Array) -> Vec3Array:
45
+ return Vec3Array(
46
+ self.x - other.x,
47
+ self.y - other.y,
48
+ self.z - other.z,
49
+ )
50
+
51
+ def __mul__(self, other: Float) -> Vec3Array:
52
+ return Vec3Array(
53
+ self.x * other,
54
+ self.y * other,
55
+ self.z * other,
56
+ )
57
+
58
+ def __rmul__(self, other: Float) -> Vec3Array:
59
+ return self * other
60
+
61
+ def __truediv__(self, other: Float) -> Vec3Array:
62
+ return Vec3Array(
63
+ self.x / other,
64
+ self.y / other,
65
+ self.z / other,
66
+ )
67
+
68
+ def __neg__(self) -> Vec3Array:
69
+ return self * -1
70
+
71
+ def __pos__(self) -> Vec3Array:
72
+ return self * 1
73
+
74
+ def __getitem__(self, index) -> Vec3Array:
75
+ return Vec3Array(
76
+ self.x[index],
77
+ self.y[index],
78
+ self.z[index],
79
+ )
80
+
81
+ def __iter__(self):
82
+ return iter((self.x, self.y, self.z))
83
+
84
+ @property
85
+ def shape(self):
86
+ return self.x.shape
87
+
88
+ def map_tensor_fn(self, fn) -> Vec3Array:
89
+ return Vec3Array(
90
+ fn(self.x),
91
+ fn(self.y),
92
+ fn(self.z),
93
+ )
94
+
95
+ def cross(self, other: Vec3Array) -> Vec3Array:
96
+ """Compute cross product between 'self' and 'other'."""
97
+ new_x = self.y * other.z - self.z * other.y
98
+ new_y = self.z * other.x - self.x * other.z
99
+ new_z = self.x * other.y - self.y * other.x
100
+ return Vec3Array(new_x, new_y, new_z)
101
+
102
+ def dot(self, other: Vec3Array) -> Float:
103
+ """Compute dot product between 'self' and 'other'."""
104
+ return self.x * other.x + self.y * other.y + self.z * other.z
105
+
106
+ def norm(self, epsilon: float = 1e-6) -> Float:
107
+ """Compute Norm of Vec3Array, clipped to epsilon."""
108
+ # To avoid NaN on the backward pass, we must use maximum before the sqrt
109
+ norm2 = self.dot(self)
110
+ if epsilon:
111
+ norm2 = torch.clamp(norm2, min=epsilon**2)
112
+ return torch.sqrt(norm2)
113
+
114
+ def norm2(self):
115
+ return self.dot(self)
116
+
117
+ def normalized(self, epsilon: float = 1e-6) -> Vec3Array:
118
+ """Return unit vector with optional clipping."""
119
+ return self / self.norm(epsilon)
120
+
121
+ def clone(self) -> Vec3Array:
122
+ return Vec3Array(
123
+ self.x.clone(),
124
+ self.y.clone(),
125
+ self.z.clone(),
126
+ )
127
+
128
+ def reshape(self, new_shape) -> Vec3Array:
129
+ x = self.x.reshape(new_shape)
130
+ y = self.y.reshape(new_shape)
131
+ z = self.z.reshape(new_shape)
132
+
133
+ return Vec3Array(x, y, z)
134
+
135
+ def sum(self, dim: int) -> Vec3Array:
136
+ return Vec3Array(
137
+ torch.sum(self.x, dim=dim),
138
+ torch.sum(self.y, dim=dim),
139
+ torch.sum(self.z, dim=dim),
140
+ )
141
+
142
+ def unsqueeze(self, dim: int):
143
+ return Vec3Array(
144
+ self.x.unsqueeze(dim),
145
+ self.y.unsqueeze(dim),
146
+ self.z.unsqueeze(dim),
147
+ )
148
+
149
+ @classmethod
150
+ def zeros(cls, shape, device="cpu"):
151
+ """Return Vec3Array corresponding to zeros of given shape."""
152
+ return cls(
153
+ torch.zeros(shape, dtype=torch.float32, device=device),
154
+ torch.zeros(shape, dtype=torch.float32, device=device),
155
+ torch.zeros(shape, dtype=torch.float32, device=device)
156
+ )
157
+
158
+ def to_tensor(self) -> torch.Tensor:
159
+ return torch.stack([self.x, self.y, self.z], dim=-1)
160
+
161
+ @classmethod
162
+ def from_array(cls, tensor):
163
+ return cls(*torch.unbind(tensor, dim=-1))
164
+
165
+ @classmethod
166
+ def cat(cls, vecs: List[Vec3Array], dim: int) -> Vec3Array:
167
+ return cls(
168
+ torch.cat([v.x for v in vecs], dim=dim),
169
+ torch.cat([v.y for v in vecs], dim=dim),
170
+ torch.cat([v.z for v in vecs], dim=dim),
171
+ )
172
+
173
+
174
+ def square_euclidean_distance(
175
+ vec1: Vec3Array,
176
+ vec2: Vec3Array,
177
+ epsilon: float = 1e-6
178
+ ) -> Float:
179
+ """Computes square of euclidean distance between 'vec1' and 'vec2'.
180
+
181
+ Args:
182
+ vec1: Vec3Array to compute distance to
183
+ vec2: Vec3Array to compute distance from, should be
184
+ broadcast compatible with 'vec1'
185
+ epsilon: distance is clipped from below to be at least epsilon
186
+
187
+ Returns:
188
+ Array of square euclidean distances;
189
+ shape will be result of broadcasting 'vec1' and 'vec2'
190
+ """
191
+ difference = vec1 - vec2
192
+ distance = difference.dot(difference)
193
+ if epsilon:
194
+ distance = torch.clamp(distance, min=epsilon)
195
+ return distance
196
+
197
+
198
+ def dot(vector1: Vec3Array, vector2: Vec3Array) -> Float:
199
+ return vector1.dot(vector2)
200
+
201
+
202
+ def cross(vector1: Vec3Array, vector2: Vec3Array) -> Float:
203
+ return vector1.cross(vector2)
204
+
205
+
206
+ def norm(vector: Vec3Array, epsilon: float = 1e-6) -> Float:
207
+ return vector.norm(epsilon)
208
+
209
+
210
+ def normalized(vector: Vec3Array, epsilon: float = 1e-6) -> Vec3Array:
211
+ return vector.normalized(epsilon)
212
+
213
+
214
+ def euclidean_distance(
215
+ vec1: Vec3Array,
216
+ vec2: Vec3Array,
217
+ epsilon: float = 1e-6
218
+ ) -> Float:
219
+ """Computes euclidean distance between 'vec1' and 'vec2'.
220
+
221
+ Args:
222
+ vec1: Vec3Array to compute euclidean distance to
223
+ vec2: Vec3Array to compute euclidean distance from, should be
224
+ broadcast compatible with 'vec1'
225
+ epsilon: distance is clipped from below to be at least epsilon
226
+
227
+ Returns:
228
+ Array of euclidean distances;
229
+ shape will be result of broadcasting 'vec1' and 'vec2'
230
+ """
231
+ distance_sq = square_euclidean_distance(vec1, vec2, epsilon**2)
232
+ distance = torch.sqrt(distance_sq)
233
+ return distance
234
+
235
+
236
+ def dihedral_angle(a: Vec3Array, b: Vec3Array, c: Vec3Array,
237
+ d: Vec3Array) -> Float:
238
+ """Computes torsion angle for a quadruple of points.
239
+
240
+ For points (a, b, c, d), this is the angle between the planes defined by
241
+ points (a, b, c) and (b, c, d). It is also known as the dihedral angle.
242
+
243
+ Arguments:
244
+ a: A Vec3Array of coordinates.
245
+ b: A Vec3Array of coordinates.
246
+ c: A Vec3Array of coordinates.
247
+ d: A Vec3Array of coordinates.
248
+
249
+ Returns:
250
+ A tensor of angles in radians: [-pi, pi].
251
+ """
252
+ v1 = a - b
253
+ v2 = b - c
254
+ v3 = d - c
255
+
256
+ c1 = v1.cross(v2)
257
+ c2 = v3.cross(v2)
258
+ c3 = c2.cross(c1)
259
+
260
+ v2_mag = v2.norm()
261
+ return torch.atan2(c3.dot(v2), v2_mag * c1.dot(c2))