Spaces:
Runtime error
Runtime error
File size: 2,026 Bytes
753e275 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
import torch
from .protein import constants
def find_cdrs(structure):
cdrs = []
if structure['heavy'] is not None:
flag = structure['heavy']['cdr_flag']
if int(constants.CDR.H1) in flag:
cdrs.append('H_CDR1')
if int(constants.CDR.H2) in flag:
cdrs.append('H_CDR2')
if int(constants.CDR.H3) in flag:
cdrs.append('H_CDR3')
if structure['light'] is not None:
flag = structure['light']['cdr_flag']
if int(constants.CDR.L1) in flag:
cdrs.append('L_CDR1')
if int(constants.CDR.L2) in flag:
cdrs.append('L_CDR2')
if int(constants.CDR.L3) in flag:
cdrs.append('L_CDR3')
return cdrs
def get_residue_first_last(data):
loop_flag = data['generate_flag']
loop_idx = torch.arange(loop_flag.size(0))[loop_flag]
idx_first, idx_last = loop_idx.min().item(), loop_idx.max().item()
residue_first = (data['chain_id'][idx_first], data['resseq'][idx_first].item(), data['icode'][idx_first])
residue_last = (data['chain_id'][idx_last], data['resseq'][idx_last].item(), data['icode'][idx_last])
return residue_first, residue_last
class RemoveNative(object):
def __init__(self, remove_structure, remove_sequence):
super().__init__()
self.remove_structure = remove_structure
self.remove_sequence = remove_sequence
def __call__(self, data):
generate_flag = data['generate_flag'].clone()
if self.remove_sequence:
data['aa'] = torch.where(
generate_flag,
torch.full_like(data['aa'], fill_value=int(constants.AA.UNK)), # Is loop
data['aa']
)
if self.remove_structure:
data['pos_heavyatom'] = torch.where(
generate_flag[:, None, None].expand(data['pos_heavyatom'].shape),
torch.randn_like(data['pos_heavyatom']) * 10,
data['pos_heavyatom']
)
return data |