FAPM_demo / esm /esmfold /v1 /categorical_mixture.py
wenkai's picture
Upload 31 files
3f0529e verified
raw
history blame
1.3 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
class CategoricalMixture:
def __init__(self, param, bins=50, start=0, end=1):
# All tensors are of shape ..., bins.
self.logits = param
bins = torch.linspace(
start, end, bins + 1, device=self.logits.device, dtype=self.logits.dtype
)
self.v_bins = (bins[:-1] + bins[1:]) / 2
def log_prob(self, true):
# Shapes are:
# self.probs: ... x bins
# true : ...
true_index = (
(
true.unsqueeze(-1)
- self.v_bins[
[
None,
]
* true.ndim
]
)
.abs()
.argmin(-1)
)
nll = self.logits.log_softmax(-1)
return torch.take_along_dim(nll, true_index.unsqueeze(-1), dim=-1).squeeze(-1)
def mean(self):
return (self.logits.softmax(-1) @ self.v_bins.unsqueeze(1)).squeeze(-1)
def categorical_lddt(logits, bins=50):
# Logits are ..., 37, bins.
return CategoricalMixture(logits, bins=bins).mean()