Spaces:
Runtime error
Runtime error
import torch | |
from ._base import _mask_select_data, register_transform | |
from ..protein import constants | |
class PatchAroundAnchor(object): | |
def __init__(self, initial_patch_size=128, antigen_size=128): | |
super().__init__() | |
self.initial_patch_size = initial_patch_size | |
self.antigen_size = antigen_size | |
def _center(self, data, origin): | |
origin = origin.reshape(1, 1, 3) | |
data['pos_heavyatom'] -= origin # (L, A, 3) | |
data['pos_heavyatom'] = data['pos_heavyatom'] * data['mask_heavyatom'][:, :, None] | |
data['origin'] = origin.reshape(3) | |
return data | |
def __call__(self, data): | |
anchor_flag = data['anchor_flag'] # (L,) | |
anchor_points = data['pos_heavyatom'][anchor_flag, constants.BBHeavyAtom.CA] # (n_anchors, 3) | |
antigen_mask = (data['fragment_type'] == constants.Fragment.Antigen) | |
antibody_mask = torch.logical_not(antigen_mask) | |
if anchor_flag.sum().item() == 0: | |
# Generating full antibody-Fv, no antigen given | |
data_patch = _mask_select_data( | |
data = data, | |
mask = antibody_mask, | |
) | |
data_patch = self._center( | |
data_patch, | |
origin = data_patch['pos_heavyatom'][:, constants.BBHeavyAtom.CA].mean(dim=0) | |
) | |
return data_patch | |
pos_alpha = data['pos_heavyatom'][:, constants.BBHeavyAtom.CA] # (L, 3) | |
dist_anchor = torch.cdist(pos_alpha, anchor_points).min(dim=1)[0] # (L, ) | |
initial_patch_idx = torch.topk( | |
dist_anchor, | |
k = min(self.initial_patch_size, dist_anchor.size(0)), | |
largest=False, | |
)[1] # (initial_patch_size, ) | |
dist_anchor_antigen = dist_anchor.masked_fill( | |
mask = antibody_mask, # Fill antibody with +inf | |
value = float('+inf') | |
) # (L, ) | |
antigen_patch_idx = torch.topk( | |
dist_anchor_antigen, | |
k = min(self.antigen_size, antigen_mask.sum().item()), | |
largest=False, sorted=True | |
)[1] # (ag_size, ) | |
patch_mask = torch.logical_or( | |
data['generate_flag'], | |
data['anchor_flag'], | |
) | |
patch_mask[initial_patch_idx] = True | |
patch_mask[antigen_patch_idx] = True | |
patch_idx = torch.arange(0, patch_mask.shape[0])[patch_mask] | |
data_patch = _mask_select_data(data, patch_mask) | |
data_patch = self._center( | |
data_patch, | |
origin = anchor_points.mean(dim=0) | |
) | |
data_patch['patch_idx'] = patch_idx | |
return data_patch | |