Spaces:
Runtime error
Runtime error
File size: 7,863 Bytes
753e275 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 |
import torch
import random
from typing import List, Optional
from ..protein import constants
from ._base import register_transform
def random_shrink_extend(flag, min_length=5, shrink_limit=1, extend_limit=2):
first, last = continuous_flag_to_range(flag)
length = flag.sum().item()
if (length - 2*shrink_limit) < min_length:
shrink_limit = 0
first_ext = max(0, first-random.randint(-shrink_limit, extend_limit))
last_ext = min(last+random.randint(-shrink_limit, extend_limit), flag.size(0)-1)
flag_ext = flag.clone()
flag_ext[first_ext : last_ext+1] = True
return flag_ext
def continuous_flag_to_range(flag):
first = (torch.arange(0, flag.size(0))[flag]).min().item()
last = (torch.arange(0, flag.size(0))[flag]).max().item()
return first, last
@register_transform('mask_single_cdr')
class MaskSingleCDR(object):
def __init__(self, selection=None, augmentation=True):
super().__init__()
cdr_str_to_enum = {
'H1': constants.CDR.H1,
'H2': constants.CDR.H2,
'H3': constants.CDR.H3,
'L1': constants.CDR.L1,
'L2': constants.CDR.L2,
'L3': constants.CDR.L3,
'H_CDR1': constants.CDR.H1,
'H_CDR2': constants.CDR.H2,
'H_CDR3': constants.CDR.H3,
'L_CDR1': constants.CDR.L1,
'L_CDR2': constants.CDR.L2,
'L_CDR3': constants.CDR.L3,
'CDR3': 'CDR3', # H3 first, then fallback to L3
}
assert selection is None or selection in cdr_str_to_enum
self.selection = cdr_str_to_enum.get(selection, None)
self.augmentation = augmentation
def perform_masking_(self, data, selection=None):
cdr_flag = data['cdr_flag']
if selection is None:
cdr_all = cdr_flag[cdr_flag > 0].unique().tolist()
cdr_to_mask = random.choice(cdr_all)
else:
cdr_to_mask = selection
cdr_to_mask_flag = (cdr_flag == cdr_to_mask)
if self.augmentation:
cdr_to_mask_flag = random_shrink_extend(cdr_to_mask_flag)
cdr_first, cdr_last = continuous_flag_to_range(cdr_to_mask_flag)
left_idx = max(0, cdr_first-1)
right_idx = min(data['aa'].size(0)-1, cdr_last+1)
anchor_flag = torch.zeros(data['aa'].shape, dtype=torch.bool)
anchor_flag[left_idx] = True
anchor_flag[right_idx] = True
data['generate_flag'] = cdr_to_mask_flag
data['anchor_flag'] = anchor_flag
def __call__(self, structure):
if self.selection is None:
ab_data = []
if structure['heavy'] is not None:
ab_data.append(structure['heavy'])
if structure['light'] is not None:
ab_data.append(structure['light'])
data_to_mask = random.choice(ab_data)
sel = None
elif self.selection in (constants.CDR.H1, constants.CDR.H2, constants.CDR.H3, ):
data_to_mask = structure['heavy']
sel = int(self.selection)
elif self.selection in (constants.CDR.L1, constants.CDR.L2, constants.CDR.L3, ):
data_to_mask = structure['light']
sel = int(self.selection)
elif self.selection == 'CDR3':
if structure['heavy'] is not None:
data_to_mask = structure['heavy']
sel = constants.CDR.H3
else:
data_to_mask = structure['light']
sel = constants.CDR.L3
self.perform_masking_(data_to_mask, selection=sel)
return structure
@register_transform('mask_multiple_cdrs')
class MaskMultipleCDRs(object):
def __init__(self, selection: Optional[List[str]]=None, augmentation=True):
super().__init__()
cdr_str_to_enum = {
'H1': constants.CDR.H1,
'H2': constants.CDR.H2,
'H3': constants.CDR.H3,
'L1': constants.CDR.L1,
'L2': constants.CDR.L2,
'L3': constants.CDR.L3,
'H_CDR1': constants.CDR.H1,
'H_CDR2': constants.CDR.H2,
'H_CDR3': constants.CDR.H3,
'L_CDR1': constants.CDR.L1,
'L_CDR2': constants.CDR.L2,
'L_CDR3': constants.CDR.L3,
}
if selection is not None:
self.selection = [cdr_str_to_enum[s] for s in selection]
else:
self.selection = None
self.augmentation = augmentation
def mask_one_cdr_(self, data, cdr_to_mask):
cdr_flag = data['cdr_flag']
cdr_to_mask_flag = (cdr_flag == cdr_to_mask)
if self.augmentation:
cdr_to_mask_flag = random_shrink_extend(cdr_to_mask_flag)
cdr_first, cdr_last = continuous_flag_to_range(cdr_to_mask_flag)
left_idx = max(0, cdr_first-1)
right_idx = min(data['aa'].size(0)-1, cdr_last+1)
anchor_flag = torch.zeros(data['aa'].shape, dtype=torch.bool)
anchor_flag[left_idx] = True
anchor_flag[right_idx] = True
if 'generate_flag' not in data:
data['generate_flag'] = cdr_to_mask_flag
data['anchor_flag'] = anchor_flag
else:
data['generate_flag'] |= cdr_to_mask_flag
data['anchor_flag'] |= anchor_flag
def mask_for_one_chain_(self, data):
cdr_flag = data['cdr_flag']
cdr_all = cdr_flag[cdr_flag > 0].unique().tolist()
num_cdrs_to_mask = random.randint(1, len(cdr_all))
if self.selection is not None:
cdrs_to_mask = list(set(cdr_all).intersection(self.selection))
else:
random.shuffle(cdr_all)
cdrs_to_mask = cdr_all[:num_cdrs_to_mask]
for cdr_to_mask in cdrs_to_mask:
self.mask_one_cdr_(data, cdr_to_mask)
def __call__(self, structure):
if structure['heavy'] is not None:
self.mask_for_one_chain_(structure['heavy'])
if structure['light'] is not None:
self.mask_for_one_chain_(structure['light'])
return structure
@register_transform('mask_antibody')
class MaskAntibody(object):
def mask_ab_chain_(self, data):
data['generate_flag'] = torch.ones(data['aa'].shape, dtype=torch.bool)
def __call__(self, structure):
pos_ab_alpha = []
if structure['heavy'] is not None:
self.mask_ab_chain_(structure['heavy'])
pos_ab_alpha.append(
structure['heavy']['pos_heavyatom'][:, constants.BBHeavyAtom.CA]
)
if structure['light'] is not None:
self.mask_ab_chain_(structure['light'])
pos_ab_alpha.append(
structure['light']['pos_heavyatom'][:, constants.BBHeavyAtom.CA]
)
pos_ab_alpha = torch.cat(pos_ab_alpha, dim=0) # (L_Ab, 3)
if structure['antigen'] is not None:
pos_ag_alpha = structure['antigen']['pos_heavyatom'][:, constants.BBHeavyAtom.CA]
ag_ab_dist = torch.cdist(pos_ag_alpha, pos_ab_alpha) # (L_Ag, L_Ab)
nn_ab_dist = ag_ab_dist.min(dim=1)[0] # (L_Ag)
contact_flag = (nn_ab_dist <= 6.0) # (L_Ag)
if contact_flag.sum().item() == 0:
contact_flag[nn_ab_dist.argmin()] = True
anchor_idx = torch.multinomial(contact_flag.float(), num_samples=1).item()
anchor_flag = torch.zeros(structure['antigen']['aa'].shape, dtype=torch.bool)
anchor_flag[anchor_idx] = True
structure['antigen']['anchor_flag'] = anchor_flag
structure['antigen']['contact_flag'] = contact_flag
return structure
@register_transform('remove_antigen')
class RemoveAntigen:
def __call__(self, structure):
structure['antigen'] = None
structure['antigen_seqmap'] = None
return structure
|