import os import argparse import copy import json from tqdm.auto import tqdm from torch.utils.data import DataLoader from diffab.datasets.custom import preprocess_antibody_structure from diffab.models import get_model from diffab.modules.common.geometry import reconstruct_backbone_partially from diffab.modules.common.so3 import so3vec_to_rotation from diffab.utils.inference import RemoveNative from diffab.utils.protein.writers import save_pdb from diffab.utils.train import recursive_to from diffab.utils.misc import * from diffab.utils.data import * from diffab.utils.transforms import * from diffab.utils.inference import * from diffab.tools.renumber import renumber as renumber_antibody def create_data_variants(config, structure_factory): structure = structure_factory() structure_id = structure['id'] data_variants = [] if config.mode == 'single_cdr': cdrs = sorted(list(set(find_cdrs(structure)).intersection(config.sampling.cdrs))) for cdr_name in cdrs: transform = Compose([ MaskSingleCDR(cdr_name, augmentation=False), MergeChains(), ]) data_var = transform(structure_factory()) residue_first, residue_last = get_residue_first_last(data_var) data_variants.append({ 'data': data_var, 'name': f'{structure_id}-{cdr_name}', 'tag': f'{cdr_name}', 'cdr': cdr_name, 'residue_first': residue_first, 'residue_last': residue_last, }) elif config.mode == 'multiple_cdrs': cdrs = sorted(list(set(find_cdrs(structure)).intersection(config.sampling.cdrs))) transform = Compose([ MaskMultipleCDRs(selection=cdrs, augmentation=False), MergeChains(), ]) data_var = transform(structure_factory()) data_variants.append({ 'data': data_var, 'name': f'{structure_id}-MultipleCDRs', 'tag': 'MultipleCDRs', 'cdrs': cdrs, 'residue_first': None, 'residue_last': None, }) elif config.mode == 'full': transform = Compose([ MaskAntibody(), MergeChains(), ]) data_var = transform(structure_factory()) data_variants.append({ 'data': data_var, 'name': f'{structure_id}-Full', 'tag': 'Full', 'residue_first': None, 'residue_last': None, }) elif config.mode == 'abopt': cdrs = sorted(list(set(find_cdrs(structure)).intersection(config.sampling.cdrs))) for cdr_name in cdrs: transform = Compose([ MaskSingleCDR(cdr_name, augmentation=False), MergeChains(), ]) data_var = transform(structure_factory()) residue_first, residue_last = get_residue_first_last(data_var) for opt_step in config.sampling.optimize_steps: data_variants.append({ 'data': data_var, 'name': f'{structure_id}-{cdr_name}-O{opt_step}', 'tag': f'{cdr_name}-O{opt_step}', 'cdr': cdr_name, 'opt_step': opt_step, 'residue_first': residue_first, 'residue_last': residue_last, }) else: raise ValueError(f'Unknown mode: {config.mode}.') return data_variants def design_for_pdb(args): # Load configs config, config_name = load_config(args.config) seed_all(args.seed if args.seed is not None else config.sampling.seed) # Structure loading data_id = os.path.basename(args.pdb_path) if args.no_renumber: pdb_path = args.pdb_path else: in_pdb_path = args.pdb_path out_pdb_path = os.path.splitext(in_pdb_path)[0] + '_chothia.pdb' heavy_chains, light_chains = renumber_antibody(in_pdb_path, out_pdb_path) pdb_path = out_pdb_path if args.heavy is None and len(heavy_chains) > 0: args.heavy = heavy_chains[0] if args.light is None and len(light_chains) > 0: args.light = light_chains[0] if args.heavy is None and args.light is None: raise ValueError("Neither heavy chain id (--heavy) or light chain id (--light) is specified.") get_structure = lambda: preprocess_antibody_structure({ 'id': data_id, 'pdb_path': pdb_path, 'heavy_id': args.heavy, # If the input is a nanobody, the light chain will be ignores 'light_id': args.light, }) # Logging structure_ = get_structure() structure_id = structure_['id'] tag_postfix = '_%s' % args.tag if args.tag else '' log_dir = get_new_log_dir( os.path.join(args.out_root, config_name + tag_postfix), prefix=data_id ) logger = get_logger('sample', log_dir) logger.info(f'Data ID: {structure_["id"]}') logger.info(f'Results will be saved to {log_dir}') data_native = MergeChains()(structure_) save_pdb(data_native, os.path.join(log_dir, 'reference.pdb')) # Load checkpoint and model logger.info('Loading model config and checkpoints: %s' % (config.model.checkpoint)) ckpt = torch.load(config.model.checkpoint, map_location='cpu') cfg_ckpt = ckpt['config'] model = get_model(cfg_ckpt.model).to(args.device) lsd = model.load_state_dict(ckpt['model']) logger.info(str(lsd)) # Make data variants data_variants = create_data_variants( config = config, structure_factory = get_structure, ) # Save metadata metadata = { 'identifier': structure_id, 'index': data_id, 'config': args.config, 'items': [{kk: vv for kk, vv in var.items() if kk != 'data'} for var in data_variants], } with open(os.path.join(log_dir, 'metadata.json'), 'w') as f: json.dump(metadata, f, indent=2) # Start sampling collate_fn = PaddingCollate(eight=False) inference_tfm = [ PatchAroundAnchor(), ] if 'abopt' not in config.mode: # Don't remove native CDR in optimization mode inference_tfm.append(RemoveNative( remove_structure = config.sampling.sample_structure, remove_sequence = config.sampling.sample_sequence, )) inference_tfm = Compose(inference_tfm) for variant in data_variants: os.makedirs(os.path.join(log_dir, variant['tag']), exist_ok=True) logger.info(f"Start sampling for: {variant['tag']}") save_pdb(data_native, os.path.join(log_dir, variant['tag'], 'REF1.pdb')) # w/ OpenMM minimization data_cropped = inference_tfm( copy.deepcopy(variant['data']) ) data_list_repeat = [ data_cropped ] * config.sampling.num_samples loader = DataLoader(data_list_repeat, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn) count = 0 for batch in tqdm(loader, desc=variant['name'], dynamic_ncols=True): torch.set_grad_enabled(False) model.eval() batch = recursive_to(batch, args.device) if 'abopt' in config.mode: # Antibody optimization starting from native traj_batch = model.optimize(batch, opt_step=variant['opt_step'], optimize_opt={ 'pbar': True, 'sample_structure': config.sampling.sample_structure, 'sample_sequence': config.sampling.sample_sequence, }) else: # De novo design traj_batch = model.sample(batch, sample_opt={ 'pbar': True, 'sample_structure': config.sampling.sample_structure, 'sample_sequence': config.sampling.sample_sequence, }) aa_new = traj_batch[0][2] # 0: Last sampling step. 2: Amino acid. pos_atom_new, mask_atom_new = reconstruct_backbone_partially( pos_ctx = batch['pos_heavyatom'], R_new = so3vec_to_rotation(traj_batch[0][0]), t_new = traj_batch[0][1], aa = aa_new, chain_nb = batch['chain_nb'], res_nb = batch['res_nb'], mask_atoms = batch['mask_heavyatom'], mask_recons = batch['generate_flag'], ) aa_new = aa_new.cpu() pos_atom_new = pos_atom_new.cpu() mask_atom_new = mask_atom_new.cpu() for i in range(aa_new.size(0)): data_tmpl = variant['data'] aa = apply_patch_to_tensor(data_tmpl['aa'], aa_new[i], data_cropped['patch_idx']) mask_ha = apply_patch_to_tensor(data_tmpl['mask_heavyatom'], mask_atom_new[i], data_cropped['patch_idx']) pos_ha = ( apply_patch_to_tensor( data_tmpl['pos_heavyatom'], pos_atom_new[i] + batch['origin'][i].view(1, 1, 3).cpu(), data_cropped['patch_idx'] ) ) save_path = os.path.join(log_dir, variant['tag'], '%04d.pdb' % (count, )) save_pdb({ 'chain_nb': data_tmpl['chain_nb'], 'chain_id': data_tmpl['chain_id'], 'resseq': data_tmpl['resseq'], 'icode': data_tmpl['icode'], # Generated 'aa': aa, 'mask_heavyatom': mask_ha, 'pos_heavyatom': pos_ha, }, path=save_path) # save_pdb({ # 'chain_nb': data_cropped['chain_nb'], # 'chain_id': data_cropped['chain_id'], # 'resseq': data_cropped['resseq'], # 'icode': data_cropped['icode'], # # Generated # 'aa': aa_new[i], # 'mask_heavyatom': mask_atom_new[i], # 'pos_heavyatom': pos_atom_new[i] + batch['origin'][i].view(1, 1, 3).cpu(), # }, path=os.path.join(log_dir, variant['tag'], '%04d_patch.pdb' % (count, ))) count += 1 logger.info('Finished.\n') def args_from_cmdline(): parser = argparse.ArgumentParser() parser.add_argument('pdb_path', type=str) parser.add_argument('--heavy', type=str, default=None, help='Chain id of the heavy chain.') parser.add_argument('--light', type=str, default=None, help='Chain id of the light chain.') parser.add_argument('--no_renumber', action='store_true', default=False) parser.add_argument('-c', '--config', type=str, default='./configs/test/codesign_single.yml') parser.add_argument('-o', '--out_root', type=str, default='./results') parser.add_argument('-t', '--tag', type=str, default='') parser.add_argument('-s', '--seed', type=int, default=None) parser.add_argument('-d', '--device', type=str, default='cuda') parser.add_argument('-b', '--batch_size', type=int, default=16) args = parser.parse_args() return args def args_factory(**kwargs): default_args = EasyDict( heavy = 'H', light = 'L', no_renumber = False, config = './configs/test/codesign_single.yml', out_root = './results', tag = '', seed = None, device = 'cuda', batch_size = 16 ) default_args.update(kwargs) return default_args if __name__ == '__main__': design_for_pdb(args_from_cmdline())