luost26's picture
Update
753e275
import torch
from ._base import _mask_select_data, register_transform
from ..protein import constants
@register_transform('patch_around_anchor')
class PatchAroundAnchor(object):
def __init__(self, initial_patch_size=128, antigen_size=128):
super().__init__()
self.initial_patch_size = initial_patch_size
self.antigen_size = antigen_size
def _center(self, data, origin):
origin = origin.reshape(1, 1, 3)
data['pos_heavyatom'] -= origin # (L, A, 3)
data['pos_heavyatom'] = data['pos_heavyatom'] * data['mask_heavyatom'][:, :, None]
data['origin'] = origin.reshape(3)
return data
def __call__(self, data):
anchor_flag = data['anchor_flag'] # (L,)
anchor_points = data['pos_heavyatom'][anchor_flag, constants.BBHeavyAtom.CA] # (n_anchors, 3)
antigen_mask = (data['fragment_type'] == constants.Fragment.Antigen)
antibody_mask = torch.logical_not(antigen_mask)
if anchor_flag.sum().item() == 0:
# Generating full antibody-Fv, no antigen given
data_patch = _mask_select_data(
data = data,
mask = antibody_mask,
)
data_patch = self._center(
data_patch,
origin = data_patch['pos_heavyatom'][:, constants.BBHeavyAtom.CA].mean(dim=0)
)
return data_patch
pos_alpha = data['pos_heavyatom'][:, constants.BBHeavyAtom.CA] # (L, 3)
dist_anchor = torch.cdist(pos_alpha, anchor_points).min(dim=1)[0] # (L, )
initial_patch_idx = torch.topk(
dist_anchor,
k = min(self.initial_patch_size, dist_anchor.size(0)),
largest=False,
)[1] # (initial_patch_size, )
dist_anchor_antigen = dist_anchor.masked_fill(
mask = antibody_mask, # Fill antibody with +inf
value = float('+inf')
) # (L, )
antigen_patch_idx = torch.topk(
dist_anchor_antigen,
k = min(self.antigen_size, antigen_mask.sum().item()),
largest=False, sorted=True
)[1] # (ag_size, )
patch_mask = torch.logical_or(
data['generate_flag'],
data['anchor_flag'],
)
patch_mask[initial_patch_idx] = True
patch_mask[antigen_patch_idx] = True
patch_idx = torch.arange(0, patch_mask.shape[0])[patch_mask]
data_patch = _mask_select_data(data, patch_mask)
data_patch = self._center(
data_patch,
origin = anchor_points.mean(dim=0)
)
data_patch['patch_idx'] = patch_idx
return data_patch