DiffAb / diffab /tools /runner /design_for_testset.py
luost26's picture
Update
753e275
import os
import argparse
import copy
import json
from tqdm.auto import tqdm
from torch.utils.data import DataLoader
from diffab.datasets import get_dataset
from diffab.models import get_model
from diffab.modules.common.geometry import reconstruct_backbone_partially
from diffab.modules.common.so3 import so3vec_to_rotation
from diffab.utils.inference import RemoveNative
from diffab.utils.protein.writers import save_pdb
from diffab.utils.train import recursive_to
from diffab.utils.misc import *
from diffab.utils.data import *
from diffab.utils.transforms import *
from diffab.utils.inference import *
def create_data_variants(config, structure_factory):
structure = structure_factory()
structure_id = structure['id']
data_variants = []
if config.mode == 'single_cdr':
cdrs = sorted(list(set(find_cdrs(structure)).intersection(config.sampling.cdrs)))
for cdr_name in cdrs:
transform = Compose([
MaskSingleCDR(cdr_name, augmentation=False),
MergeChains(),
])
data_var = transform(structure_factory())
residue_first, residue_last = get_residue_first_last(data_var)
data_variants.append({
'data': data_var,
'name': f'{structure_id}-{cdr_name}',
'tag': f'{cdr_name}',
'cdr': cdr_name,
'residue_first': residue_first,
'residue_last': residue_last,
})
elif config.mode == 'multiple_cdrs':
cdrs = sorted(list(set(find_cdrs(structure)).intersection(config.sampling.cdrs)))
transform = Compose([
MaskMultipleCDRs(selection=cdrs, augmentation=False),
MergeChains(),
])
data_var = transform(structure_factory())
data_variants.append({
'data': data_var,
'name': f'{structure_id}-MultipleCDRs',
'tag': 'MultipleCDRs',
'cdrs': cdrs,
'residue_first': None,
'residue_last': None,
})
elif config.mode == 'full':
transform = Compose([
MaskAntibody(),
MergeChains(),
])
data_var = transform(structure_factory())
data_variants.append({
'data': data_var,
'name': f'{structure_id}-Full',
'tag': 'Full',
'residue_first': None,
'residue_last': None,
})
elif config.mode == 'abopt':
cdrs = sorted(list(set(find_cdrs(structure)).intersection(config.sampling.cdrs)))
for cdr_name in cdrs:
transform = Compose([
MaskSingleCDR(cdr_name, augmentation=False),
MergeChains(),
])
data_var = transform(structure_factory())
residue_first, residue_last = get_residue_first_last(data_var)
for opt_step in config.sampling.optimize_steps:
data_variants.append({
'data': data_var,
'name': f'{structure_id}-{cdr_name}-O{opt_step}',
'tag': f'{cdr_name}-O{opt_step}',
'cdr': cdr_name,
'opt_step': opt_step,
'residue_first': residue_first,
'residue_last': residue_last,
})
else:
raise ValueError(f'Unknown mode: {config.mode}.')
return data_variants
def main():
parser = argparse.ArgumentParser()
parser.add_argument('index', type=int)
parser.add_argument('-c', '--config', type=str, default='./configs/test/codesign_single.yml')
parser.add_argument('-o', '--out_root', type=str, default='./results')
parser.add_argument('-t', '--tag', type=str, default='')
parser.add_argument('-s', '--seed', type=int, default=None)
parser.add_argument('-d', '--device', type=str, default='cuda')
parser.add_argument('-b', '--batch_size', type=int, default=16)
args = parser.parse_args()
# Load configs
config, config_name = load_config(args.config)
seed_all(args.seed if args.seed is not None else config.sampling.seed)
# Testset
dataset = get_dataset(config.dataset.test)
get_structure = lambda: dataset[args.index]
# Logging
structure_ = get_structure()
structure_id = structure_['id']
tag_postfix = '_%s' % args.tag if args.tag else ''
log_dir = get_new_log_dir(os.path.join(args.out_root, config_name + tag_postfix), prefix='%04d_%s' % (args.index, structure_['id']))
logger = get_logger('sample', log_dir)
logger.info('Data ID: %s' % structure_['id'])
data_native = MergeChains()(structure_)
save_pdb(data_native, os.path.join(log_dir, 'reference.pdb'))
# Load checkpoint and model
logger.info('Loading model config and checkpoints: %s' % (config.model.checkpoint))
ckpt = torch.load(config.model.checkpoint, map_location='cpu')
cfg_ckpt = ckpt['config']
model = get_model(cfg_ckpt.model).to(args.device)
lsd = model.load_state_dict(ckpt['model'])
logger.info(str(lsd))
# Make data variants
data_variants = create_data_variants(
config = config,
structure_factory = get_structure,
)
# Save metadata
metadata = {
'identifier': structure_id,
'index': args.index,
'config': args.config,
'items': [{kk: vv for kk, vv in var.items() if kk != 'data'} for var in data_variants],
}
with open(os.path.join(log_dir, 'metadata.json'), 'w') as f:
json.dump(metadata, f, indent=2)
# Start sampling
collate_fn = PaddingCollate(eight=False)
inference_tfm = [ PatchAroundAnchor(), ]
if 'abopt' not in config.mode: # Don't remove native CDR in optimization mode
inference_tfm.append(RemoveNative(
remove_structure = config.sampling.sample_structure,
remove_sequence = config.sampling.sample_sequence,
))
inference_tfm = Compose(inference_tfm)
for variant in data_variants:
os.makedirs(os.path.join(log_dir, variant['tag']), exist_ok=True)
logger.info(f"Start sampling for: {variant['tag']}")
save_pdb(data_native, os.path.join(log_dir, variant['tag'], 'REF1.pdb')) # w/ OpenMM minimization
data_cropped = inference_tfm(
copy.deepcopy(variant['data'])
)
data_list_repeat = [ data_cropped ] * config.sampling.num_samples
loader = DataLoader(data_list_repeat, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn)
count = 0
for batch in tqdm(loader, desc=variant['name'], dynamic_ncols=True):
torch.set_grad_enabled(False)
model.eval()
batch = recursive_to(batch, args.device)
if 'abopt' in config.mode:
# Antibody optimization starting from native
traj_batch = model.optimize(batch, opt_step=variant['opt_step'], optimize_opt={
'pbar': True,
'sample_structure': config.sampling.sample_structure,
'sample_sequence': config.sampling.sample_sequence,
})
else:
# De novo design
traj_batch = model.sample(batch, sample_opt={
'pbar': True,
'sample_structure': config.sampling.sample_structure,
'sample_sequence': config.sampling.sample_sequence,
})
aa_new = traj_batch[0][2] # 0: Last sampling step. 2: Amino acid.
pos_atom_new, mask_atom_new = reconstruct_backbone_partially(
pos_ctx = batch['pos_heavyatom'],
R_new = so3vec_to_rotation(traj_batch[0][0]),
t_new = traj_batch[0][1],
aa = aa_new,
chain_nb = batch['chain_nb'],
res_nb = batch['res_nb'],
mask_atoms = batch['mask_heavyatom'],
mask_recons = batch['generate_flag'],
)
aa_new = aa_new.cpu()
pos_atom_new = pos_atom_new.cpu()
mask_atom_new = mask_atom_new.cpu()
for i in range(aa_new.size(0)):
data_tmpl = variant['data']
aa = apply_patch_to_tensor(data_tmpl['aa'], aa_new[i], data_cropped['patch_idx'])
mask_ha = apply_patch_to_tensor(data_tmpl['mask_heavyatom'], mask_atom_new[i], data_cropped['patch_idx'])
pos_ha = (
apply_patch_to_tensor(
data_tmpl['pos_heavyatom'],
pos_atom_new[i] + batch['origin'][i].view(1, 1, 3).cpu(),
data_cropped['patch_idx']
)
)
save_path = os.path.join(log_dir, variant['tag'], '%04d.pdb' % (count, ))
save_pdb({
'chain_nb': data_tmpl['chain_nb'],
'chain_id': data_tmpl['chain_id'],
'resseq': data_tmpl['resseq'],
'icode': data_tmpl['icode'],
# Generated
'aa': aa,
'mask_heavyatom': mask_ha,
'pos_heavyatom': pos_ha,
}, path=save_path)
# save_pdb({
# 'chain_nb': data_cropped['chain_nb'],
# 'chain_id': data_cropped['chain_id'],
# 'resseq': data_cropped['resseq'],
# 'icode': data_cropped['icode'],
# # Generated
# 'aa': aa_new[i],
# 'mask_heavyatom': mask_atom_new[i],
# 'pos_heavyatom': pos_atom_new[i] + batch['origin'][i].view(1, 1, 3).cpu(),
# }, path=os.path.join(log_dir, variant['tag'], '%04d_patch.pdb' % (count, )))
count += 1
logger.info('Finished.\n')
if __name__ == '__main__':
main()