Spaces:
Runtime error
Runtime error
import torch | |
import random | |
from typing import List, Optional | |
from ..protein import constants | |
from ._base import register_transform | |
def random_shrink_extend(flag, min_length=5, shrink_limit=1, extend_limit=2): | |
first, last = continuous_flag_to_range(flag) | |
length = flag.sum().item() | |
if (length - 2*shrink_limit) < min_length: | |
shrink_limit = 0 | |
first_ext = max(0, first-random.randint(-shrink_limit, extend_limit)) | |
last_ext = min(last+random.randint(-shrink_limit, extend_limit), flag.size(0)-1) | |
flag_ext = flag.clone() | |
flag_ext[first_ext : last_ext+1] = True | |
return flag_ext | |
def continuous_flag_to_range(flag): | |
first = (torch.arange(0, flag.size(0))[flag]).min().item() | |
last = (torch.arange(0, flag.size(0))[flag]).max().item() | |
return first, last | |
class MaskSingleCDR(object): | |
def __init__(self, selection=None, augmentation=True): | |
super().__init__() | |
cdr_str_to_enum = { | |
'H1': constants.CDR.H1, | |
'H2': constants.CDR.H2, | |
'H3': constants.CDR.H3, | |
'L1': constants.CDR.L1, | |
'L2': constants.CDR.L2, | |
'L3': constants.CDR.L3, | |
'H_CDR1': constants.CDR.H1, | |
'H_CDR2': constants.CDR.H2, | |
'H_CDR3': constants.CDR.H3, | |
'L_CDR1': constants.CDR.L1, | |
'L_CDR2': constants.CDR.L2, | |
'L_CDR3': constants.CDR.L3, | |
'CDR3': 'CDR3', # H3 first, then fallback to L3 | |
} | |
assert selection is None or selection in cdr_str_to_enum | |
self.selection = cdr_str_to_enum.get(selection, None) | |
self.augmentation = augmentation | |
def perform_masking_(self, data, selection=None): | |
cdr_flag = data['cdr_flag'] | |
if selection is None: | |
cdr_all = cdr_flag[cdr_flag > 0].unique().tolist() | |
cdr_to_mask = random.choice(cdr_all) | |
else: | |
cdr_to_mask = selection | |
cdr_to_mask_flag = (cdr_flag == cdr_to_mask) | |
if self.augmentation: | |
cdr_to_mask_flag = random_shrink_extend(cdr_to_mask_flag) | |
cdr_first, cdr_last = continuous_flag_to_range(cdr_to_mask_flag) | |
left_idx = max(0, cdr_first-1) | |
right_idx = min(data['aa'].size(0)-1, cdr_last+1) | |
anchor_flag = torch.zeros(data['aa'].shape, dtype=torch.bool) | |
anchor_flag[left_idx] = True | |
anchor_flag[right_idx] = True | |
data['generate_flag'] = cdr_to_mask_flag | |
data['anchor_flag'] = anchor_flag | |
def __call__(self, structure): | |
if self.selection is None: | |
ab_data = [] | |
if structure['heavy'] is not None: | |
ab_data.append(structure['heavy']) | |
if structure['light'] is not None: | |
ab_data.append(structure['light']) | |
data_to_mask = random.choice(ab_data) | |
sel = None | |
elif self.selection in (constants.CDR.H1, constants.CDR.H2, constants.CDR.H3, ): | |
data_to_mask = structure['heavy'] | |
sel = int(self.selection) | |
elif self.selection in (constants.CDR.L1, constants.CDR.L2, constants.CDR.L3, ): | |
data_to_mask = structure['light'] | |
sel = int(self.selection) | |
elif self.selection == 'CDR3': | |
if structure['heavy'] is not None: | |
data_to_mask = structure['heavy'] | |
sel = constants.CDR.H3 | |
else: | |
data_to_mask = structure['light'] | |
sel = constants.CDR.L3 | |
self.perform_masking_(data_to_mask, selection=sel) | |
return structure | |
class MaskMultipleCDRs(object): | |
def __init__(self, selection: Optional[List[str]]=None, augmentation=True): | |
super().__init__() | |
cdr_str_to_enum = { | |
'H1': constants.CDR.H1, | |
'H2': constants.CDR.H2, | |
'H3': constants.CDR.H3, | |
'L1': constants.CDR.L1, | |
'L2': constants.CDR.L2, | |
'L3': constants.CDR.L3, | |
'H_CDR1': constants.CDR.H1, | |
'H_CDR2': constants.CDR.H2, | |
'H_CDR3': constants.CDR.H3, | |
'L_CDR1': constants.CDR.L1, | |
'L_CDR2': constants.CDR.L2, | |
'L_CDR3': constants.CDR.L3, | |
} | |
if selection is not None: | |
self.selection = [cdr_str_to_enum[s] for s in selection] | |
else: | |
self.selection = None | |
self.augmentation = augmentation | |
def mask_one_cdr_(self, data, cdr_to_mask): | |
cdr_flag = data['cdr_flag'] | |
cdr_to_mask_flag = (cdr_flag == cdr_to_mask) | |
if self.augmentation: | |
cdr_to_mask_flag = random_shrink_extend(cdr_to_mask_flag) | |
cdr_first, cdr_last = continuous_flag_to_range(cdr_to_mask_flag) | |
left_idx = max(0, cdr_first-1) | |
right_idx = min(data['aa'].size(0)-1, cdr_last+1) | |
anchor_flag = torch.zeros(data['aa'].shape, dtype=torch.bool) | |
anchor_flag[left_idx] = True | |
anchor_flag[right_idx] = True | |
if 'generate_flag' not in data: | |
data['generate_flag'] = cdr_to_mask_flag | |
data['anchor_flag'] = anchor_flag | |
else: | |
data['generate_flag'] |= cdr_to_mask_flag | |
data['anchor_flag'] |= anchor_flag | |
def mask_for_one_chain_(self, data): | |
cdr_flag = data['cdr_flag'] | |
cdr_all = cdr_flag[cdr_flag > 0].unique().tolist() | |
num_cdrs_to_mask = random.randint(1, len(cdr_all)) | |
if self.selection is not None: | |
cdrs_to_mask = list(set(cdr_all).intersection(self.selection)) | |
else: | |
random.shuffle(cdr_all) | |
cdrs_to_mask = cdr_all[:num_cdrs_to_mask] | |
for cdr_to_mask in cdrs_to_mask: | |
self.mask_one_cdr_(data, cdr_to_mask) | |
def __call__(self, structure): | |
if structure['heavy'] is not None: | |
self.mask_for_one_chain_(structure['heavy']) | |
if structure['light'] is not None: | |
self.mask_for_one_chain_(structure['light']) | |
return structure | |
class MaskAntibody(object): | |
def mask_ab_chain_(self, data): | |
data['generate_flag'] = torch.ones(data['aa'].shape, dtype=torch.bool) | |
def __call__(self, structure): | |
pos_ab_alpha = [] | |
if structure['heavy'] is not None: | |
self.mask_ab_chain_(structure['heavy']) | |
pos_ab_alpha.append( | |
structure['heavy']['pos_heavyatom'][:, constants.BBHeavyAtom.CA] | |
) | |
if structure['light'] is not None: | |
self.mask_ab_chain_(structure['light']) | |
pos_ab_alpha.append( | |
structure['light']['pos_heavyatom'][:, constants.BBHeavyAtom.CA] | |
) | |
pos_ab_alpha = torch.cat(pos_ab_alpha, dim=0) # (L_Ab, 3) | |
if structure['antigen'] is not None: | |
pos_ag_alpha = structure['antigen']['pos_heavyatom'][:, constants.BBHeavyAtom.CA] | |
ag_ab_dist = torch.cdist(pos_ag_alpha, pos_ab_alpha) # (L_Ag, L_Ab) | |
nn_ab_dist = ag_ab_dist.min(dim=1)[0] # (L_Ag) | |
contact_flag = (nn_ab_dist <= 6.0) # (L_Ag) | |
if contact_flag.sum().item() == 0: | |
contact_flag[nn_ab_dist.argmin()] = True | |
anchor_idx = torch.multinomial(contact_flag.float(), num_samples=1).item() | |
anchor_flag = torch.zeros(structure['antigen']['aa'].shape, dtype=torch.bool) | |
anchor_flag[anchor_idx] = True | |
structure['antigen']['anchor_flag'] = anchor_flag | |
structure['antigen']['contact_flag'] = contact_flag | |
return structure | |
class RemoveAntigen: | |
def __call__(self, structure): | |
structure['antigen'] = None | |
structure['antigen_seqmap'] = None | |
return structure | |