Spaces:
Runtime error
Runtime error
import os | |
import argparse | |
import copy | |
import json | |
from tqdm.auto import tqdm | |
from torch.utils.data import DataLoader | |
from diffab.datasets.custom import preprocess_antibody_structure | |
from diffab.models import get_model | |
from diffab.modules.common.geometry import reconstruct_backbone_partially | |
from diffab.modules.common.so3 import so3vec_to_rotation | |
from diffab.utils.inference import RemoveNative | |
from diffab.utils.protein.writers import save_pdb | |
from diffab.utils.train import recursive_to | |
from diffab.utils.misc import * | |
from diffab.utils.data import * | |
from diffab.utils.transforms import * | |
from diffab.utils.inference import * | |
from diffab.tools.renumber import renumber as renumber_antibody | |
def create_data_variants(config, structure_factory): | |
structure = structure_factory() | |
structure_id = structure['id'] | |
data_variants = [] | |
if config.mode == 'single_cdr': | |
cdrs = sorted(list(set(find_cdrs(structure)).intersection(config.sampling.cdrs))) | |
for cdr_name in cdrs: | |
transform = Compose([ | |
MaskSingleCDR(cdr_name, augmentation=False), | |
MergeChains(), | |
]) | |
data_var = transform(structure_factory()) | |
residue_first, residue_last = get_residue_first_last(data_var) | |
data_variants.append({ | |
'data': data_var, | |
'name': f'{structure_id}-{cdr_name}', | |
'tag': f'{cdr_name}', | |
'cdr': cdr_name, | |
'residue_first': residue_first, | |
'residue_last': residue_last, | |
}) | |
elif config.mode == 'multiple_cdrs': | |
cdrs = sorted(list(set(find_cdrs(structure)).intersection(config.sampling.cdrs))) | |
transform = Compose([ | |
MaskMultipleCDRs(selection=cdrs, augmentation=False), | |
MergeChains(), | |
]) | |
data_var = transform(structure_factory()) | |
data_variants.append({ | |
'data': data_var, | |
'name': f'{structure_id}-MultipleCDRs', | |
'tag': 'MultipleCDRs', | |
'cdrs': cdrs, | |
'residue_first': None, | |
'residue_last': None, | |
}) | |
elif config.mode == 'full': | |
transform = Compose([ | |
MaskAntibody(), | |
MergeChains(), | |
]) | |
data_var = transform(structure_factory()) | |
data_variants.append({ | |
'data': data_var, | |
'name': f'{structure_id}-Full', | |
'tag': 'Full', | |
'residue_first': None, | |
'residue_last': None, | |
}) | |
elif config.mode == 'abopt': | |
cdrs = sorted(list(set(find_cdrs(structure)).intersection(config.sampling.cdrs))) | |
for cdr_name in cdrs: | |
transform = Compose([ | |
MaskSingleCDR(cdr_name, augmentation=False), | |
MergeChains(), | |
]) | |
data_var = transform(structure_factory()) | |
residue_first, residue_last = get_residue_first_last(data_var) | |
for opt_step in config.sampling.optimize_steps: | |
data_variants.append({ | |
'data': data_var, | |
'name': f'{structure_id}-{cdr_name}-O{opt_step}', | |
'tag': f'{cdr_name}-O{opt_step}', | |
'cdr': cdr_name, | |
'opt_step': opt_step, | |
'residue_first': residue_first, | |
'residue_last': residue_last, | |
}) | |
else: | |
raise ValueError(f'Unknown mode: {config.mode}.') | |
return data_variants | |
def design_for_pdb(args): | |
# Load configs | |
config, config_name = load_config(args.config) | |
seed_all(args.seed if args.seed is not None else config.sampling.seed) | |
# Structure loading | |
data_id = os.path.basename(args.pdb_path) | |
if args.no_renumber: | |
pdb_path = args.pdb_path | |
else: | |
in_pdb_path = args.pdb_path | |
out_pdb_path = os.path.splitext(in_pdb_path)[0] + '_chothia.pdb' | |
heavy_chains, light_chains = renumber_antibody(in_pdb_path, out_pdb_path) | |
pdb_path = out_pdb_path | |
if args.heavy is None and len(heavy_chains) > 0: | |
args.heavy = heavy_chains[0] | |
if args.light is None and len(light_chains) > 0: | |
args.light = light_chains[0] | |
if args.heavy is None and args.light is None: | |
raise ValueError("Neither heavy chain id (--heavy) or light chain id (--light) is specified.") | |
get_structure = lambda: preprocess_antibody_structure({ | |
'id': data_id, | |
'pdb_path': pdb_path, | |
'heavy_id': args.heavy, | |
# If the input is a nanobody, the light chain will be ignores | |
'light_id': args.light, | |
}) | |
# Logging | |
structure_ = get_structure() | |
structure_id = structure_['id'] | |
tag_postfix = '_%s' % args.tag if args.tag else '' | |
log_dir = get_new_log_dir( | |
os.path.join(args.out_root, config_name + tag_postfix), | |
prefix=data_id | |
) | |
logger = get_logger('sample', log_dir) | |
logger.info(f'Data ID: {structure_["id"]}') | |
logger.info(f'Results will be saved to {log_dir}') | |
data_native = MergeChains()(structure_) | |
save_pdb(data_native, os.path.join(log_dir, 'reference.pdb')) | |
# Load checkpoint and model | |
logger.info('Loading model config and checkpoints: %s' % (config.model.checkpoint)) | |
ckpt = torch.load(config.model.checkpoint, map_location='cpu') | |
cfg_ckpt = ckpt['config'] | |
model = get_model(cfg_ckpt.model).to(args.device) | |
lsd = model.load_state_dict(ckpt['model']) | |
logger.info(str(lsd)) | |
# Make data variants | |
data_variants = create_data_variants( | |
config = config, | |
structure_factory = get_structure, | |
) | |
# Save metadata | |
metadata = { | |
'identifier': structure_id, | |
'index': data_id, | |
'config': args.config, | |
'items': [{kk: vv for kk, vv in var.items() if kk != 'data'} for var in data_variants], | |
} | |
with open(os.path.join(log_dir, 'metadata.json'), 'w') as f: | |
json.dump(metadata, f, indent=2) | |
# Start sampling | |
collate_fn = PaddingCollate(eight=False) | |
inference_tfm = [ PatchAroundAnchor(), ] | |
if 'abopt' not in config.mode: # Don't remove native CDR in optimization mode | |
inference_tfm.append(RemoveNative( | |
remove_structure = config.sampling.sample_structure, | |
remove_sequence = config.sampling.sample_sequence, | |
)) | |
inference_tfm = Compose(inference_tfm) | |
for variant in data_variants: | |
os.makedirs(os.path.join(log_dir, variant['tag']), exist_ok=True) | |
logger.info(f"Start sampling for: {variant['tag']}") | |
save_pdb(data_native, os.path.join(log_dir, variant['tag'], 'REF1.pdb')) # w/ OpenMM minimization | |
data_cropped = inference_tfm( | |
copy.deepcopy(variant['data']) | |
) | |
data_list_repeat = [ data_cropped ] * config.sampling.num_samples | |
loader = DataLoader(data_list_repeat, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn) | |
count = 0 | |
for batch in tqdm(loader, desc=variant['name'], dynamic_ncols=True): | |
torch.set_grad_enabled(False) | |
model.eval() | |
batch = recursive_to(batch, args.device) | |
if 'abopt' in config.mode: | |
# Antibody optimization starting from native | |
traj_batch = model.optimize(batch, opt_step=variant['opt_step'], optimize_opt={ | |
'pbar': True, | |
'sample_structure': config.sampling.sample_structure, | |
'sample_sequence': config.sampling.sample_sequence, | |
}) | |
else: | |
# De novo design | |
traj_batch = model.sample(batch, sample_opt={ | |
'pbar': True, | |
'sample_structure': config.sampling.sample_structure, | |
'sample_sequence': config.sampling.sample_sequence, | |
}) | |
aa_new = traj_batch[0][2] # 0: Last sampling step. 2: Amino acid. | |
pos_atom_new, mask_atom_new = reconstruct_backbone_partially( | |
pos_ctx = batch['pos_heavyatom'], | |
R_new = so3vec_to_rotation(traj_batch[0][0]), | |
t_new = traj_batch[0][1], | |
aa = aa_new, | |
chain_nb = batch['chain_nb'], | |
res_nb = batch['res_nb'], | |
mask_atoms = batch['mask_heavyatom'], | |
mask_recons = batch['generate_flag'], | |
) | |
aa_new = aa_new.cpu() | |
pos_atom_new = pos_atom_new.cpu() | |
mask_atom_new = mask_atom_new.cpu() | |
for i in range(aa_new.size(0)): | |
data_tmpl = variant['data'] | |
aa = apply_patch_to_tensor(data_tmpl['aa'], aa_new[i], data_cropped['patch_idx']) | |
mask_ha = apply_patch_to_tensor(data_tmpl['mask_heavyatom'], mask_atom_new[i], data_cropped['patch_idx']) | |
pos_ha = ( | |
apply_patch_to_tensor( | |
data_tmpl['pos_heavyatom'], | |
pos_atom_new[i] + batch['origin'][i].view(1, 1, 3).cpu(), | |
data_cropped['patch_idx'] | |
) | |
) | |
save_path = os.path.join(log_dir, variant['tag'], '%04d.pdb' % (count, )) | |
save_pdb({ | |
'chain_nb': data_tmpl['chain_nb'], | |
'chain_id': data_tmpl['chain_id'], | |
'resseq': data_tmpl['resseq'], | |
'icode': data_tmpl['icode'], | |
# Generated | |
'aa': aa, | |
'mask_heavyatom': mask_ha, | |
'pos_heavyatom': pos_ha, | |
}, path=save_path) | |
# save_pdb({ | |
# 'chain_nb': data_cropped['chain_nb'], | |
# 'chain_id': data_cropped['chain_id'], | |
# 'resseq': data_cropped['resseq'], | |
# 'icode': data_cropped['icode'], | |
# # Generated | |
# 'aa': aa_new[i], | |
# 'mask_heavyatom': mask_atom_new[i], | |
# 'pos_heavyatom': pos_atom_new[i] + batch['origin'][i].view(1, 1, 3).cpu(), | |
# }, path=os.path.join(log_dir, variant['tag'], '%04d_patch.pdb' % (count, ))) | |
count += 1 | |
logger.info('Finished.\n') | |
def args_from_cmdline(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument('pdb_path', type=str) | |
parser.add_argument('--heavy', type=str, default=None, help='Chain id of the heavy chain.') | |
parser.add_argument('--light', type=str, default=None, help='Chain id of the light chain.') | |
parser.add_argument('--no_renumber', action='store_true', default=False) | |
parser.add_argument('-c', '--config', type=str, default='./configs/test/codesign_single.yml') | |
parser.add_argument('-o', '--out_root', type=str, default='./results') | |
parser.add_argument('-t', '--tag', type=str, default='') | |
parser.add_argument('-s', '--seed', type=int, default=None) | |
parser.add_argument('-d', '--device', type=str, default='cuda') | |
parser.add_argument('-b', '--batch_size', type=int, default=16) | |
args = parser.parse_args() | |
return args | |
def args_factory(**kwargs): | |
default_args = EasyDict( | |
heavy = 'H', | |
light = 'L', | |
no_renumber = False, | |
config = './configs/test/codesign_single.yml', | |
out_root = './results', | |
tag = '', | |
seed = None, | |
device = 'cuda', | |
batch_size = 16 | |
) | |
default_args.update(kwargs) | |
return default_args | |
if __name__ == '__main__': | |
design_for_pdb(args_from_cmdline()) | |