# Copyright 2021 DeepMind Technologies Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """lDDT protein distance score.""" import jax.numpy as jnp def lddt(predicted_points, true_points, true_points_mask, cutoff=15., per_residue=False): """Measure (approximate) lDDT for a batch of coordinates. lDDT reference: Mariani, V., Biasini, M., Barbato, A. & Schwede, T. lDDT: A local superposition-free score for comparing protein structures and models using distance difference tests. Bioinformatics 29, 2722–2728 (2013). lDDT is a measure of the difference between the true distance matrix and the distance matrix of the predicted points. The difference is computed only on points closer than cutoff *in the true structure*. This function does not compute the exact lDDT value that the original paper describes because it does not include terms for physical feasibility (e.g. bond length violations). Therefore this is only an approximate lDDT score. Args: predicted_points: (batch, length, 3) array of predicted 3D points true_points: (batch, length, 3) array of true 3D points true_points_mask: (batch, length, 1) binary-valued float array. This mask should be 1 for points that exist in the true points. cutoff: Maximum distance for a pair of points to be included per_residue: If true, return score for each residue. Note that the overall lDDT is not exactly the mean of the per_residue lDDT's because some residues have more contacts than others. Returns: An (approximate, see above) lDDT score in the range 0-1. """ assert len(predicted_points.shape) == 3 assert predicted_points.shape[-1] == 3 assert true_points_mask.shape[-1] == 1 assert len(true_points_mask.shape) == 3 # Compute true and predicted distance matrices. dmat_true = jnp.sqrt(1e-10 + jnp.sum( (true_points[:, :, None] - true_points[:, None, :])**2, axis=-1)) dmat_predicted = jnp.sqrt(1e-10 + jnp.sum( (predicted_points[:, :, None] - predicted_points[:, None, :])**2, axis=-1)) dists_to_score = ( (dmat_true < cutoff).astype(jnp.float32) * true_points_mask * jnp.transpose(true_points_mask, [0, 2, 1]) * (1. - jnp.eye(dmat_true.shape[1])) # Exclude self-interaction. ) # Shift unscored distances to be far away. dist_l1 = jnp.abs(dmat_true - dmat_predicted) # True lDDT uses a number of fixed bins. # We ignore the physical plausibility correction to lDDT, though. score = 0.25 * ((dist_l1 < 0.5).astype(jnp.float32) + (dist_l1 < 1.0).astype(jnp.float32) + (dist_l1 < 2.0).astype(jnp.float32) + (dist_l1 < 4.0).astype(jnp.float32)) # Normalize over the appropriate axes. reduce_axes = (-1,) if per_residue else (-2, -1) norm = 1. / (1e-10 + jnp.sum(dists_to_score, axis=reduce_axes)) score = norm * (1e-10 + jnp.sum(dists_to_score * score, axis=reduce_axes)) return score