File size: 7,863 Bytes
753e275
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
import torch
import random
from typing import List, Optional

from ..protein import constants
from ._base import register_transform


def random_shrink_extend(flag, min_length=5, shrink_limit=1, extend_limit=2):
    first, last = continuous_flag_to_range(flag)
    length = flag.sum().item()
    if (length - 2*shrink_limit) < min_length:
        shrink_limit = 0
    first_ext = max(0, first-random.randint(-shrink_limit, extend_limit))
    last_ext = min(last+random.randint(-shrink_limit, extend_limit), flag.size(0)-1)
    flag_ext = flag.clone()
    flag_ext[first_ext : last_ext+1] = True
    return flag_ext


def continuous_flag_to_range(flag):
    first = (torch.arange(0, flag.size(0))[flag]).min().item()
    last = (torch.arange(0, flag.size(0))[flag]).max().item()
    return first, last


@register_transform('mask_single_cdr')
class MaskSingleCDR(object):

    def __init__(self, selection=None, augmentation=True):
        super().__init__()
        cdr_str_to_enum = {
            'H1': constants.CDR.H1,
            'H2': constants.CDR.H2,
            'H3': constants.CDR.H3,
            'L1': constants.CDR.L1,
            'L2': constants.CDR.L2,
            'L3': constants.CDR.L3,
            'H_CDR1': constants.CDR.H1,
            'H_CDR2': constants.CDR.H2,
            'H_CDR3': constants.CDR.H3,
            'L_CDR1': constants.CDR.L1,
            'L_CDR2': constants.CDR.L2,
            'L_CDR3': constants.CDR.L3,
            'CDR3': 'CDR3',     # H3 first, then fallback to L3
        }
        assert selection is None or selection in cdr_str_to_enum
        self.selection = cdr_str_to_enum.get(selection, None)
        self.augmentation = augmentation

    def perform_masking_(self, data, selection=None):
        cdr_flag = data['cdr_flag']

        if selection is None:
            cdr_all = cdr_flag[cdr_flag > 0].unique().tolist()
            cdr_to_mask = random.choice(cdr_all)
        else:
            cdr_to_mask = selection

        cdr_to_mask_flag = (cdr_flag == cdr_to_mask)
        if self.augmentation:
            cdr_to_mask_flag = random_shrink_extend(cdr_to_mask_flag)

        cdr_first, cdr_last = continuous_flag_to_range(cdr_to_mask_flag)
        left_idx = max(0, cdr_first-1)
        right_idx = min(data['aa'].size(0)-1, cdr_last+1)
        anchor_flag = torch.zeros(data['aa'].shape, dtype=torch.bool)
        anchor_flag[left_idx] = True
        anchor_flag[right_idx] = True

        data['generate_flag'] = cdr_to_mask_flag
        data['anchor_flag'] = anchor_flag

    def __call__(self, structure):
        if self.selection is None:
            ab_data = []
            if structure['heavy'] is not None:
                ab_data.append(structure['heavy'])
            if structure['light'] is not None:
                ab_data.append(structure['light'])
            data_to_mask = random.choice(ab_data)
            sel = None
        elif self.selection in (constants.CDR.H1, constants.CDR.H2, constants.CDR.H3, ):
            data_to_mask = structure['heavy']
            sel = int(self.selection)
        elif self.selection in (constants.CDR.L1, constants.CDR.L2, constants.CDR.L3, ):
            data_to_mask = structure['light']
            sel = int(self.selection)
        elif self.selection == 'CDR3':
            if structure['heavy'] is not None:
                data_to_mask = structure['heavy']
                sel = constants.CDR.H3
            else:
                data_to_mask = structure['light']
                sel = constants.CDR.L3

        self.perform_masking_(data_to_mask, selection=sel)
        return structure


@register_transform('mask_multiple_cdrs')
class MaskMultipleCDRs(object):
    
    def __init__(self, selection: Optional[List[str]]=None, augmentation=True):
        super().__init__()
        cdr_str_to_enum = {
            'H1': constants.CDR.H1,
            'H2': constants.CDR.H2,
            'H3': constants.CDR.H3,
            'L1': constants.CDR.L1,
            'L2': constants.CDR.L2,
            'L3': constants.CDR.L3,
            'H_CDR1': constants.CDR.H1,
            'H_CDR2': constants.CDR.H2,
            'H_CDR3': constants.CDR.H3,
            'L_CDR1': constants.CDR.L1,
            'L_CDR2': constants.CDR.L2,
            'L_CDR3': constants.CDR.L3,
        }
        if selection is not None:
            self.selection = [cdr_str_to_enum[s] for s in selection]
        else:
            self.selection = None
        self.augmentation = augmentation

    def mask_one_cdr_(self, data, cdr_to_mask):
        cdr_flag = data['cdr_flag']

        cdr_to_mask_flag = (cdr_flag == cdr_to_mask)
        if self.augmentation:
            cdr_to_mask_flag = random_shrink_extend(cdr_to_mask_flag)

        cdr_first, cdr_last = continuous_flag_to_range(cdr_to_mask_flag)
        left_idx = max(0, cdr_first-1)
        right_idx = min(data['aa'].size(0)-1, cdr_last+1)
        anchor_flag = torch.zeros(data['aa'].shape, dtype=torch.bool)
        anchor_flag[left_idx] = True
        anchor_flag[right_idx] = True

        if 'generate_flag' not in data:
            data['generate_flag'] = cdr_to_mask_flag
            data['anchor_flag'] = anchor_flag
        else:
            data['generate_flag'] |= cdr_to_mask_flag
            data['anchor_flag'] |= anchor_flag

    def mask_for_one_chain_(self, data):
        cdr_flag = data['cdr_flag']
        cdr_all = cdr_flag[cdr_flag > 0].unique().tolist()
    
        num_cdrs_to_mask = random.randint(1, len(cdr_all))

        if self.selection is not None:
            cdrs_to_mask = list(set(cdr_all).intersection(self.selection))
        else:
            random.shuffle(cdr_all)
            cdrs_to_mask = cdr_all[:num_cdrs_to_mask]

        for cdr_to_mask in cdrs_to_mask:
            self.mask_one_cdr_(data, cdr_to_mask)

    def __call__(self, structure):
        if structure['heavy'] is not None:
            self.mask_for_one_chain_(structure['heavy'])
        if structure['light'] is not None:
            self.mask_for_one_chain_(structure['light'])
        return structure

        
@register_transform('mask_antibody')
class MaskAntibody(object):

    def mask_ab_chain_(self, data):
        data['generate_flag'] = torch.ones(data['aa'].shape, dtype=torch.bool)

    def __call__(self, structure):
        pos_ab_alpha = []
        if structure['heavy'] is not None:
            self.mask_ab_chain_(structure['heavy'])
            pos_ab_alpha.append(
                structure['heavy']['pos_heavyatom'][:, constants.BBHeavyAtom.CA]
            )
        if structure['light'] is not None:
            self.mask_ab_chain_(structure['light'])
            pos_ab_alpha.append(
                structure['light']['pos_heavyatom'][:, constants.BBHeavyAtom.CA]
            )
        pos_ab_alpha = torch.cat(pos_ab_alpha, dim=0)   # (L_Ab, 3)

        if structure['antigen'] is not None:
            pos_ag_alpha = structure['antigen']['pos_heavyatom'][:, constants.BBHeavyAtom.CA]
            ag_ab_dist = torch.cdist(pos_ag_alpha, pos_ab_alpha)    # (L_Ag, L_Ab)
            nn_ab_dist = ag_ab_dist.min(dim=1)[0]   # (L_Ag)
            contact_flag = (nn_ab_dist <= 6.0)      # (L_Ag)
            if contact_flag.sum().item() == 0:
                contact_flag[nn_ab_dist.argmin()] = True

            anchor_idx = torch.multinomial(contact_flag.float(), num_samples=1).item()
            anchor_flag = torch.zeros(structure['antigen']['aa'].shape, dtype=torch.bool)
            anchor_flag[anchor_idx] = True
            structure['antigen']['anchor_flag'] = anchor_flag
            structure['antigen']['contact_flag'] = contact_flag
        
        return structure


@register_transform('remove_antigen')
class RemoveAntigen:

    def __call__(self, structure):
        structure['antigen'] = None
        structure['antigen_seqmap'] = None
        return structure