luost26's picture
Update
753e275
import torch
import random
from typing import List, Optional
from ..protein import constants
from ._base import register_transform
def random_shrink_extend(flag, min_length=5, shrink_limit=1, extend_limit=2):
first, last = continuous_flag_to_range(flag)
length = flag.sum().item()
if (length - 2*shrink_limit) < min_length:
shrink_limit = 0
first_ext = max(0, first-random.randint(-shrink_limit, extend_limit))
last_ext = min(last+random.randint(-shrink_limit, extend_limit), flag.size(0)-1)
flag_ext = flag.clone()
flag_ext[first_ext : last_ext+1] = True
return flag_ext
def continuous_flag_to_range(flag):
first = (torch.arange(0, flag.size(0))[flag]).min().item()
last = (torch.arange(0, flag.size(0))[flag]).max().item()
return first, last
@register_transform('mask_single_cdr')
class MaskSingleCDR(object):
def __init__(self, selection=None, augmentation=True):
super().__init__()
cdr_str_to_enum = {
'H1': constants.CDR.H1,
'H2': constants.CDR.H2,
'H3': constants.CDR.H3,
'L1': constants.CDR.L1,
'L2': constants.CDR.L2,
'L3': constants.CDR.L3,
'H_CDR1': constants.CDR.H1,
'H_CDR2': constants.CDR.H2,
'H_CDR3': constants.CDR.H3,
'L_CDR1': constants.CDR.L1,
'L_CDR2': constants.CDR.L2,
'L_CDR3': constants.CDR.L3,
'CDR3': 'CDR3', # H3 first, then fallback to L3
}
assert selection is None or selection in cdr_str_to_enum
self.selection = cdr_str_to_enum.get(selection, None)
self.augmentation = augmentation
def perform_masking_(self, data, selection=None):
cdr_flag = data['cdr_flag']
if selection is None:
cdr_all = cdr_flag[cdr_flag > 0].unique().tolist()
cdr_to_mask = random.choice(cdr_all)
else:
cdr_to_mask = selection
cdr_to_mask_flag = (cdr_flag == cdr_to_mask)
if self.augmentation:
cdr_to_mask_flag = random_shrink_extend(cdr_to_mask_flag)
cdr_first, cdr_last = continuous_flag_to_range(cdr_to_mask_flag)
left_idx = max(0, cdr_first-1)
right_idx = min(data['aa'].size(0)-1, cdr_last+1)
anchor_flag = torch.zeros(data['aa'].shape, dtype=torch.bool)
anchor_flag[left_idx] = True
anchor_flag[right_idx] = True
data['generate_flag'] = cdr_to_mask_flag
data['anchor_flag'] = anchor_flag
def __call__(self, structure):
if self.selection is None:
ab_data = []
if structure['heavy'] is not None:
ab_data.append(structure['heavy'])
if structure['light'] is not None:
ab_data.append(structure['light'])
data_to_mask = random.choice(ab_data)
sel = None
elif self.selection in (constants.CDR.H1, constants.CDR.H2, constants.CDR.H3, ):
data_to_mask = structure['heavy']
sel = int(self.selection)
elif self.selection in (constants.CDR.L1, constants.CDR.L2, constants.CDR.L3, ):
data_to_mask = structure['light']
sel = int(self.selection)
elif self.selection == 'CDR3':
if structure['heavy'] is not None:
data_to_mask = structure['heavy']
sel = constants.CDR.H3
else:
data_to_mask = structure['light']
sel = constants.CDR.L3
self.perform_masking_(data_to_mask, selection=sel)
return structure
@register_transform('mask_multiple_cdrs')
class MaskMultipleCDRs(object):
def __init__(self, selection: Optional[List[str]]=None, augmentation=True):
super().__init__()
cdr_str_to_enum = {
'H1': constants.CDR.H1,
'H2': constants.CDR.H2,
'H3': constants.CDR.H3,
'L1': constants.CDR.L1,
'L2': constants.CDR.L2,
'L3': constants.CDR.L3,
'H_CDR1': constants.CDR.H1,
'H_CDR2': constants.CDR.H2,
'H_CDR3': constants.CDR.H3,
'L_CDR1': constants.CDR.L1,
'L_CDR2': constants.CDR.L2,
'L_CDR3': constants.CDR.L3,
}
if selection is not None:
self.selection = [cdr_str_to_enum[s] for s in selection]
else:
self.selection = None
self.augmentation = augmentation
def mask_one_cdr_(self, data, cdr_to_mask):
cdr_flag = data['cdr_flag']
cdr_to_mask_flag = (cdr_flag == cdr_to_mask)
if self.augmentation:
cdr_to_mask_flag = random_shrink_extend(cdr_to_mask_flag)
cdr_first, cdr_last = continuous_flag_to_range(cdr_to_mask_flag)
left_idx = max(0, cdr_first-1)
right_idx = min(data['aa'].size(0)-1, cdr_last+1)
anchor_flag = torch.zeros(data['aa'].shape, dtype=torch.bool)
anchor_flag[left_idx] = True
anchor_flag[right_idx] = True
if 'generate_flag' not in data:
data['generate_flag'] = cdr_to_mask_flag
data['anchor_flag'] = anchor_flag
else:
data['generate_flag'] |= cdr_to_mask_flag
data['anchor_flag'] |= anchor_flag
def mask_for_one_chain_(self, data):
cdr_flag = data['cdr_flag']
cdr_all = cdr_flag[cdr_flag > 0].unique().tolist()
num_cdrs_to_mask = random.randint(1, len(cdr_all))
if self.selection is not None:
cdrs_to_mask = list(set(cdr_all).intersection(self.selection))
else:
random.shuffle(cdr_all)
cdrs_to_mask = cdr_all[:num_cdrs_to_mask]
for cdr_to_mask in cdrs_to_mask:
self.mask_one_cdr_(data, cdr_to_mask)
def __call__(self, structure):
if structure['heavy'] is not None:
self.mask_for_one_chain_(structure['heavy'])
if structure['light'] is not None:
self.mask_for_one_chain_(structure['light'])
return structure
@register_transform('mask_antibody')
class MaskAntibody(object):
def mask_ab_chain_(self, data):
data['generate_flag'] = torch.ones(data['aa'].shape, dtype=torch.bool)
def __call__(self, structure):
pos_ab_alpha = []
if structure['heavy'] is not None:
self.mask_ab_chain_(structure['heavy'])
pos_ab_alpha.append(
structure['heavy']['pos_heavyatom'][:, constants.BBHeavyAtom.CA]
)
if structure['light'] is not None:
self.mask_ab_chain_(structure['light'])
pos_ab_alpha.append(
structure['light']['pos_heavyatom'][:, constants.BBHeavyAtom.CA]
)
pos_ab_alpha = torch.cat(pos_ab_alpha, dim=0) # (L_Ab, 3)
if structure['antigen'] is not None:
pos_ag_alpha = structure['antigen']['pos_heavyatom'][:, constants.BBHeavyAtom.CA]
ag_ab_dist = torch.cdist(pos_ag_alpha, pos_ab_alpha) # (L_Ag, L_Ab)
nn_ab_dist = ag_ab_dist.min(dim=1)[0] # (L_Ag)
contact_flag = (nn_ab_dist <= 6.0) # (L_Ag)
if contact_flag.sum().item() == 0:
contact_flag[nn_ab_dist.argmin()] = True
anchor_idx = torch.multinomial(contact_flag.float(), num_samples=1).item()
anchor_flag = torch.zeros(structure['antigen']['aa'].shape, dtype=torch.bool)
anchor_flag[anchor_idx] = True
structure['antigen']['anchor_flag'] = anchor_flag
structure['antigen']['contact_flag'] = contact_flag
return structure
@register_transform('remove_antigen')
class RemoveAntigen:
def __call__(self, structure):
structure['antigen'] = None
structure['antigen_seqmap'] = None
return structure