Spaces:
Runtime error
Runtime error
import torch | |
from .protein import constants | |
def find_cdrs(structure): | |
cdrs = [] | |
if structure['heavy'] is not None: | |
flag = structure['heavy']['cdr_flag'] | |
if int(constants.CDR.H1) in flag: | |
cdrs.append('H_CDR1') | |
if int(constants.CDR.H2) in flag: | |
cdrs.append('H_CDR2') | |
if int(constants.CDR.H3) in flag: | |
cdrs.append('H_CDR3') | |
if structure['light'] is not None: | |
flag = structure['light']['cdr_flag'] | |
if int(constants.CDR.L1) in flag: | |
cdrs.append('L_CDR1') | |
if int(constants.CDR.L2) in flag: | |
cdrs.append('L_CDR2') | |
if int(constants.CDR.L3) in flag: | |
cdrs.append('L_CDR3') | |
return cdrs | |
def get_residue_first_last(data): | |
loop_flag = data['generate_flag'] | |
loop_idx = torch.arange(loop_flag.size(0))[loop_flag] | |
idx_first, idx_last = loop_idx.min().item(), loop_idx.max().item() | |
residue_first = (data['chain_id'][idx_first], data['resseq'][idx_first].item(), data['icode'][idx_first]) | |
residue_last = (data['chain_id'][idx_last], data['resseq'][idx_last].item(), data['icode'][idx_last]) | |
return residue_first, residue_last | |
class RemoveNative(object): | |
def __init__(self, remove_structure, remove_sequence): | |
super().__init__() | |
self.remove_structure = remove_structure | |
self.remove_sequence = remove_sequence | |
def __call__(self, data): | |
generate_flag = data['generate_flag'].clone() | |
if self.remove_sequence: | |
data['aa'] = torch.where( | |
generate_flag, | |
torch.full_like(data['aa'], fill_value=int(constants.AA.UNK)), # Is loop | |
data['aa'] | |
) | |
if self.remove_structure: | |
data['pos_heavyatom'] = torch.where( | |
generate_flag[:, None, None].expand(data['pos_heavyatom'].shape), | |
torch.randn_like(data['pos_heavyatom']) * 10, | |
data['pos_heavyatom'] | |
) | |
return data |