diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml new file mode 100644 index 0000000000000000000000000000000000000000..5be2100e0abd949bb98439896128189fbf2930d0 --- /dev/null +++ b/.github/workflows/stale.yml @@ -0,0 +1,27 @@ +# This workflow warns and then closes issues and PRs that have had no activity for a specified amount of time. +# +# You can adjust the behavior by modifying this file. +# For more information, see: +# https://github.com/actions/stale +name: Close inactive issues +on: + schedule: + - cron: "30 1 * * *" + +jobs: + close-issues: + runs-on: ubuntu-latest + permissions: + issues: write + pull-requests: write + steps: + - uses: actions/stale@v9 + with: + days-before-issue-stale: 30 + days-before-issue-close: 14 + stale-issue-label: "stale" + stale-issue-message: "This issue is stale because it has been open for 30 days with no activity." + close-issue-message: "This issue was closed because it has been inactive for 14 days since being marked as stale." + days-before-pr-stale: -1 + days-before-pr-close: -1 + repo-token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..805fe80b12d5e1ba0758d6058d1f5baf3de3d039 --- /dev/null +++ b/.gitignore @@ -0,0 +1,35 @@ +__pycache__ + +__cache__ + +__tmcache__ + +ckpts + +checkpoints + +*_results* + +datasets + +exps + +DockQ + +TMscore + +*.txt + +*.pt + +*.png + +*.pkl + +*.svg + +*.log + +*.pdb + +*.jsonl diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..359bb5307e8535ab7d59faf27a7377033291821e --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,3 @@ +# 默认忽略的文件 +/shelf/ +/workspace.xml diff --git a/.idea/PepGLAD.iml b/.idea/PepGLAD.iml new file mode 100644 index 0000000000000000000000000000000000000000..e04280bd4cb0b5b9512962f3b4c0e4b57fe79b63 --- /dev/null +++ b/.idea/PepGLAD.iml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000000000000000000000000000000000000..105ce2da2d6447d11dfe32bfb846c3d5b199fc99 --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000000000000000000000000000000000000..bc506c6fc1f22cbb3301ebbbf93d5047202fbee1 --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000000000000000000000000000000000000..246133dfe02589c17d99750e3b22b4e8ca6ceb32 --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000000000000000000000000000000000000..35eb1ddfbbc029bcab630581847471d7f238ec53 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..7029d81cc0d2568b8f316b4d4eb9b780b8f096b9 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 THUNLP + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md index 7be5fc7f47d5db027d120b8024982df93db95b74..484cba8feec8539c2efe19ccba73bceee5e074e5 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,214 @@ ---- -license: mit ---- +# PepGLAD: Full-Atom Peptide Design with Geometric Latent Diffusion + +![cover](./assets/cover.png) + +## Quick Links + +- [Setup](#setup) + - [Environment](#environment) + - [Datasets](#optional-datasets) + - [Trained Weights](#trained-weights) +- [Usage](#usage) + - [Peptide Sequence-Structure Co-Design](#peptide-sequence-structure-co-design) + - [Peptide Binding Conformation Generation](#peptide-binding-conformation-generation) +- [Reproduction of Paper Experiments](#reproduction-of-paper-experiments) + - [Codesign](#codesign) + - [Binding Conformation Generation](#binding-conformation-generation) +- [Contact](#contact) +- [Reference](#reference) + +## Updates + +Changes for compatibilities and extended functionalities are saved in [beta](https://github.com/THUNLP-MT/PepGLAD/tree/beta) branch. Thank [@Barry0121](https://github.com/Barry0121) for the help. + +- pyTorch 2.6.0 and openmm 8.2.0 are supported, with new environment configure at [2025_env.yaml](https://github.com/THUNLP-MT/PepGLAD/blob/beta/2025_env.yml). +- Support non-canonical amino acids in `detect_pocket.py`. + + +## Setup + +### Environment + +The conda environment can be constructed with the configuration `env.yaml`: + +```bash +conda env create -f env.yaml +``` + +The codes are tested with cuda version `11.7` and pytorch version `1.13.1`. + +Don't forget to activate the environment before running the codes: + +```bash +conda activate PepGLAD +``` + +#### (Optional) pyRosetta + +PyRosetta is used to calculate interface energy of generated peptides. If you are interested in it, please follow the instruction [here](https://www.pyrosetta.org/downloads) to install. + +### (Optional) Datasets + +These datasets are only used for benchmarking models. If you just want to use the trained weights for inferencing on your cases, there is no need to download these datasets. + +#### PepBench + +1. Download + +The datasets, which are originally introduced in this paper, are uploaded to Zenodo at [this url](https://zenodo.org/records/13373108). You can download them as follows: + +```bash +mkdir datasets # all datasets will be put into this directory +wget https://zenodo.org/records/13373108/files/train_valid.tar.gz?download=1 -O ./datasets/train_valid.tar.gz # training/validation +wget https://zenodo.org/records/13373108/files/LNR.tar.gz?download=1 -O ./datasets/LNR.tar.gz # test set +wget https://zenodo.org/records/13373108/files/ProtFrag.tar.gz?download=1 -O ./datasets/ProtFrag.tar.gz # augmentation dataset +``` + +2. Decompresss + +```bash +tar zxvf ./datasets/train_valid.tar.gz -C ./datasets +tar zxvf ./datasets/LNR.tar.gz -C ./datasets +tar zxvf ./datasets/ProtFrag.tar.gz -C ./datasets +``` + +3. Process + +```bash +python -m scripts.data_process.process --index ./datasets/train_valid/all.txt --out_dir ./datasets/train_valid/processed # train/validation set +python -m scripts.data_process.process --index ./datasets/LNR/test.txt --out_dir ./datasets/LNR/processed # test set +python -m scripts.data_process.process --index ./datasets/ProtFrag/all.txt --out_dir ./datasets/ProtFrag/processed # augmentation dataset +``` + +The index of processed data for train/validation splits need to be generated as follows, which will result in `datasets/train_valid/processed/train_index.txt` and `datasets/train_valid/processed/valid_index.txt`: + +```bash +python -m scripts.data_process.split --train_index datasets/train_valid/train.txt --valid_index datasets/train_valid/valid.txt --processed_dir datasets/train_valid/processed/ +``` + +#### PepBDB + +1. Download + +```bash +wget http://huanglab.phys.hust.edu.cn/pepbdb/db/download/pepbdb-20200318.tgz -O ./datasets/pepbdb.tgz +``` + +2. Decompress + +```bash +tar zxvf ./datasets/pepbdb.tgz -C ./datasets/pepbdb +``` + + +3. Process + +```bash +python -m scripts.data_process.pepbdb --index ./datasets/pepbdb/peptidelist.txt --out_dir ./datasets/pepbdb/processed +python -m scripts.data_process.split --train_index ./datasets/pepbdb/train.txt --valid_index ./datasets/pepbdb/valid.txt --test_index ./datasets/pepbdb/test.txt --processed_dir datasets/pepbdb/processed/ +mv ./datasets/pepbdb/processed/pdbs ./dataset/pepbdb # re-locate +``` + + +### Trained Weights + +- codesign: `./checkpoint/codesign.ckpt` +- conformation generation: `./checkpoints/fixseq.ckpt` + +Both can be downloaded at the [release page](https://github.com/THUNLP-MT/PepGLAD/releases/tag/v1.0). These checkpoints were trained on PepBench. + +## Usage + +:warning: Before using the following codes, please first download the trained weights mentioned above. + +### Peptide Sequence-Structure Co-Design + +Take `./assets/1ssc_A_B.pdb` as an example, where chain A is the target protein: + +```bash +# obtain the binding site, which might also be manually crafted or from other ligands (e.g. small molecule, antibodies) +python -m api.detect_pocket --pdb assets/1ssc_A_B.pdb --target_chains A --ligand_chains B --out assets/1ssc_A_pocket.json +# sequence-structure codesign with length in [8, 15) +CUDA_VISIBLE_DEVICES=0 python -m api.run \ + --mode codesign \ + --pdb assets/1ssc_A_B.pdb \ + --pocket assets/1ssc_A_pocket.json \ + --out_dir ./output/codesign \ + --length_min 8 \ + --length_max 15 \ + --n_samples 10 +``` +Then 10 generations will be outputed under the folder `./output/codesign`. + +### Peptide Binding Conformation Generation + +Take `./assets/1ssc_A_B.pdb` as an example, where chain A is the target protein: + +```bash +# obtain the binding site, which might also be manually crafted or from other ligands (e.g. small molecule, antibodies) +python -m api.detect_pocket --pdb assets/1ssc_A_B.pdb --target_chains A --ligand_chains B --out assets/1ssc_A_pocket.json +# generate binding conformation +CUDA_VISIBLE_DEVICES=0 python -m api.run \ + --mode struct_pred \ + --pdb assets/1ssc_A_B.pdb \ + --pocket assets/1ssc_A_pocket.json \ + --out_dir ./output/struct_pred \ + --peptide_seq PYVPVHFDASV \ + --n_samples 10 +``` +Then 10 conformations will be outputed under the folder `./output/struct_pred`. + + +## Reproduction of Paper Experiments + +Each task requires the following steps, which we have integrated into the script `./scripts/run_exp_pipe.sh`: + +1. Train autoencoder +2. Train latent diffusion model +3. Calculate distribution of latent distances between consecutive residues +4. Generation & Evaluation + +On the other hand, if you want to evaluate existing checkpoints, please follow the instructions below (e.g. conformation generation): + +```bash +# generate results on the test set and save to ./results/fixseq +python generate.py --config configs/pepbench/test_fixseq.yaml --ckpt checkpoints/fixseq.ckpt --gpu 0 --save_dir ./results/fixseq +# calculate metrics +python cal_metrics.py --results ./results/fixseq/results.jsonl +``` + +### Codesign + +Codesign experiments on PepBench: + +```bash +GPU=0 bash scripts/run_exp_pipe.sh pepbench_codesign configs/pepbench/autoencoder/train_codesign.yaml configs/pepbench/ldm/train_codesign.yaml configs/pepbench/ldm/setup_latent_guidance.yaml configs/pepbench/test_codesign.yaml +``` + + +### Binding Conformation Generation + +Conformation generation experiments on PepBench: + +```bash +GPU=0 bash scripts/run_exp_pipe.sh pepbench_fixseq configs/pepbench/autoencoder/train_fixseq.yaml configs/pepbench/ldm/train_fixseq.yaml configs/pepbench/ldm/setup_latent_guidance.yaml configs/pepbench/test_fixseq.yaml +``` + +## Contact + +Thank you for your interest in our work! + +Please feel free to ask about any questions about the algorithms, codes, as well as problems encountered in running them so that we can make it clearer and better. You can either create an issue in the github repo or contact us at jackie_kxz@outlook.com. + +## Reference + +```bibtex +@article{kong2025full, + title={Full-atom peptide design with geometric latent diffusion}, + author={Kong, Xiangzhe and Jia, Yinjun and Huang, Wenbing and Liu, Yang}, + journal={Advances in Neural Information Processing Systems}, + volume={37}, + pages={74808--74839}, + year={2025} +} +``` diff --git a/api/detect_pocket.py b/api/detect_pocket.py new file mode 100644 index 0000000000000000000000000000000000000000..4246035663316ce28be16caac51ce17ba7b9cce2 --- /dev/null +++ b/api/detect_pocket.py @@ -0,0 +1,72 @@ +#!/usr/bin/python +# -*- coding:utf-8 -*- +import argparse +import numpy as np + +from data.converter.pdb_to_list_blocks import pdb_to_list_blocks +from data.converter.blocks_interface import blocks_cb_interface, dist_matrix_from_blocks + + +def get_interface(pdb, receptor_chains, ligand_chains, pocket_th=10.0): # CB distance + list_blocks, chain_ids = pdb_to_list_blocks(pdb, receptor_chains + ligand_chains, return_chain_ids=True) + chain2blocks = {chain: block for chain, block in zip(chain_ids, list_blocks)} + for c in receptor_chains: + assert c in chain2blocks, f'Chain {c} not found for receptor' + for c in ligand_chains: + assert c in chain2blocks, f'Chain {c} not found for ligand' + + rec_blocks, rec_block_chains, lig_blocks, lig_block_chains = [], [], [], [] + for c in receptor_chains: + for block in chain2blocks[c]: + rec_blocks.append(block) + rec_block_chains.append(c) + for c in ligand_chains: + for block in chain2blocks[c]: + lig_blocks.append(block) + lig_block_chains.append(c) + + _, (pocket_idx, lig_if_idx) = blocks_cb_interface(rec_blocks, lig_blocks, pocket_th) # 10A for pocket size based on CB + epitope = [] + for i in pocket_idx: + epitope.append((rec_blocks[i], rec_block_chains[i], i)) + + dist_mat = dist_matrix_from_blocks([rec_blocks[i] for i in pocket_idx], [lig_blocks[i] for i in lig_if_idx]) + min_dists = np.min(dist_mat, axis=-1) # [Nrec] + lig_idxs = np.argmin(dist_mat, axis=-1) # [Nrec] + dists = [] + for i, d in zip(lig_idxs, min_dists): + i = lig_if_idx[i] + dists.append((lig_blocks[i], lig_block_chains[i], i, d)) + + return epitope, dists + + +if __name__ == '__main__': + import json + parser = argparse.ArgumentParser(description='get interface') + parser.add_argument('--pdb', type=str, required=True, help='Path to the complex pdb') + parser.add_argument('--target_chains', type=str, nargs='+', required=True, help='Specify target chain ids') + parser.add_argument('--ligand_chains', type=str, nargs='+', required=True, help='Specify ligand chain ids') + parser.add_argument('--pocket_th', type=int, default=10.0, help='CB distance threshold for defining the binding site') + parser.add_argument('--out', type=str, default=None, help='Save epitope information to json file if specified') + args = parser.parse_args() + epitope, dists = get_interface(args.pdb, args.target_chains, args.ligand_chains, args.pocket_th) + para_res = {} + for _, chain_name, i, d in dists: + key = f'{chain_name}-{i}' + para_res[key] = 1 + print(f'REMARK: {len(epitope)} residues in the binding site on the target protein, with {len(para_res)} residues in ligand:') + print(f' \tchain\tresidue id\ttype\tchain\tresidue id\ttype\tdistance') + for i, (e, p) in enumerate(zip(epitope, dists)): + e_res, e_chain_name, _ = e + p_res, p_chain_name, _, d = p + print(f'{i+1}\t{e_chain_name}\t{e_res.id}\t{e_res.abrv}\t' + \ + f'{p_chain_name}\t{p_res.id}\t{p_res.abrv}\t{round(d, 3)}') + + if args.out: + data = [] + for e in epitope: + res, chain_name, _ = e + data.append((chain_name, res.id)) + with open(args.out, 'w') as fout: + json.dump(data, fout) \ No newline at end of file diff --git a/api/run.py b/api/run.py new file mode 100644 index 0000000000000000000000000000000000000000..773a539383ce4dc0908b39b340e31c36ef800e2a --- /dev/null +++ b/api/run.py @@ -0,0 +1,274 @@ +#!/usr/bin/python +# -*- coding:utf-8 -*- +import os +import sys +import json +import argparse +from tqdm import tqdm +from os.path import splitext, basename + +import ray +import numpy as np +import torch +from torch.utils.data import DataLoader + +from data.format import Atom, Block, VOCAB +from data.converter.pdb_to_list_blocks import pdb_to_list_blocks +from data.converter.list_blocks_to_pdb import list_blocks_to_pdb +from data.codesign import calculate_covariance_matrix +from utils.const import sidechain_atoms +from utils.logger import print_log +from evaluation.dG.openmm_relaxer import ForceFieldMinimizer + + +class DesignDataset(torch.utils.data.Dataset): + + MAX_N_ATOM = 14 + + def __init__(self, pdbs, epitopes, lengths_range=None, seqs=None) -> None: + super().__init__() + self.pdbs = pdbs + self.epitopes = epitopes + self.lengths_range = lengths_range + self.seqs = seqs + # structure prediction or codesign + assert (self.seqs is not None and self.lengths_range is None) | \ + (self.seqs is None and self.lengths_range is not None) + + def get_epitope(self, idx): + pdb, epitope_def = self.pdbs[idx], self.epitopes[idx] + + with open(epitope_def, 'r') as fin: + epitope = json.load(fin) + to_str = lambda pos: f'{pos[0]}-{pos[1]}' + epi_map = {} + for chain_name, pos in epitope: + if chain_name not in epi_map: + epi_map[chain_name] = {} + epi_map[chain_name][to_str(pos)] = True + residues, position_ids = [], [] + chain2blocks = pdb_to_list_blocks(pdb, list(epi_map.keys()), dict_form=True) + if len(chain2blocks) != len(epi_map): + print_log(f'Some chains in the epitope are missing. Parsed {list(chain2blocks.keys())}, given {list(epi_map.keys())}.', level='WARN') + for chain_name in chain2blocks: + chain = chain2blocks[chain_name] + for i, block in enumerate(chain): # residue + if to_str(block.id) in epi_map[chain_name]: + residues.append(block) + position_ids.append(i + 1) # position ids start from 1 + return residues, position_ids, chain2blocks + + def generate_pep_chain(self, idx): + if self.lengths_range is not None: # codesign + lmin, lmax = self.lengths_range[idx] + length = np.random.randint(lmin, lmax) + unk_block = Block(VOCAB.symbol_to_abrv(VOCAB.UNK), [Atom('CA', [0, 0, 0], 'C')]) + return [unk_block] * length + else: + seq = self.seqs[idx] + blocks = [] + for s in seq: + atoms = [] + for atom_name in VOCAB.backbone_atoms + sidechain_atoms.get(s, []): + atoms.append(Atom(atom_name, [0, 0, 0], atom_name[0])) + blocks.append(Block(VOCAB.symbol_to_abrv(s), atoms)) + return blocks + + def __len__(self): + return len(self.pdbs) + + def __getitem__(self, idx: int): + rec_blocks, rec_position_ids, rec_chain2blocks = self.get_epitope(idx) + lig_blocks = self.generate_pep_chain(idx) + + mask = [0 for _ in rec_blocks] + [1 for _ in lig_blocks] + position_ids = rec_position_ids + [i + 1 for i, _ in enumerate(lig_blocks)] + X, S, atom_mask = [], [], [] + for block in rec_blocks + lig_blocks: + symbol = VOCAB.abrv_to_symbol(block.abrv) + atom2coord = { unit.name: unit.get_coord() for unit in block.units } + bb_pos = np.mean(list(atom2coord.values()), axis=0).tolist() + coords, coord_mask = [], [] + for atom_name in VOCAB.backbone_atoms + sidechain_atoms.get(symbol, []): + if atom_name in atom2coord: + coords.append(atom2coord[atom_name]) + coord_mask.append(1) + else: + coords.append(bb_pos) + coord_mask.append(0) + n_pad = self.MAX_N_ATOM - len(coords) + for _ in range(n_pad): + coords.append(bb_pos) + coord_mask.append(0) + + X.append(coords) + S.append(VOCAB.symbol_to_idx(symbol)) + atom_mask.append(coord_mask) + + X, atom_mask = torch.tensor(X, dtype=torch.float), torch.tensor(atom_mask, dtype=torch.bool) + mask = torch.tensor(mask, dtype=torch.bool) + cov = calculate_covariance_matrix(X[~mask][:, 1][atom_mask[~mask][:, 1]].numpy()) # only use the receptor to derive the affine transformation + eps = 1e-4 + cov = cov + eps * np.identity(cov.shape[0]) + L = torch.from_numpy(np.linalg.cholesky(cov)).float().unsqueeze(0) + + return { + 'X': X, # [N, 14] or [N, 4] if backbone_only == True + 'S': torch.tensor(S, dtype=torch.long), # [N] + 'position_ids': torch.tensor(position_ids, dtype=torch.long), # [N] + 'mask': mask, # [N], 1 for generation + 'atom_mask': atom_mask, # [N, 14] or [N, 4], 1 for having records in the PDB + 'lengths': len(S), + 'rec_chain2blocks': rec_chain2blocks, + 'L': L + } + + def collate_fn(self, batch): + results = {} + for key in batch[0]: + values = [item[key] for item in batch] + if key == 'lengths': + results[key] = torch.tensor(values, dtype=torch.long) + elif key == 'rec_chain2blocks': + results[key] = values + else: + results[key] = torch.cat(values, dim=0) + return results + + +@ray.remote(num_cpus=1, num_gpus=1/16) +def openmm_relax(pdb_path): + force_field = ForceFieldMinimizer() + force_field(pdb_path, pdb_path) + return pdb_path + + +def design(mode, ckpt, gpu, pdbs, epitope_defs, n_samples, out_dir, + lengths_range=None, seqs=None, identifiers=None, batch_size=8, num_workers=4): + + # create out dir + if not os.path.exists(out_dir): + os.makedirs(out_dir) + result_summary = open(os.path.join(out_dir, 'summary.jsonl'), 'w') + if identifiers is None: + identifiers = [splitext(basename(pdb))[0] for pdb in pdbs] + # load model + device = torch.device('cpu' if gpu == -1 else f'cuda:{gpu}') + model = torch.load(ckpt, map_location='cpu') + model.to(device) + model.eval() + + # generate dataset + # expand data + if lengths_range is None: lengths_range = [None for _ in pdbs] + if seqs is None: seqs = [None for _ in pdbs] + expand_pdbs, expand_epitopes, expand_lens, expand_ids, expand_seqs = [], [], [], [], [] + for _id, pdb, epitope, l, s, n in zip(identifiers, pdbs, epitope_defs, lengths_range, seqs, n_samples): + expand_ids.extend([f'{_id}_{i}' for i in range(n)]) + expand_pdbs.extend([pdb for _ in range(n)]) + expand_epitopes.extend([epitope for _ in range(n)]) + expand_lens.extend([l for _ in range(n)]) + expand_seqs.extend([s for _ in range(n)]) + # create dataset + if expand_lens[0] is None: expand_lens = None + if expand_seqs[0] is None: expand_seqs = None + dataset = DesignDataset(expand_pdbs, expand_epitopes, expand_lens, expand_seqs) + dataloader = DataLoader(dataset, batch_size=batch_size, + num_workers=num_workers, + collate_fn=dataset.collate_fn, + shuffle=False + ) + + # generate peptides + cnt = 0 + all_pdbs = [] + for batch in tqdm(dataloader): + with torch.no_grad(): + # move data + for k in batch: + if hasattr(batch[k], 'to'): + batch[k] = batch[k].to(device) + # generate + batch_X, batch_S, batch_pmetric = model.sample( + batch['X'], batch['S'], + batch['mask'], batch['position_ids'], + batch['lengths'], batch['atom_mask'], + L=batch['L'], sample_opt={ + 'energy_func': 'default', + 'energy_lambda': 0.5 if mode == 'struct_pred' else 0.8 + } + ) + # save data + for X, S, pmetric, rec_chain2blocks in zip(batch_X, batch_S, batch_pmetric, batch['rec_chain2blocks']): + if S is None: S = expand_seqs[cnt] # structure prediction + lig_blocks = [] + for x, s in zip(X, S): + abrv = VOCAB.symbol_to_abrv(s) + atoms = VOCAB.backbone_atoms + sidechain_atoms[VOCAB.abrv_to_symbol(abrv)] + units = [ + Atom(atom_name, coord, atom_name[0]) for atom_name, coord in zip(atoms, x) + ] + lig_blocks.append(Block(abrv, units)) + list_blocks, chain_names = [], [] + for chain in rec_chain2blocks: + list_blocks.append(rec_chain2blocks[chain]) + chain_names.append(chain) + pep_chain_id = chr(max([ord(c) for c in chain_names]) + 1) + list_blocks.append(lig_blocks) + chain_names.append(pep_chain_id) + out_pdb = os.path.join(out_dir, expand_ids[cnt] + '.pdb') + list_blocks_to_pdb(list_blocks, chain_names, out_pdb) + all_pdbs.append(out_pdb) + result_summary.write(json.dumps({ + 'id': expand_ids[cnt], + 'rec_chains': list(rec_chain2blocks.keys()), + 'pep_chain': pep_chain_id, + 'pep_seq': ''.join([VOCAB.abrv_to_symbol(block.abrv) for block in lig_blocks]) + }) + '\n') + result_summary.flush() + cnt += 1 + result_summary.close() + + print_log(f'Running openmm relaxation...') + ray.init(num_cpus=8) + futures = [openmm_relax.remote(path) for path in all_pdbs] + pbar = tqdm(total=len(futures)) + while len(futures) > 0: + done_ids, futures = ray.wait(futures, num_returns=1) + for done_id in done_ids: + done_path = ray.get(done_id) + pbar.update(1) + print_log(f'Done') + + +def parse(): + parser = argparse.ArgumentParser(description='run pepglad for codesign or structure prediction') + parser.add_argument('--mode', type=str, required=True, choices=['codesign', 'struct_pred'], help='Running mode') + parser.add_argument('--pdb', type=str, required=True, help='Path to the PDB file of the target protein') + parser.add_argument('--pocket', type=str, required=True, help='Path to the pocket definition (*.json generated by detect_pocket)') + parser.add_argument('--n_samples', type=int, default=10, help='Number of samples') + parser.add_argument('--out_dir', type=str, required=True, help='Output directory') + parser.add_argument('--peptide_seq', type=str, required='struct_pred' in sys.argv, help='Peptide sequence for structure prediction') + parser.add_argument('--length_min', type=int, required='codesign' in sys.argv, help='Minimum peptide length for codesign (inclusive)') + parser.add_argument('--length_max', type=int, required='codesign' in sys.argv, help='Maximum peptide length for codesign (exclusive)') + parser.add_argument('--gpu', type=int, default=0, help='GPU to use') + return parser.parse_args() + + +if __name__ == '__main__': + args = parse() + proj_dir = os.path.join(os.path.dirname(__file__), '..') + ckpt = os.path.join(proj_dir, 'checkpoints', 'fixseq.ckpt' if args.mode == 'struct_pred' else 'codesign.ckpt') + print_log(f'Loading checkpoint: {ckpt}') + design( + mode=args.mode, + ckpt=ckpt, # path to the checkpoint of the trained model + gpu=args.gpu, # the ID of the GPU to use + pdbs=[args.pdb], # paths to the PDB file of each antigen + epitope_defs=[args.pocket], # paths to the epitope (pocket) definitions + n_samples=[args.n_samples], # number of samples for each epitope + out_dir=args.out_dir, # output directory + identifiers=[os.path.basename(os.path.splitext(args.pdb)[0])], # file name (name of each output candidate) + lengths_range=[(args.length_min, args.length_max)] if args.mode == 'codesign' else None, # range of acceptable peptide lengths, left inclusive, right exclusive + seqs=[args.peptide_seq] if args.mode == 'struct_pred' else None # peptide sequences for structure prediction + ) \ No newline at end of file diff --git a/assets/1ssc_A_pocket.json b/assets/1ssc_A_pocket.json new file mode 100644 index 0000000000000000000000000000000000000000..ebe5c8536b862b8a6ebf978b72f8711d237f0b3e --- /dev/null +++ b/assets/1ssc_A_pocket.json @@ -0,0 +1 @@ +[["A", [3, " "]], ["A", [4, " "]], ["A", [5, " "]], ["A", [6, " "]], ["A", [7, " "]], ["A", [8, " "]], ["A", [9, " "]], ["A", [11, " "]], ["A", [12, " "]], ["A", [13, " "]], ["A", [43, " "]], ["A", [44, " "]], ["A", [45, " "]], ["A", [46, " "]], ["A", [47, " "]], ["A", [51, " "]], ["A", [54, " "]], ["A", [55, " "]], ["A", [56, " "]], ["A", [57, " "]], ["A", [58, " "]], ["A", [59, " "]], ["A", [63, " "]], ["A", [64, " "]], ["A", [65, " "]], ["A", [66, " "]], ["A", [67, " "]], ["A", [69, " "]], ["A", [71, " "]], ["A", [72, " "]], ["A", [73, " "]], ["A", [74, " "]], ["A", [75, " "]], ["A", [78, " "]], ["A", [79, " "]], ["A", [81, " "]], ["A", [83, " "]], ["A", [102, " "]], ["A", [103, " "]], ["A", [104, " "]], ["A", [105, " "]], ["A", [106, " "]], ["A", [107, " "]], ["A", [108, " "]], ["A", [109, " "]], ["A", [110, " "]], ["A", [111, " "]], ["A", [112, " "]]] \ No newline at end of file diff --git a/cal_metrics.py b/cal_metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..30324ca6a3ac7d365e155da43ac0f54a20e68d11 --- /dev/null +++ b/cal_metrics.py @@ -0,0 +1,228 @@ +#!/usr/bin/python +# -*- coding:utf-8 -*- +import argparse +import json +import os +import random +from copy import deepcopy +from collections import defaultdict +from tqdm import tqdm +from tqdm.contrib.concurrent import process_map +import statistics +import warnings +warnings.filterwarnings("ignore") + +import numpy as np +from scipy.stats import spearmanr + +from data.converter.pdb_to_list_blocks import pdb_to_list_blocks +from evaluation import diversity +from evaluation.dockq import dockq +from evaluation.rmsd import compute_rmsd +from utils.random_seed import setup_seed +from evaluation.seq_metric import aar, slide_aar + + +def _get_ref_pdb(_id, root_dir): + return os.path.join(root_dir, 'references', f'{_id}_ref.pdb') + + +def _get_gen_pdb(_id, number, root_dir, use_rosetta): + suffix = '_rosetta' if use_rosetta else '' + return os.path.join(root_dir, 'candidates', _id, f'{_id}_gen_{number}{suffix}.pdb') + + +def cal_metrics(items): + # all of the items are conditioned on the same binding pocket + root_dir = items[0]['root_dir'] + ref_pdb, rec_chain, lig_chain = items[0]['ref_pdb'], items[0]['rec_chain'], items[0]['lig_chain'] + ref_pdb = _get_ref_pdb(items[0]['id'], root_dir) + seq_only, struct_only, backbone_only = items[0]['seq_only'], items[0]['struct_only'], items[0]['backbone_only'] + + # prepare + results = defaultdict(list) + cand_seqs, cand_ca_xs = [], [] + rec_blocks, ref_pep_blocks = pdb_to_list_blocks(ref_pdb, [rec_chain, lig_chain]) + ref_ca_x, ca_mask = [], [] + for ref_block in ref_pep_blocks: + if ref_block.has_unit('CA'): + ca_mask.append(1) + ref_ca_x.append(ref_block.get_unit_by_name('CA').get_coord()) + else: + ca_mask.append(0) + ref_ca_x.append([0, 0, 0]) + ref_ca_x, ca_mask = np.array(ref_ca_x), np.array(ca_mask).astype(bool) + + for item in items: + if not struct_only: + cand_seqs.append(item['gen_seq']) + results['Slide AAR'].append(slide_aar(item['gen_seq'], item['ref_seq'], aar)) + + # structure metrics + gen_pdb = _get_gen_pdb(item['id'], item['number'], root_dir, item['rosetta']) + _, gen_pep_blocks = pdb_to_list_blocks(gen_pdb, [rec_chain, lig_chain]) + assert len(gen_pep_blocks) == len(ref_pep_blocks), f'{item}\t{len(ref_pep_blocks)}\t{len(gen_pep_blocks)}' + + # CA RMSD + gen_ca_x = np.array([block.get_unit_by_name('CA').get_coord() for block in gen_pep_blocks]) + cand_ca_xs.append(gen_ca_x) + rmsd = compute_rmsd(ref_ca_x[ca_mask], gen_ca_x[ca_mask], aligned=True) + results['RMSD(CA)'].append(rmsd) + if struct_only: + results['RMSD<=2.0'].append(1 if rmsd <= 2.0 else 0) + results['RMSD<=5.0'].append(1 if rmsd <= 5.0 else 0) + results['RMSD<=10.0'].append(1 if rmsd <= 10.0 else 0) + + + if backbone_only: + continue + + # 5. DockQ + dockq_score = dockq(gen_pdb, ref_pdb, lig_chain) + results['DockQ'].append(dockq_score) + if struct_only: + results['DockQ>=0.23'].append(1 if dockq_score >= 0.23 else 0) + results['DockQ>=0.49'].append(1 if dockq_score >= 0.49 else 0) + results['DockQ>=0.80'].append(1 if dockq_score >= 0.80 else 0) + + # Full atom RMSD + if struct_only: + gen_all_x, ref_all_x = [], [] + for gen_block, ref_block in zip(gen_pep_blocks, ref_pep_blocks): + for ref_atom in ref_block: + if gen_block.has_unit(ref_atom.name): + ref_all_x.append(ref_atom.get_coord()) + gen_all_x.append(gen_block.get_unit_by_name(ref_atom.name).get_coord()) + results['RMSD(full-atom)'].append(compute_rmsd( + np.array(gen_all_x), np.array(ref_all_x), aligned=True + )) + + pmets = [item['pmetric'] for item in items] + indexes = list(range(len(items))) + # aggregation + for name in results: + vals = results[name] + corr = spearmanr(vals, pmets, nan_policy='omit').statistic + if np.isnan(corr): + corr = 0 + aggr_res = { + 'max': max(vals), + 'min': min(vals), + 'mean': sum(vals) / len(vals), + 'random': vals[0], + 'max*': vals[(max if corr > 0 else min)(indexes, key=lambda i: pmets[i])], + 'min*': vals[(min if corr > 0 else max)(indexes, key=lambda i: pmets[i])], + 'pmet_corr': corr, + 'individual': vals, + 'individual_pmet': pmets + } + results[name] = aggr_res + + if len(cand_seqs) > 1 and not seq_only: + seq_div, struct_div, co_div, consistency = diversity.diversity(cand_seqs, np.array(cand_ca_xs)) + results['Sequence Diversity'] = seq_div + results['Struct Diversity'] = struct_div + results['Codesign Diversity'] = co_div + results['Consistency'] = consistency + + return results + + +def cnt_aa_dist(seqs): + cnts = {} + for seq in seqs: + for aa in seq: + if aa not in cnts: + cnts[aa] = 0 + cnts[aa] += 1 + aas = sorted(list(cnts.keys()), key=lambda aa: cnts[aa]) + total = sum(cnts.values()) + for aa in aas: + print(f'\t{aa}: {cnts[aa] / total}') + + +def main(args): + root_dir = os.path.dirname(args.results) + # load dG filter + if args.filter_dG is None: + filter_func = lambda _id, n: True + else: + dG_results = json.load(open(args.filter_dG, 'r')) + filter_func = lambda _id, n: dG_results[_id]['all'][str(n)] < 0 + # load results + with open(args.results, 'r') as fin: + lines = fin.read().strip().split('\n') + id2items = {} + for line in lines: + item = json.loads(line) + _id = item['id'] + if not filter_func(_id, item['number']): + continue + if _id not in id2items: + id2items[_id] = [] + item['root_dir'] = root_dir + item['rosetta'] = args.rosetta + id2items[_id].append(item) + ids = list(id2items.keys()) + + if args.filter_dG is not None: + # delete results with only one sample since it cannot calculate diversity + del_ids = [_id for _id in ids if len(id2items[_id]) < 2] + for _id in del_ids: + print(f'Deleting {_id} since it only has one sample passed the filter') + del id2items[_id] + + if args.num_workers > 1: + metrics = process_map(cal_metrics, id2items.values(), max_workers=args.num_workers, chunksize=1) + else: + metrics = [cal_metrics(inputs) for inputs in tqdm(id2items.values())] + + eval_results_path = os.path.join(os.path.dirname(args.results), 'eval_report.json') + with open(eval_results_path, 'w') as fout: + for i, _id in enumerate(id2items): + metric = deepcopy(metrics[i]) + metric['id'] = _id + fout.write(json.dumps(metric) + '\n') + + # individual level results + print('Point-wise evaluation results:') + for name in metrics[0]: + vals = [item[name] for item in metrics] + if isinstance(vals[0], dict): + if 'RMSD' in name and '<=' not in name: + aggr = 'min' + else: + aggr = 'max' + aggr_vals = [val[aggr] for val in vals] + if '>=' in name or '<=' in name: # percentage + print(f'{name}: {sum(aggr_vals) / len(aggr_vals)}') + else: + if 'RMSD' in name: + print(f'{name}(median): {statistics.median(aggr_vals)}') # unbounded, some extreme values will affect the mean but not the median + else: + print(f'{name}(mean): {sum(aggr_vals) / len(aggr_vals)}') + lowest_i = min([i for i in range(len(aggr_vals))], key=lambda i: aggr_vals[i]) + highest_i = max([i for i in range(len(aggr_vals))], key=lambda i: aggr_vals[i]) + print(f'\tlowest: {aggr_vals[lowest_i]}, id: {ids[lowest_i]}', end='') + print(f'\thighest: {aggr_vals[highest_i]}, id: {ids[highest_i]}') + else: + print(f'{name} (mean): {sum(vals) / len(vals)}') + lowest_i = min([i for i in range(len(vals))], key=lambda i: vals[i]) + highest_i = max([i for i in range(len(vals))], key=lambda i: vals[i]) + print(f'\tlowest: {vals[lowest_i]}, id: {ids[lowest_i]}') + print(f'\thighest: {vals[highest_i]}, id: {ids[highest_i]}') + + +def parse(): + parser = argparse.ArgumentParser(description='calculate metrics') + parser.add_argument('--results', type=str, required=True, help='Path to test set') + parser.add_argument('--num_workers', type=int, default=8, help='Number of workers to use') + parser.add_argument('--rosetta', action='store_true', help='Use the rosetta-refined structure') + parser.add_argument('--filter_dG', type=str, default=None, help='Only calculate results on samples with dG<0') + + return parser.parse_args() + + +if __name__ == '__main__': + setup_seed(0) + main(parse()) diff --git a/configs/pepbdb/autoencoder/train_codesign.yaml b/configs/pepbdb/autoencoder/train_codesign.yaml new file mode 100644 index 0000000000000000000000000000000000000000..142f37442cf98aaab00f9457b7fdaf8e9b834f12 --- /dev/null +++ b/configs/pepbdb/autoencoder/train_codesign.yaml @@ -0,0 +1,66 @@ +dataset: + train: + - class: CoDesignDataset + mmap_dir: ./datasets/pepbdb/processed + specify_index: ./datasets/pepbdb/processed/train_index.txt + backbone_only: false + cluster: ./datasets/pepbdb/train.cluster + - class: CoDesignDataset + mmap_dir: ./datasets/ProtFrag/processed + backbone_only: false + valid: + class: CoDesignDataset + mmap_dir: ./datasets/pepbdb/processed + specify_index: ./datasets/pepbdb/processed/valid_index.txt + backbone_only: false + +dataloader: + shuffle: true + num_workers: 4 + wrapper: + class: DynamicBatchWrapper + complexity: n**2 + ubound_per_batch: 60000 # batch size ~24 + +trainer: + class: AutoEncoderTrainer + config: + max_epoch: 100 + save_topk: 10 + save_dir: ./ckpts/autoencoder_codesign_pepbdb + patience: 10 + metric_min_better: true + + optimizer: + class: AdamW + lr: 1.0e-4 + + scheduler: + class: ReduceLROnPlateau + factor: 0.8 + patience: 5 + mode: min + frequency: val_epoch + min_lr: 5.0e-6 + +model: + class: AutoEncoder + embed_size: 128 + hidden_size: 128 + latent_size: 8 + latent_n_channel: 1 + n_layers: 3 + n_channel: 14 # all atom + h_kl_weight: 0.3 + z_kl_weight: 0.5 + coord_loss_ratio: 0.5 + coord_loss_weights: + Xloss: 1.0 + ca_Xloss: 1.0 + bb_bond_lengths_loss: 1.0 + sc_bond_lengths_loss: 1.0 + bb_dihedral_angles_loss: 0.0 + sc_chi_angles_loss: 0.5 + relative_position: false + anchor_at_ca: true + mask_ratio: 0.25 \ No newline at end of file diff --git a/configs/pepbdb/autoencoder/train_fixseq.yaml b/configs/pepbdb/autoencoder/train_fixseq.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8e9ec4f31c517ff95e8db29a3d6890718420b996 --- /dev/null +++ b/configs/pepbdb/autoencoder/train_fixseq.yaml @@ -0,0 +1,63 @@ +dataset: + train: + class: CoDesignDataset + mmap_dir: ./datasets/pepbdb/processed + specify_index: ./datasets/pepbdb/processed/train_index.txt + backbone_only: false + cluster: ./datasets/pepbdb/train.cluster + valid: + class: CoDesignDataset + mmap_dir: ./datasets/pepbdb/processed + specify_index: ./datasets/pepbdb/processed/valid_index.txt + backbone_only: false + +dataloader: + shuffle: true + num_workers: 4 + wrapper: + class: DynamicBatchWrapper + complexity: n**2 + ubound_per_batch: 60000 # batch size ~24 + +trainer: + class: AutoEncoderTrainer + config: + max_epoch: 150 # the best checkpoint should be obatained at about epoch 457 + save_topk: 10 + save_dir: ./ckpts/autoencoder_fixseq + patience: 10 + metric_min_better: true + + optimizer: + class: AdamW + lr: 1.0e-4 + + scheduler: + class: ReduceLROnPlateau + factor: 0.8 + patience: 15 + mode: min + frequency: val_epoch + min_lr: 5.0e-6 + +model: + class: AutoEncoder + embed_size: 128 + hidden_size: 128 + latent_size: 0 + latent_n_channel: 1 + n_layers: 3 + n_channel: 14 # all atom + h_kl_weight: 0.0 + z_kl_weight: 0.6 + coord_loss_ratio: 1.0 + coord_loss_weights: + Xloss: 1.0 + ca_Xloss: 1.0 + bb_bond_lengths_loss: 1.0 + sc_bond_lengths_loss: 1.0 + bb_dihedral_angles_loss: 0.0 + sc_chi_angles_loss: 0.5 + anchor_at_ca: true + mode: fixseq + additional_noise_scale: 1.0 diff --git a/configs/pepbdb/ldm/setup_latent_guidance.yaml b/configs/pepbdb/ldm/setup_latent_guidance.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5b0fa3f7c69b2e6faec43b306d36c874c216be3d --- /dev/null +++ b/configs/pepbdb/ldm/setup_latent_guidance.yaml @@ -0,0 +1,12 @@ +dataset: + test: + class: CoDesignDataset + mmap_dir: ./datasets/pepbdb/processed + specify_index: ./datasets/pepbdb/processed/train_index.txt + backbone_only: false + +dataloader: + num_workers: 2 + batch_size: 32 + +backbone_only: false diff --git a/configs/pepbdb/ldm/train_codesign.yaml b/configs/pepbdb/ldm/train_codesign.yaml new file mode 100644 index 0000000000000000000000000000000000000000..842fef81b3b8c83eaa1b3303fe94d996790b67cb --- /dev/null +++ b/configs/pepbdb/ldm/train_codesign.yaml @@ -0,0 +1,61 @@ +dataset: + train: + - class: CoDesignDataset + mmap_dir: ./datasets/pepbdb/processed + specify_index: ./datasets/pepbdb/processed/train_index.txt + backbone_only: false + cluster: ./datasets/pepbdb/train.cluster + use_covariance_matrix: true + valid: + class: CoDesignDataset + mmap_dir: ./datasets/pepbdb/processed + specify_index: ./datasets/pepbdb/processed/valid_index.txt + backbone_only: false + use_covariance_matrix: true + +dataloader: + shuffle: true + num_workers: 4 + wrapper: + class: DynamicBatchWrapper + complexity: n**2 + ubound_per_batch: 60000 # batch size ~32 + +trainer: + class: LDMTrainer + criterion: Loss + config: + max_epoch: 500 # the best checkpoint should be obtained at around epoch 380 + save_topk: 10 + val_freq: 10 + save_dir: ./ckpts/LDM_codesign + patience: 10 + metric_min_better: true + + optimizer: + class: AdamW + lr: 1.0e-4 + + scheduler: + class: ReduceLROnPlateau + factor: 0.6 + patience: 3 + mode: min + frequency: val_epoch + min_lr: 5.0e-6 + +model: + class: LDMPepDesign + autoencoder_ckpt: "" + autoencoder_no_randomness: true + hidden_size: 128 + num_steps: 100 + n_layers: 3 + n_rbf: 32 + cutoff: 3.0 # the coordinates are in standard space + dist_rbf: 32 + dist_rbf_cutoff: 7.0 + diffusion_opt: + trans_seq_type: Diffusion + trans_pos_type: Diffusion + max_gen_position: 60 \ No newline at end of file diff --git a/configs/pepbdb/ldm/train_fixseq.yaml b/configs/pepbdb/ldm/train_fixseq.yaml new file mode 100644 index 0000000000000000000000000000000000000000..dc7811e1e7814b88d684b0fb07806120673d118d --- /dev/null +++ b/configs/pepbdb/ldm/train_fixseq.yaml @@ -0,0 +1,63 @@ +dataset: + train: + - class: CoDesignDataset + mmap_dir: ./datasets/pepbdb/processed + specify_index: ./datasets/pepbdb/processed/train_index.txt + backbone_only: false + cluster: ./datasets/pepbdb/train.cluster + use_covariance_matrix: true + valid: + class: CoDesignDataset + mmap_dir: ./datasets/pepbdb/processed + specify_index: ./datasets/pepbdb/processed/valid_index.txt + backbone_only: false + use_covariance_matrix: true + +dataloader: + shuffle: true + num_workers: 4 + wrapper: + class: DynamicBatchWrapper + complexity: n**2 + ubound_per_batch: 60000 # batch size ~32 + +trainer: + class: LDMTrainer + criterion: RMSD + config: + max_epoch: 1000 # the best checkpoint will be obtained at about 900 epoch + save_topk: 10 + val_freq: 10 + save_dir: ./ckpts/LDM_fixseq + patience: 10 + metric_min_better: true + + optimizer: + class: AdamW + lr: 1.0e-4 + + scheduler: + class: ReduceLROnPlateau + factor: 0.6 + patience: 3 + mode: min + frequency: val_epoch + min_lr: 5.0e-6 + +model: + class: LDMPepDesign + autoencoder_ckpt: "" + autoencoder_no_randomness: true + hidden_size: 128 + num_steps: 100 + n_layers: 6 + n_rbf: 32 + cutoff: 3.0 # the coordinates are in standard space + dist_rbf: 0 + dist_rbf_cutoff: 0.0 + diffusion_opt: + trans_seq_type: Diffusion + trans_pos_type: Diffusion + std: 20.0 + mode: fixseq + max_gen_position: 60 diff --git a/configs/pepbdb/test_codesign.yaml b/configs/pepbdb/test_codesign.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9463cda358811be4171ede18d44869f9a59cc11b --- /dev/null +++ b/configs/pepbdb/test_codesign.yaml @@ -0,0 +1,18 @@ +dataset: + test: + class: CoDesignDataset + mmap_dir: ./datasets/pepbdb/processed + specify_index: ./datasets/pepbdb/processed/test_index.txt + backbone_only: false + use_covariance_matrix: true + +dataloader: + num_workers: 4 + batch_size: 64 + +backbone_only: false +n_samples: 40 + +sample_opt: + energy_func: default + energy_lambda: 0.8 diff --git a/configs/pepbdb/test_fixseq.yaml b/configs/pepbdb/test_fixseq.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1ece4dccc929e540b8ab979ded74804212437a1f --- /dev/null +++ b/configs/pepbdb/test_fixseq.yaml @@ -0,0 +1,19 @@ +dataset: + test: + class: CoDesignDataset + mmap_dir: ./datasets/pepbdb/processed + specify_index: ./datasets/pepbdb/processed/test_index.txt + backbone_only: false + use_covariance_matrix: true + +dataloader: + num_workers: 4 + batch_size: 64 + +backbone_only: false +struct_only: true +n_samples: 10 + +sample_opt: + energy_func: default + energy_lambda: 0.8 diff --git a/configs/pepbench/autoencoder/train_codesign.yaml b/configs/pepbench/autoencoder/train_codesign.yaml new file mode 100644 index 0000000000000000000000000000000000000000..dcc63f3f64e77bded791513d02ba108dcd14439c --- /dev/null +++ b/configs/pepbench/autoencoder/train_codesign.yaml @@ -0,0 +1,66 @@ +dataset: + train: + - class: CoDesignDataset + mmap_dir: ./datasets/train_valid/processed + specify_index: ./datasets/train_valid/processed/train_index.txt + backbone_only: false + cluster: ./datasets/train_valid/train.cluster + - class: CoDesignDataset + mmap_dir: ./datasets/ProtFrag/processed + backbone_only: false + valid: + class: CoDesignDataset + mmap_dir: ./datasets/train_valid/processed + specify_index: ./datasets/train_valid/processed/valid_index.txt + backbone_only: false + +dataloader: + shuffle: true + num_workers: 4 + wrapper: + class: DynamicBatchWrapper + complexity: n**2 + ubound_per_batch: 60000 # batch size ~24 + +trainer: + class: AutoEncoderTrainer + config: + max_epoch: 100 + save_topk: 10 + save_dir: ./ckpts/autoencoder_codesign + patience: 10 + metric_min_better: true + + optimizer: + class: AdamW + lr: 1.0e-4 + + scheduler: + class: ReduceLROnPlateau + factor: 0.8 + patience: 5 + mode: min + frequency: val_epoch + min_lr: 5.0e-6 + +model: + class: AutoEncoder + embed_size: 128 + hidden_size: 128 + latent_size: 8 + latent_n_channel: 1 + n_layers: 3 + n_channel: 14 # all atom + h_kl_weight: 0.3 + z_kl_weight: 0.5 + coord_loss_ratio: 0.5 + coord_loss_weights: + Xloss: 1.0 + ca_Xloss: 1.0 + bb_bond_lengths_loss: 1.0 + sc_bond_lengths_loss: 1.0 + bb_dihedral_angles_loss: 0.0 + sc_chi_angles_loss: 0.5 + relative_position: false + anchor_at_ca: true + mask_ratio: 0.25 \ No newline at end of file diff --git a/configs/pepbench/autoencoder/train_fixseq.yaml b/configs/pepbench/autoencoder/train_fixseq.yaml new file mode 100644 index 0000000000000000000000000000000000000000..091fa943dc6739a7e9187f716fa0252473c7e5a6 --- /dev/null +++ b/configs/pepbench/autoencoder/train_fixseq.yaml @@ -0,0 +1,62 @@ +dataset: + train: + class: CoDesignDataset + mmap_dir: ./datasets/train_valid/processed + specify_index: ./datasets/train_valid/processed/train_index.txt + backbone_only: false + cluster: ./datasets/train_valid/train.cluster + valid: + class: CoDesignDataset + mmap_dir: ./datasets/train_valid/processed + specify_index: ./datasets/train_valid/processed/valid_index.txt + backbone_only: false + +dataloader: + shuffle: true + num_workers: 4 + wrapper: + class: DynamicBatchWrapper + complexity: n**2 + ubound_per_batch: 60000 # batch size ~24 + +trainer: + class: AutoEncoderTrainer + config: + max_epoch: 500 # the best checkpoint should be obatained at about epoch 457 + save_topk: 10 + save_dir: ./ckpts/autoencoder_fixseq + patience: 10 + metric_min_better: true + + optimizer: + class: AdamW + lr: 1.0e-4 + + scheduler: + class: ReduceLROnPlateau + factor: 0.8 + patience: 15 + mode: min + frequency: val_epoch + min_lr: 5.0e-6 + +model: + class: AutoEncoder + embed_size: 128 + hidden_size: 128 + latent_size: 0 + latent_n_channel: 1 + n_layers: 3 + n_channel: 14 # all atom + h_kl_weight: 0.0 + z_kl_weight: 1.0 + coord_loss_ratio: 1.0 + coord_loss_weights: + Xloss: 1.0 + ca_Xloss: 1.0 + bb_bond_lengths_loss: 1.0 + sc_bond_lengths_loss: 1.0 + bb_dihedral_angles_loss: 0.0 + sc_chi_angles_loss: 0.5 + anchor_at_ca: true + mode: fixseq diff --git a/configs/pepbench/ldm/setup_latent_guidance.yaml b/configs/pepbench/ldm/setup_latent_guidance.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d3dfaf2ef85d3c085a6fb83eff33be609fe0db03 --- /dev/null +++ b/configs/pepbench/ldm/setup_latent_guidance.yaml @@ -0,0 +1,12 @@ +dataset: + test: + class: CoDesignDataset + mmap_dir: ./datasets/train_valid/processed + specify_index: ./datasets/train_valid/processed/train_index.txt + backbone_only: false + +dataloader: + num_workers: 2 + batch_size: 32 + +backbone_only: false diff --git a/configs/pepbench/ldm/train_codesign.yaml b/configs/pepbench/ldm/train_codesign.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a6353eab8c4815218457fc7ffa6d7ef7239789ee --- /dev/null +++ b/configs/pepbench/ldm/train_codesign.yaml @@ -0,0 +1,60 @@ +dataset: + train: + class: CoDesignDataset + mmap_dir: ./datasets/train_valid/processed + specify_index: ./datasets/train_valid/processed/train_index.txt + backbone_only: false + cluster: ./datasets/train_valid/train.cluster + use_covariance_matrix: true + valid: + class: CoDesignDataset + mmap_dir: ./datasets/train_valid/processed + specify_index: ./datasets/train_valid/processed/valid_index.txt + backbone_only: false + use_covariance_matrix: true + +dataloader: + shuffle: true + num_workers: 4 + wrapper: + class: DynamicBatchWrapper + complexity: n**2 + ubound_per_batch: 60000 # batch size ~32 + +trainer: + class: LDMTrainer + criterion: Loss + config: + max_epoch: 500 # the best checkpoint should be obtained at around epoch 380 + save_topk: 10 + val_freq: 10 + save_dir: ./ckpts/LDM_codesign + patience: 10 + metric_min_better: true + + optimizer: + class: AdamW + lr: 1.0e-4 + + scheduler: + class: ReduceLROnPlateau + factor: 0.6 + patience: 3 + mode: min + frequency: val_epoch + min_lr: 5.0e-6 + +model: + class: LDMPepDesign + autoencoder_ckpt: "" + autoencoder_no_randomness: true + hidden_size: 128 + num_steps: 100 + n_layers: 3 + n_rbf: 32 + cutoff: 3.0 # the coordinates are in standard space + dist_rbf: 32 + dist_rbf_cutoff: 7.0 + diffusion_opt: + trans_seq_type: Diffusion + trans_pos_type: Diffusion \ No newline at end of file diff --git a/configs/pepbench/ldm/train_fixseq.yaml b/configs/pepbench/ldm/train_fixseq.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3c4240f20b87e11008cf003b1c29369263af43a3 --- /dev/null +++ b/configs/pepbench/ldm/train_fixseq.yaml @@ -0,0 +1,61 @@ +dataset: + train: + class: CoDesignDataset + mmap_dir: ./datasets/train_valid/processed + specify_index: ./datasets/train_valid/processed/train_index.txt + backbone_only: false + cluster: ./datasets/train_valid/train.cluster + use_covariance_matrix: true + valid: + class: CoDesignDataset + mmap_dir: ./datasets/train_valid/processed + specify_index: ./datasets/train_valid/processed/valid_index.txt + backbone_only: false + use_covariance_matrix: true + +dataloader: + shuffle: true + num_workers: 4 + wrapper: + class: DynamicBatchWrapper + complexity: n**2 + ubound_per_batch: 60000 # batch size ~32 + +trainer: + class: LDMTrainer + criterion: RMSD + config: + max_epoch: 1000 # the best checkpoint will be obtained at about 720 epoch + save_topk: 10 + val_freq: 10 + save_dir: ./ckpts/LDM_fixseq + patience: 10 + metric_min_better: true + + optimizer: + class: AdamW + lr: 1.0e-4 + + scheduler: + class: ReduceLROnPlateau + factor: 0.6 + patience: 3 + mode: min + frequency: val_epoch + min_lr: 5.0e-6 + +model: + class: LDMPepDesign + autoencoder_ckpt: "" + autoencoder_no_randomness: true + hidden_size: 128 + num_steps: 100 + n_layers: 3 + n_rbf: 32 + cutoff: 3.0 # the coordinates are in standard space + dist_rbf: 0 + dist_rbf_cutoff: 0.0 + diffusion_opt: + trans_seq_type: Diffusion + trans_pos_type: Diffusion + mode: fixseq diff --git a/configs/pepbench/test_codesign.yaml b/configs/pepbench/test_codesign.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ecfe48c641fabf3fc1b795af058071c23af49ac5 --- /dev/null +++ b/configs/pepbench/test_codesign.yaml @@ -0,0 +1,17 @@ +dataset: + test: + class: CoDesignDataset + mmap_dir: ./datasets/LNR/processed + backbone_only: false + use_covariance_matrix: true + +dataloader: + num_workers: 4 + batch_size: 64 + +backbone_only: false +n_samples: 40 + +sample_opt: + energy_func: default + energy_lambda: 0.8 diff --git a/configs/pepbench/test_fixseq.yaml b/configs/pepbench/test_fixseq.yaml new file mode 100644 index 0000000000000000000000000000000000000000..26e5e9df9139154b64c3d6040fb1297f8bfb5343 --- /dev/null +++ b/configs/pepbench/test_fixseq.yaml @@ -0,0 +1,18 @@ +dataset: + test: + class: CoDesignDataset + mmap_dir: ./datasets/LNR/processed + backbone_only: false + use_covariance_matrix: true + +dataloader: + num_workers: 4 + batch_size: 64 + +backbone_only: false +struct_only: true +n_samples: 10 + +sample_opt: + energy_func: default + energy_lambda: 0.5 diff --git a/data/__init__.py b/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..aad29b45b11b2989aa28ba1c20ee034dd037d0e9 --- /dev/null +++ b/data/__init__.py @@ -0,0 +1,53 @@ +#!/usr/bin/python +# -*- coding:utf-8 -*- +from .dataset_wrapper import MixDatasetWrapper +from .codesign import CoDesignDataset +from .resample import ClusterResampler + + +import torch +from torch.utils.data import DataLoader + +import utils.register as R +from utils.logger import print_log + +def create_dataset(config: dict): + splits = [] + for split_name in ['train', 'valid', 'test']: + split_config = config.get(split_name, None) + if split_config is None: + splits.append(None) + continue + if isinstance(split_config, list): + dataset = MixDatasetWrapper( + *[R.construct(cfg) for cfg in split_config] + ) + else: + dataset = R.construct(split_config) + splits.append(dataset) + return splits # train/valid/test + + +def create_dataloader(dataset, config: dict, n_gpu: int=1, validation: bool=False): + if 'wrapper' in config: + dataset = R.construct(config['wrapper'], dataset=dataset) + batch_size = config.get('batch_size', n_gpu) # default 1 on each gpu + if validation: + batch_size = config.get('val_batch_size', batch_size) + shuffle = config.get('shuffle', False) + num_workers = config.get('num_workers', 4) + collate_fn = dataset.collate_fn if hasattr(dataset, 'collate_fn') else None + if n_gpu > 1: + sampler = torch.utils.data.distributed.DistributedSampler(dataset, shuffle=shuffle) + batch_size = int(batch_size / n_gpu) + print_log(f'Batch size on a single GPU: {batch_size}') + else: + sampler = None + return DataLoader( + dataset=dataset, + batch_size=batch_size, + num_workers=num_workers, + shuffle=(shuffle and sampler is None), + collate_fn=collate_fn, + sampler=sampler + ) \ No newline at end of file diff --git a/data/codesign.py b/data/codesign.py new file mode 100644 index 0000000000000000000000000000000000000000..eafc1deb87d49a7aa241562bed1a558ff8e6b659 --- /dev/null +++ b/data/codesign.py @@ -0,0 +1,208 @@ + +import os +from typing import Optional, Any + +import numpy as np +import torch +from torch.nn.utils.rnn import pad_sequence + +from utils import register as R +from utils.const import sidechain_atoms + +from data.converter.list_blocks_to_pdb import list_blocks_to_pdb + +from .format import VOCAB, Block, Atom +from .mmap_dataset import MMAPDataset +from .resample import ClusterResampler + + + +def calculate_covariance_matrix(point_cloud): + # Calculate the covariance matrix of the point cloud + covariance_matrix = np.cov(point_cloud, rowvar=False) + return covariance_matrix + + +@R.register('CoDesignDataset') +class CoDesignDataset(MMAPDataset): + + MAX_N_ATOM = 14 + + def __init__( + self, + mmap_dir: str, + backbone_only: bool, # only backbone (N, CA, C, O) or full-atom + specify_data: Optional[str] = None, + specify_index: Optional[str] = None, + padding_collate: bool = False, + cluster: Optional[str] = None, + use_covariance_matrix: bool = False + ) -> None: + super().__init__(mmap_dir, specify_data, specify_index) + self.mmap_dir = mmap_dir + self.backbone_only = backbone_only + self._lengths = [len(prop[-1].split(',')) + int(prop[1]) for prop in self._properties] + self.padding_collate = padding_collate + self.resampler = ClusterResampler(cluster) if cluster else None # should only be used in training! + self.use_covariance_matrix = use_covariance_matrix + + self.dynamic_idxs = [i for i in range(len(self))] + self.update_epoch() # should be called every epoch + + def update_epoch(self): + if self.resampler is not None: + self.dynamic_idxs = self.resampler(len(self)) + + def get_len(self, idx): + return self._lengths[self.dynamic_idxs[idx]] + + def get_summary(self, idx: int): + props = self._properties[idx] + _id = self._indexes[idx][0].split('.')[0] + ref_pdb = os.path.join(self.mmap_dir, '..', 'pdbs', _id + '.pdb') + rec_chain, lig_chain = props[4], props[5] + return _id, ref_pdb, rec_chain, lig_chain + + def __getitem__(self, idx: int): + idx = self.dynamic_idxs[idx] + rec_blocks, lig_blocks = super().__getitem__(idx) + # receptor, (lig_chain_id, lig_blocks) = super().__getitem__(idx) + # pocket = {} + # for i in self._properties[idx][-1].split(','): + # chain, i = i.split(':') + # if chain not in pocket: + # pocket[chain] = [] + # pocket[chain].append(int(i)) + # rec_blocks = [] + # for chain_id, blocks in receptor: + # for i in pocket[chain_id]: + # rec_blocks.append(blocks[i]) + pocket_idx = [int(i) for i in self._properties[idx][-1].split(',')] + rec_position_ids = [i + 1 for i, _ in enumerate(rec_blocks)] + rec_blocks = [rec_blocks[i] for i in pocket_idx] + rec_position_ids = [rec_position_ids[i] for i in pocket_idx] + rec_blocks = [Block.from_tuple(tup) for tup in rec_blocks] + lig_blocks = [Block.from_tuple(tup) for tup in lig_blocks] + + # for block in lig_blocks: + # block.units = [Atom('CA', [0, 0, 0], 'C')] + # if idx == 0: + # print(self._properties[idx]) + # print(''.join(VOCAB.abrv_to_symbol(block.abrv) for block in lig_blocks)) + # list_blocks_to_pdb([ + # rec_blocks, lig_blocks + # ], ['B', 'A'], 'pocket.pdb') + + mask = [0 for _ in rec_blocks] + [1 for _ in lig_blocks] + position_ids = rec_position_ids + [i + 1 for i, _ in enumerate(lig_blocks)] + X, S, atom_mask = [], [], [] + for block in rec_blocks + lig_blocks: + symbol = VOCAB.abrv_to_symbol(block.abrv) + atom2coord = { unit.name: unit.get_coord() for unit in block.units } + bb_pos = np.mean(list(atom2coord.values()), axis=0).tolist() + coords, coord_mask = [], [] + for atom_name in VOCAB.backbone_atoms + sidechain_atoms.get(symbol, []): + if atom_name in atom2coord: + coords.append(atom2coord[atom_name]) + coord_mask.append(1) + else: + coords.append(bb_pos) + coord_mask.append(0) + n_pad = self.MAX_N_ATOM - len(coords) + for _ in range(n_pad): + coords.append(bb_pos) + coord_mask.append(0) + + X.append(coords) + S.append(VOCAB.symbol_to_idx(symbol)) + atom_mask.append(coord_mask) + + X, atom_mask = torch.tensor(X, dtype=torch.float), torch.tensor(atom_mask, dtype=torch.bool) + mask = torch.tensor(mask, dtype=torch.bool) + if self.backbone_only: + X, atom_mask = X[:, :4], atom_mask[:, :4] + + if self.use_covariance_matrix: + cov = calculate_covariance_matrix(X[~mask][:, 1][atom_mask[~mask][:, 1]].numpy()) # only use the receptor to derive the affine transformation + eps = 1e-4 + cov = cov + eps * np.identity(cov.shape[0]) + L = torch.from_numpy(np.linalg.cholesky(cov)).float().unsqueeze(0) + else: + L = None + + item = { + 'X': X, # [N, 14] or [N, 4] if backbone_only == True + 'S': torch.tensor(S, dtype=torch.long), # [N] + 'position_ids': torch.tensor(position_ids, dtype=torch.long), # [N] + 'mask': mask, # [N], 1 for generation + 'atom_mask': atom_mask, # [N, 14] or [N, 4], 1 for having records in the PDB + 'lengths': len(S), + } + if L is not None: + item['L'] = L + return item + + def collate_fn(self, batch): + if self.padding_collate: + results = {} + pad_idx = VOCAB.symbol_to_idx(VOCAB.PAD) + for key in batch[0]: + values = [item[key] for item in batch] + if values[0] is None: + results[key] = None + continue + if key == 'lengths': + results[key] = torch.tensor(values, dtype=torch.long) + elif key == 'S': + results[key] = pad_sequence(values, batch_first=True, padding_value=pad_idx) + else: + results[key] = pad_sequence(values, batch_first=True, padding_value=0) + return results + else: + results = {} + for key in batch[0]: + values = [item[key] for item in batch] + if values[0] is None: + results[key] = None + continue + if key == 'lengths': + results[key] = torch.tensor(values, dtype=torch.long) + else: + results[key] = torch.cat(values, dim=0) + return results + + +@R.register('ShapeDataset') +class ShapeDataset(CoDesignDataset): + def __init__( + self, + mmap_dir: str, + specify_data: Optional[str] = None, + specify_index: Optional[str] = None, + padding_collate: bool = False, + cluster: Optional[str] = None + ) -> None: + super().__init__(mmap_dir, False, specify_data, specify_index, padding_collate, cluster) + self.ca_idx = VOCAB.backbone_atoms.index('CA') + + def __getitem__(self, idx: int): + item = super().__getitem__(idx) + + # refine coordinates to CA and the atom furthest from CA + X = item['X'] # [N, 14, 3] + atom_mask = item['atom_mask'] + ca_x = X[:, self.ca_idx].unsqueeze(1) # [N, 1, 3] + sc_x = X[:, 4:] # [N, 10, 3], sidechain atom indexes + dist = torch.norm(sc_x - ca_x, dim=-1) # [N, 10] + dist = dist.masked_fill(~atom_mask[:, 4:], 1e10) + furthest_atom_x = sc_x[torch.arange(sc_x.shape[0]), torch.argmax(dist, dim=-1)] # [N, 3] + X = torch.cat([ca_x, furthest_atom_x.unsqueeze(1)], dim=1) + + item['X'] = X + return item + + +if __name__ == '__main__': + import sys + dataset = CoDesignDataset(sys.argv[1], backbone_only=True) + print(dataset[0]) diff --git a/data/converter/blocks_interface.py b/data/converter/blocks_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..2367a059cdf25e42dccee38abd841f95e40e8ffd --- /dev/null +++ b/data/converter/blocks_interface.py @@ -0,0 +1,89 @@ +#!/usr/bin/python +# -*- coding:utf-8 -*- +import numpy as np + + +def blocks_to_coords(blocks): + max_n_unit = 0 + coords, masks = [], [] + for block in blocks: + coords.append([unit.get_coord() for unit in block.units]) + max_n_unit = max(max_n_unit, len(coords[-1])) + masks.append([1 for _ in coords[-1]]) + + for i in range(len(coords)): + num_pad = max_n_unit - len(coords[i]) + coords[i] = coords[i] + [[0, 0, 0] for _ in range(num_pad)] + masks[i] = masks[i] + [0 for _ in range(num_pad)] + + return np.array(coords), np.array(masks).astype('bool') # [N, M, 3], [N, M], M == max_n_unit, in mask 0 is for padding + + +def dist_matrix_from_coords(coords1, masks1, coords2, masks2): + dist = np.linalg.norm(coords1[:, None] - coords2[None, :], axis=-1) # [N1, N2, M] + dist = dist + np.logical_not(masks1[:, None] * masks2[None, :]) * 1e6 # [N1, N2, M] + dist = np.min(dist, axis=-1) # [N1, N2] + return dist + + +def dist_matrix_from_blocks(blocks1, blocks2): + blocks_coord, blocks_mask = blocks_to_coords(blocks1 + blocks2) + blocks1_coord, blocks1_mask = blocks_coord[:len(blocks1)], blocks_mask[:len(blocks1)] + blocks2_coord, blocks2_mask = blocks_coord[len(blocks1):], blocks_mask[len(blocks1):] + dist = dist_matrix_from_coords(blocks1_coord, blocks1_mask, blocks2_coord, blocks2_mask) + return dist + + +def blocks_interface(blocks1, blocks2, dist_th): + dist = dist_matrix_from_blocks(blocks1, blocks2) + + on_interface = dist < dist_th + indexes1 = np.nonzero(on_interface.sum(axis=1) > 0)[0] + indexes2 = np.nonzero(on_interface.sum(axis=0) > 0)[0] + + blocks1 = [blocks1[i] for i in indexes1] + blocks2 = [blocks2[i] for i in indexes2] + + return (blocks1, blocks2), (indexes1, indexes2) + + +def add_cb(input_array): + #from protein mpnn + #The virtual Cβ coordinates were calculated using ideal angle and bond length definitions: b = Cα - N, c = C - Cα, a = cross(b, c), Cβ = -0.58273431*a + 0.56802827*b - 0.54067466*c + Cα. + N,CA,C,O = input_array + b = CA - N + c = C - CA + a = np.cross(b,c) + CB = np.around(-0.58273431*a + 0.56802827*b - 0.54067466*c + CA,3) + return CB #np.array([N,CA,C,CB,O]) + + +def blocks_to_cb_coords(blocks): + cb_coords = [] + for block in blocks: + try: + cb_coords.append(block.get_unit_by_name('CB').get_coord()) + except KeyError: + tmp_coord = np.array([ + block.get_unit_by_name('N').get_coord(), + block.get_unit_by_name('CA').get_coord(), + block.get_unit_by_name('C').get_coord(), + block.get_unit_by_name('O').get_coord() + ]) + cb_coords.append(add_cb(tmp_coord)) + return np.array(cb_coords) + + +def blocks_cb_interface(blocks1, blocks2, dist_th=8.0): + cb_coords1 = blocks_to_cb_coords(blocks1) + cb_coords2 = blocks_to_cb_coords(blocks2) + dist = np.linalg.norm(cb_coords1[:, None] - cb_coords2[None, :], axis=-1) # [N1, N2] + + on_interface = dist < dist_th + indexes1 = np.nonzero(on_interface.sum(axis=1) > 0)[0] + indexes2 = np.nonzero(on_interface.sum(axis=0) > 0)[0] + + blocks1 = [blocks1[i] for i in indexes1] + blocks2 = [blocks2[i] for i in indexes2] + + return (blocks1, blocks2), (indexes1, indexes2) \ No newline at end of file diff --git a/data/converter/blocks_to_data.py b/data/converter/blocks_to_data.py new file mode 100644 index 0000000000000000000000000000000000000000..1e6f58acde507ee859209ac627727a4223667003 --- /dev/null +++ b/data/converter/blocks_to_data.py @@ -0,0 +1,110 @@ +#!/usr/bin/python +# -*- coding:utf-8 -*- +from typing import List + +import numpy as np + +from data.format import VOCAB, Block +from utils import const + + +def blocks_to_data(*blocks_list: List[List[Block]]): + B, A, X, atom_positions, block_lengths, segment_ids = [], [], [], [], [], [] + atom_mask, is_ca = [], [] + topo_edge_index, topo_edge_attr, atom_names = [], [], [] + last_c_node_id = None + for i, blocks in enumerate(blocks_list): + if len(blocks) == 0: + continue + cur_B, cur_A, cur_X, cur_atom_positions, cur_block_lengths = [], [], [], [], [] + cur_atom_mask, cur_is_ca = [], [] + # other nodes + for block in blocks: + b, symbol = VOCAB.abrv_to_idx(block.abrv), VOCAB.abrv_to_symbol(block.abrv) + x, a, positions, m, ca = [], [], [], [], [] + atom2node_id = {} + if symbol == '?': + atom_missing = {} + else: + atom_missing = { atom_name: True for atom_name in const.backbone_atoms + const.sidechain_atoms[symbol] } + for atom in block: + atom2node_id[atom.name] = len(A) + len(cur_A) + len(a) + a.append(VOCAB.atom_to_idx(atom.get_element())) + x.append(atom.get_coord()) + pos_code = ''.join((c for c in atom.get_pos_code() if not c.isdigit())) + positions.append(VOCAB.atom_pos_to_idx(pos_code)) + if atom.name in atom_missing: + atom_missing[atom.name] = False + m.append(1) + ca.append(atom.name == 'CA') + atom_names.append(atom.name) + for atom_name in atom_missing: + if atom_missing[atom_name]: + atom2node_id[atom_name] = len(A) + len(cur_A) + len(a) + a.append(VOCAB.atom_to_idx(atom_name[0])) # only C, N, O, S in proteins + x.append([0, 0, 0]) + pos_code = ''.join((c for c in atom_name[1:] if not c.isdigit())) + positions.append(VOCAB.atom_pos_to_idx(pos_code)) + m.append(0) + ca.append(atom_name == 'CA') + atom_names.append(atom_name) + block_len = len(a) + cur_B.append(b) + cur_A.extend(a) + cur_X.extend(x) + cur_atom_positions.extend(positions) + cur_block_lengths.append(block_len) + cur_atom_mask.extend(m) + cur_is_ca.extend(ca) + + # topology edges + for src, dst, bond_type in const.sidechain_bonds.get(VOCAB.abrv_to_symbol(block.abrv), []): + src, dst = atom2node_id[src], atom2node_id[dst] + topo_edge_index.append((src, dst)) # no direction + topo_edge_index.append((dst, src)) + topo_edge_attr.append(bond_type) + topo_edge_attr.append(bond_type) + if last_c_node_id is not None and ('CA' in atom2node_id): + src, dst = last_c_node_id, atom2node_id['N'] + topo_edge_index.append((src, dst)) # no direction + topo_edge_index.append((dst, src)) + topo_edge_attr.append(4) + topo_edge_attr.append(4) + if 'CA' not in atom2node_id: + last_c_node_id = None + else: + last_c_node_id = atom2node_id['C'] + + # update coordinates of the global node to the center + # cur_X[0] = np.mean(cur_X[1:], axis=0) + cur_segment_ids = [i for _ in cur_B] + + # finish these blocks + B.extend(cur_B) + A.extend(cur_A) + X.extend(cur_X) + atom_positions.extend(cur_atom_positions) + block_lengths.extend(cur_block_lengths) + segment_ids.extend(cur_segment_ids) + atom_mask.extend(cur_atom_mask) + is_ca.extend(cur_is_ca) + + X = np.array(X).tolist() + topo_edge_index = np.array(topo_edge_index).T.tolist() + topo_edge_attr = (np.array(topo_edge_attr) - 1).tolist() # type starts from 0 but bond type starts from 1 + + data = { + 'X': X, # [Natom, 2, 3] + 'B': B, # [Nb], block (residue) type + 'A': A, # [Natom] + 'atom_positions': atom_positions, # [Natom] + 'block_lengths': block_lengths, # [Nresidue] + 'segment_ids': segment_ids, # [Nresidue] + 'atom_mask': atom_mask, # [Natom] + 'is_ca': is_ca, # [Natom] + 'atom_names': atom_names, # [Natom] + 'topo_edge_index': topo_edge_index, # atom level + 'topo_edge_attr': topo_edge_attr + } + + return data \ No newline at end of file diff --git a/data/converter/list_blocks_to_pdb.py b/data/converter/list_blocks_to_pdb.py new file mode 100644 index 0000000000000000000000000000000000000000..79708dde86666ee7c00841dd353e3f30b9a26a36 --- /dev/null +++ b/data/converter/list_blocks_to_pdb.py @@ -0,0 +1,61 @@ +#!/usr/bin/python +# -*- coding:utf-8 -*- +import os +from typing import List + +import numpy as np + +from Bio.PDB import PDBParser, PDBIO +from Bio.PDB.Structure import Structure as BStructure +from Bio.PDB.Model import Model as BModel +from Bio.PDB.Chain import Chain as BChain +from Bio.PDB.Residue import Residue as BResidue +from Bio.PDB.Atom import Atom as BAtom + +from data.format import Block, Atom, VOCAB + + +def list_blocks_to_pdb(list_blocks: List[List[Block]], chain_names: List[str], out_path: str) -> None: + ''' + Convert pdb file to a list of lists of blocks using Biopython. + Each chain will be a list of blocks. + + Parameters: + list_blocks: A list of lists of blocks. Each list of blocks will be parsed into one chain in the pdb + chain_names: name of chains + out_path: Path to the pdb file + + ''' + pdb_id = os.path.basename(os.path.splitext(out_path)[0]) + structure = BStructure(id=pdb_id) + model = BModel(id=0) + for blocks, chain_name in zip(list_blocks, chain_names): + chain = BChain(id=chain_name) + for i, block in enumerate(blocks): + chain.add(_block_to_biopython(block, i)) + model.add(chain) + structure.add(model) + io = PDBIO() + io.set_structure(structure) + io.save(out_path) + + +def _block_to_biopython(block: Block, pos_code: int) -> BResidue: + _id = (' ', pos_code, ' ') + residue = BResidue(_id, block.abrv, ' ') + for i, atom in enumerate(block): + fullname = ' ' + atom.name + while len(fullname) < 4: + fullname += ' ' + bio_atom = BAtom( + name=atom, + coord=np.array(atom.coordinate, dtype=np.float32), + bfactor=0, + occupancy=1.0, + altloc=' ', + fullname=fullname, + serial_number=i, + element=atom.element + ) + residue.add(bio_atom) + return residue \ No newline at end of file diff --git a/data/converter/pdb_to_list_blocks.py b/data/converter/pdb_to_list_blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..0d5df0c0df502b6e51d3eb1f0cbbdd8b232868d3 --- /dev/null +++ b/data/converter/pdb_to_list_blocks.py @@ -0,0 +1,99 @@ +#!/usr/bin/python +# -*- coding:utf-8 -*- +from typing import Dict, List, Optional, Union + +from Bio.PDB import PDBParser + +from data.format import Block, Atom + + +def pdb_to_list_blocks(pdb: str, selected_chains: Optional[List[str]]=None, return_chain_ids: bool=False, dict_form: bool=False) -> Union[List[List[Block]], Dict[str, List[Block]]]: + ''' + Convert pdb file to a list of lists of blocks using Biopython. + Each chain will be a list of blocks. + + Parameters: + pdb: Path to the pdb file + selected_chains: List of selected chain ids. The returned list will be ordered + according to the ordering of chain ids in this parameter. If not specified, + all chains will be returned. e.g. ['A', 'B'] + return_chain_ids: Whether to return the ids of each chain + dict_form: Whether to return chains in dict form (chain id as the key and blocks + as the value) + + Returns: + A list of lists of blocks. Each chain in the pdb file will be parsed into + one list of blocks. + example: + [ + [residueA1, residueA2, ...], # chain A + [residueB1, residueB2, ...] # chain B + ], + where each residue is instantiated by Block data class. + ''' + + parser = PDBParser(QUIET=True) + structure = parser.get_structure('anonym', pdb) + + list_blocks, chain_ids, chains = [], {}, [] + + for model in structure.get_models(): # use model 1 only + structure = model + break + + for chain in structure.get_chains(): + + _id = chain.get_id() + if (selected_chains is not None) and (_id not in selected_chains): + continue + + residues, res_ids = [], {} + + for residue in chain: + abrv = residue.get_resname() + hetero_flag, res_number, insert_code = residue.get_id() + res_id = f'{res_number}-{insert_code}' + if hetero_flag == 'W': + continue # residue from glucose (WAT) or water (HOH) + if hetero_flag.strip() != '' and res_id in res_ids: + continue # the solvent (e.g. H_EDO (EDO)) + if abrv in ['EDO', 'HOH', 'BME']: # solvent or other molecules + continue + if abrv == 'MSE': + abrv = 'MET' # MET is usually transformed to MSE for structural analysis + + # filter Hs because not all data include them + atoms = [ Atom(atom.get_id(), atom.get_coord().tolist(), atom.element) for atom in residue if atom.element != 'H' ] + block = Block(abrv, atoms, id=(res_number, insert_code)) + if block.is_residue(): + residues.append(block) + res_ids[res_id] = True + + if len(residues) == 0: # not a chain + continue + + chain_ids[_id] = len(list_blocks) + list_blocks.append(residues) + chains.append(_id) + + # reorder + if selected_chains is not None: + list_blocks = [list_blocks[chain_ids[chain_id]] for chain_id in selected_chains] + chains = selected_chains + + if dict_form: + return { chain: blocks for chain, blocks in zip(chains, list_blocks)} + + if return_chain_ids: + return list_blocks, chains + + return list_blocks + + +if __name__ == '__main__': + import sys + list_blocks = pdb_to_list_blocks(sys.argv[1]) + print(f'{sys.argv[1]} parsed') + print(f'number of chains: {len(list_blocks)}') + for i, chain in enumerate(list_blocks): + print(f'chain {i} lengths: {len(chain)}') \ No newline at end of file diff --git a/data/dataset_wrapper.py b/data/dataset_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..818ffa635f2b6d7cb16994d4fc600713565d64cc --- /dev/null +++ b/data/dataset_wrapper.py @@ -0,0 +1,115 @@ +from typing import Callable +from tqdm import tqdm +from math import log + +import numpy as np +import torch +import sympy + +from utils import register as R + + +class MixDatasetWrapper(torch.utils.data.Dataset): + def __init__(self, *datasets, collate_fn: Callable=None) -> None: + super().__init__() + self.datasets = datasets + self.cum_len = [] + self.total_len = 0 + for dataset in datasets: + self.total_len += len(dataset) + self.cum_len.append(self.total_len) + self.collate_fn = self.datasets[0].collate_fn if collate_fn is None else collate_fn + if hasattr(datasets[0], '_lengths'): + self._lengths = [] + for dataset in datasets: + self._lengths.extend(dataset._lengths) + + def update_epoch(self): + for dataset in self.datasets: + if hasattr(dataset, 'update_epoch'): + dataset.update_epoch() + + def get_len(self, idx): + return self._lengths[idx] + + def __len__(self): + return self.total_len + + def __getitem__(self, idx): + last_cum_len = 0 + for i, cum_len in enumerate(self.cum_len): + if idx < cum_len: + return self.datasets[i].__getitem__(idx - last_cum_len) + last_cum_len = cum_len + return None # this is not possible + + +@R.register('DynamicBatchWrapper') +class DynamicBatchWrapper(torch.utils.data.Dataset): + def __init__(self, dataset, complexity, ubound_per_batch) -> None: + super().__init__() + self.dataset = dataset + self.indexes = [i for i in range(len(dataset))] + self.complexity = complexity + self.eval_func = sympy.lambdify('n', sympy.simplify(complexity)) + self.ubound_per_batch = ubound_per_batch + self.total_size = None + self.batch_indexes = [] + self._form_batch() + + def __getattr__(self, attr): + if attr in self.__dict__: + return self.__dict__[attr] + elif hasattr(self.dataset, attr): + return getattr(self.dataset, attr) + else: + raise AttributeError(f"'DynamicBatchWrapper'(or '{type(self.dataset)}') object has no attribute '{attr}'") + + def update_epoch(self): + if hasattr(self.dataset, 'update_epoch'): + self.dataset.update_epoch() + self._form_batch() + + ########## overload with your criterion ########## + def _form_batch(self): + + np.random.shuffle(self.indexes) + last_batch_indexes = self.batch_indexes + self.batch_indexes = [] + + cur_complexity = 0 + batch = [] + + for i in tqdm(self.indexes): + item_len = self.eval_func(self.dataset.get_len(i)) + if item_len > self.ubound_per_batch: + continue + cur_complexity += item_len + if cur_complexity > self.ubound_per_batch: + self.batch_indexes.append(batch) + batch = [] + cur_complexity = item_len + batch.append(i) + self.batch_indexes.append(batch) + + if self.total_size is None: + self.total_size = len(self.batch_indexes) + else: + # control the lengths of the dataset, otherwise the dataloader will raise error + if len(self.batch_indexes) < self.total_size: + num_add = self.total_size - len(self.batch_indexes) + self.batch_indexes = self.batch_indexes + last_batch_indexes[:num_add] + else: + self.batch_indexes = self.batch_indexes[:self.total_size] + + def __len__(self): + return len(self.batch_indexes) + + def __getitem__(self, idx): + return [self.dataset[i] for i in self.batch_indexes[idx]] + + def collate_fn(self, batched_batch): + batch = [] + for minibatch in batched_batch: + batch.extend(minibatch) + return self.dataset.collate_fn(batch) \ No newline at end of file diff --git a/data/format.py b/data/format.py new file mode 100644 index 0000000000000000000000000000000000000000..bd1c350dbc99a66f0e1aa1c79b5249652083a381 --- /dev/null +++ b/data/format.py @@ -0,0 +1,220 @@ +#!/usr/bin/python +# -*- coding:utf-8 -*- +from copy import copy +from typing import List, Tuple, Iterator, Optional + +from utils import const + + +class MoleculeVocab: + + MAX_ATOM_NUMBER = 14 + + def __init__(self): + self.backbone_atoms = ['N', 'CA', 'C', 'O'] + self.PAD, self.MASK, self.UNK, self.LAT = '#', '*', '?', '&' # pad / mask / unk / latent node + specials = [# special added + (self.PAD, 'PAD'), (self.MASK, 'MASK'), (self.UNK, 'UNK'), # pad / mask / unk + (self.LAT, '') # latent node in latent space + ] + + aas = const.aas + + # sms = [(e.lower(), e) for e in const.periodic_table] + sms = [] # disable small molecule vocabulary + + self.atom_pad, self.atom_mask, self.atom_latent = 'pad', 'msk', 'lat' # Avoid conflict with atom P + self.atom_pos_pad, self.atom_pos_mask, self.atom_pos_latent = 'pad', 'msk', 'lat' + self.atom_pos_sm = 'sml' # small molecule + + # block level vocab + self.idx2block = specials + aas + sms + self.symbol2idx, self.abrv2idx = {}, {} + for i, (symbol, abrv) in enumerate(self.idx2block): + self.symbol2idx[symbol] = i + self.abrv2idx[abrv] = i + self.special_mask = [1 for _ in specials] + [0 for _ in aas] + [0 for _ in sms] + + # atom level vocab + self.idx2atom = [self.atom_pad, self.atom_mask, self.atom_latent] + const.periodic_table + self.idx2atom_pos = [self.atom_pos_pad, self.atom_pos_mask, self.atom_pos_latent, '', 'A', 'B', 'G', 'D', 'E', 'Z', 'H', 'XT', 'P', self.atom_pos_sm] # SM is for atoms in small molecule, 'P' for O1P, O2P, O3P + self.atom2idx, self.atom_pos2idx = {}, {} + self.atom2idx = {} + for i, atom in enumerate(self.idx2atom): + self.atom2idx[atom] = i + for i, atom_pos in enumerate(self.idx2atom_pos): + self.atom_pos2idx[atom_pos] = i + + # block level APIs + + def abrv_to_symbol(self, abrv): + idx = self.abrv_to_idx(abrv) + return None if idx is None else self.idx2block[idx][0] + + def symbol_to_abrv(self, symbol): + idx = self.symbol_to_idx(symbol) + return None if idx is None else self.idx2block[idx][1] + + def abrv_to_idx(self, abrv): + abrv = abrv.upper() + return self.abrv2idx.get(abrv, self.abrv2idx['UNK']) + + def symbol_to_idx(self, symbol): + # symbol = symbol.upper() + return self.symbol2idx.get(symbol, self.abrv2idx['UNK']) + + def idx_to_symbol(self, idx): + return self.idx2block[idx][0] + + def idx_to_abrv(self, idx): + return self.idx2block[idx][1] + + def get_pad_idx(self): + return self.symbol_to_idx(self.PAD) + + def get_mask_idx(self): + return self.symbol_to_idx(self.MASK) + + def get_special_mask(self): + return copy(self.special_mask) + + # atom level APIs + + def get_atom_pad_idx(self): + return self.atom2idx[self.atom_pad] + + def get_atom_mask_idx(self): + return self.atom2idx[self.atom_mask] + + def get_atom_latent_idx(self): + return self.atom2idx[self.atom_latent] + + def get_atom_pos_pad_idx(self): + return self.atom_pos2idx[self.atom_pos_pad] + + def get_atom_pos_mask_idx(self): + return self.atom_pos2idx[self.atom_pos_mask] + + def get_atom_pos_latent_idx(self): + return self.atom_pos2idx[self.atom_pos_latent] + + def idx_to_atom(self, idx): + return self.idx2atom[idx] + + def atom_to_idx(self, atom): + atom = atom.upper() + return self.atom2idx.get(atom, self.atom2idx[self.atom_mask]) + + def idx_to_atom_pos(self, idx): + return self.idx2atom_pos[idx] + + def atom_pos_to_idx(self, atom_pos): + return self.atom_pos2idx.get(atom_pos, self.atom_pos2idx[self.atom_pos_mask]) + + # sizes + + def get_num_atom_type(self): + return len(self.idx2atom) + + def get_num_atom_pos(self): + return len(self.idx2atom_pos) + + def get_num_block_type(self): + return len(self.special_mask) - sum(self.special_mask) + + def __len__(self): + return len(self.symbol2idx) + + # others + @property + def ca_channel_idx(self): + return self.backbone_atoms.index('CA') + + +VOCAB = MoleculeVocab() + + +class Atom: + def __init__(self, atom_name: str, coordinate: List[float], element: str, pos_code: str=None): + self.name = atom_name + self.coordinate = coordinate + self.element = element + if pos_code is None: + pos_code = atom_name.lstrip(element) + self.pos_code = pos_code + else: + self.pos_code = pos_code + + def get_element(self): + return self.element + + def get_coord(self): + return copy(self.coordinate) + + def get_pos_code(self): + return self.pos_code + + def __str__(self) -> str: + return self.name + + def __repr__(self) -> str: + return f"Atom ({self.name}): {self.element}({self.pos_code}) [{','.join(['{:.4f}'.format(num) for num in self.coordinate])}]" + + def to_tuple(self): + return ( + self.name, + self.coordinate, + self.element, + self.pos_code + ) + + @classmethod + def from_tuple(self, data): + return Atom( + atom_name=data[0], + coordinate=data[1], + element=data[2], + pos_code=data[3] + ) + + +class Block: + def __init__(self, abrv: str, units: List[Atom], id: Optional[any]=None) -> None: + self.abrv: str = abrv + self.units: List[Atom] = units + self._uname2idx = { unit.name: i for i, unit in enumerate(self.units) } + self.id = id + + def __len__(self) -> int: + return len(self.units) + + def __iter__(self) -> Iterator[Atom]: + return iter(self.units) + + def get_unit_by_name(self, name: str) -> Atom: + idx = self._uname2idx[name] + return self.units[idx] + + def has_unit(self, name: str) -> bool: + return name in self._uname2idx + + def to_tuple(self): + return ( + self.abrv, + [unit.to_tuple() for unit in self.units], + self.id + ) + + def is_residue(self): + return self.has_unit('CA') and self.has_unit('N') and self.has_unit('C') and self.has_unit('O') + + @classmethod + def from_tuple(self, data): + return Block( + abrv=data[0], + units=[Atom.from_tuple(unit_data) for unit_data in data[1]], + id=data[2] + ) + + def __repr__(self) -> str: + return f"Block ({self.abrv}):\n\t" + '\n\t'.join([repr(at) for at in self.units]) + '\n' \ No newline at end of file diff --git a/data/mmap_dataset.py b/data/mmap_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..8e146f0cbd25a95f34f0cdcc14c321d7f9642cd0 --- /dev/null +++ b/data/mmap_dataset.py @@ -0,0 +1,112 @@ +#!/usr/bin/python +# -*- coding:utf-8 -*- +import os +import io +import gzip +import json +import mmap +from typing import Optional +from tqdm import tqdm + +import torch + + +def compress(x): + serialized_x = json.dumps(x).encode() + buf = io.BytesIO() + with gzip.GzipFile(fileobj=buf, mode='wb', compresslevel=6) as f: + f.write(serialized_x) + compressed = buf.getvalue() + return compressed + + +def decompress(compressed_x): + buf = io.BytesIO(compressed_x) + with gzip.GzipFile(fileobj=buf, mode="rb") as f: + serialized_x = f.read().decode() + x = json.loads(serialized_x) + return x + + +def _find_measure_unit(num_bytes): + size, measure_unit = num_bytes, 'Bytes' + for unit in ['KB', 'MB', 'GB']: + if size > 1000: + size /= 1024 + measure_unit = unit + else: + break + return size, measure_unit + + +def create_mmap(iterator, out_dir, total_len=None, commit_batch=10000): + + if not os.path.exists(out_dir): + os.makedirs(out_dir) + + data_file_path = os.path.join(out_dir, 'data.bin') + data_file = open(data_file_path, 'wb') + index_file = open(os.path.join(out_dir, 'index.txt'), 'w') + + i, offset, n_finished = 0, 0, 0 + progress_bar = tqdm(iterator, total=total_len) + for _id, x, properties, entry_idx in iterator: + progress_bar.set_description(f'Processing {_id}') + compressed_x = compress(x) + bin_length = data_file.write(compressed_x) + properties = '\t'.join([str(prop) for prop in properties]) + index_file.write(f'{_id}\t{offset}\t{offset + bin_length}\t{properties}\n') # tuple of (_id, start, end), data slice is [start, end) + offset += bin_length + i += 1 + + if entry_idx > n_finished: + progress_bar.update(entry_idx - n_finished) + n_finished = entry_idx + if total_len is not None: + expected_size = os.fstat(data_file.fileno()).st_size / n_finished * total_len + expected_size, measure_unit = _find_measure_unit(expected_size) + progress_bar.set_postfix({f'{i} saved; Estimated total size ({measure_unit})': expected_size}) + + if i % commit_batch == 0: + data_file.flush() # save from memory to disk + index_file.flush() + + + data_file.close() + index_file.close() + + +class MMAPDataset(torch.utils.data.Dataset): + + def __init__(self, mmap_dir: str, specify_data: Optional[str]=None, specify_index: Optional[str]=None) -> None: + super().__init__() + + self._indexes = [] + self._properties = [] + _index_path = os.path.join(mmap_dir, 'index.txt') if specify_index is None else specify_index + with open(_index_path, 'r') as f: + for line in f.readlines(): + messages = line.strip().split('\t') + _id, start, end = messages[:3] + _property = messages[3:] + self._indexes.append((_id, int(start), int(end))) + self._properties.append(_property) + _data_path = os.path.join(mmap_dir, 'data.bin') if specify_data is None else specify_data + self._data_file = open(_data_path, 'rb') + self._mmap = mmap.mmap(self._data_file.fileno(), 0, access=mmap.ACCESS_READ) + + def __del__(self): + self._mmap.close() + self._data_file.close() + + def __len__(self): + return len(self._indexes) + + def __getitem__(self, idx: int): + if idx < 0 or idx >= len(self): + raise IndexError(idx) + + _, start, end = self._indexes[idx] + data = decompress(self._mmap[start:end]) + + return data \ No newline at end of file diff --git a/data/resample.py b/data/resample.py new file mode 100644 index 0000000000000000000000000000000000000000..6096ceef78c877cc8ab342961a3ee389db6d2187 --- /dev/null +++ b/data/resample.py @@ -0,0 +1,19 @@ +#!/usr/bin/python +# -*- coding:utf-8 -*- +import numpy as np + + +class ClusterResampler: + def __init__(self, cluster_path: str) -> None: + idx2prob = [] + with open(cluster_path, 'r') as fin: + for line in fin: + cluster_n_member = int(line.strip().split('\t')[-1]) + idx2prob.append(1 / cluster_n_member) + total = sum(idx2prob) + idx2prob = [p / total for p in idx2prob] + self.idx2prob = np.array(idx2prob) + + def __call__(self, n_sample:int, replace: bool=False): + idxs = np.random.choice(len(self.idx2prob), size=n_sample, replace=replace, p=self.idx2prob) + return idxs \ No newline at end of file diff --git a/env.yaml b/env.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b09c07dfeec7cadd72799686ac30ac385d047679 --- /dev/null +++ b/env.yaml @@ -0,0 +1,32 @@ +name: PepGLAD +channels: + - pytorch + - nvidia + - bioconda + - pyg + - salilab + - conda-forge + - defaults +dependencies: + - python=3.9 + - pytorch::pytorch=1.13.1 + - pytorch::pytorch-cuda=11.7 + - nvidia::cudatoolkit=11.7.0 + - pyg::pytorch-scatter + - mkl=2024.0.0 + - salilab::dssp + - anaconda::libboost=1.73.0 + - mmseqs2 + - openmm=8.0.0 + - pdbfixer + - pip + - pip: + - biopython==1.80 + - rdkit-pypi==2022.3.5 + - ray + - sympy + - scipy + - freesasa + - tensorboard + - pyyaml + - tqdm \ No newline at end of file diff --git a/evaluation/__init__.py b/evaluation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..85e9a3bf118a787e4d729204019f189ab56b3be3 --- /dev/null +++ b/evaluation/__init__.py @@ -0,0 +1,3 @@ +#!/usr/bin/python +# -*- coding:utf-8 -*- + diff --git a/evaluation/dG/RosettaFastRelaxUtil.xml b/evaluation/dG/RosettaFastRelaxUtil.xml new file mode 100644 index 0000000000000000000000000000000000000000..711829099891d4019287acca6170ceaf0630116d --- /dev/null +++ b/evaluation/dG/RosettaFastRelaxUtil.xml @@ -0,0 +1,190 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + PruneBadRotamers name="prune_bad_rotamers" probability_cut="0.01" /> + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/evaluation/dG/base.py b/evaluation/dG/base.py new file mode 100644 index 0000000000000000000000000000000000000000..e0e47d3329680af9a4d0714ac227541e8a071b32 --- /dev/null +++ b/evaluation/dG/base.py @@ -0,0 +1,148 @@ +#!/usr/bin/python +# -*- coding:utf-8 -*- +import os +import re +import json +from typing import Optional, Tuple, List +from dataclasses import dataclass + +from .energy import pyrosetta_fastrelax, pyrosetta_interface_energy, rfdiff_refine + + +@dataclass +class RelaxTask: + in_path: str + current_path: str + info: dict + status: str + rec_chain: str + pep_chain: str + rfdiff_relax: bool = False + dG: Optional[float] = None + + def set_dG(self, dG): + self.dG = dG + + def get_in_path_with_tag(self, tag): + name, ext = os.path.splitext(self.in_path) + new_path = f'{name}_{tag}{ext}' + return new_path + + def set_current_path_tag(self, tag): + new_path = self.get_in_path_with_tag(tag) + self.current_path = new_path + return new_path + + def check_current_path_exists(self): + ok = os.path.exists(self.current_path) + if not ok: + self.mark_failure() + if os.path.getsize(self.current_path) == 0: + ok = False + self.mark_failure() + os.unlink(self.current_path) + return ok + + def update_if_finished(self, tag): + out_path = self.get_in_path_with_tag(tag) + if os.path.exists(out_path) and os.path.getsize(out_path) > 0: + # print('Already finished', out_path) + self.set_current_path_tag(tag) + self.mark_success() + return True + return False + + def can_proceed(self): + self.check_current_path_exists() + return self.status != 'failed' + + def mark_success(self): + self.status = 'success' + + def mark_failure(self): + self.status = 'failed' + + +class TaskScanner: + + def __init__(self, results, n_sample, rfdiff_relax): + super().__init__() + self.results = results + self.n_sample = n_sample + self.rfdiff_relax = rfdiff_relax + self.visited = set() + + def scan(self) -> List[RelaxTask]: + tasks = [] + root_dir = os.path.dirname(self.results) + with open(self.results, 'r') as fin: + lines = fin.readlines() + for line in lines: + item = json.loads(line) + if item['number'] >= self.n_sample: + continue + _id = f"{item['id']}_{item['number']}" + if _id in self.visited: + continue + gen_pdb = os.path.split(item['gen_pdb'])[-1] + # subdir = gen_pdb.split('_')[0] + subdir = '_'.join(gen_pdb.split('_')[:-2]) + gen_pdb = os.path.join(root_dir, 'candidates', subdir, gen_pdb) + tasks.append(RelaxTask( + in_path=gen_pdb, + current_path=gen_pdb, + info=item, + status='created', + rec_chain=item['rec_chain'], + pep_chain=item['lig_chain'], + rfdiff_relax=self.rfdiff_relax + )) + self.visited.add(_id) + return tasks + + def scan_dataset(self) -> List[RelaxTask]: + tasks = [] + root_dir = os.path.dirname(self.results) + with open(self.results, 'r') as fin: # index file of datasets + lines = fin.readlines() + for line in lines: + line = line.strip('\n').split('\t') + _id = line[0] + item = { + 'id': _id, + 'number': 0 + } + pdb_path = os.path.join(root_dir, 'pdbs', _id + '.pdb') + tasks.append(RelaxTask( + in_path=pdb_path, + current_path=pdb_path, + info=item, + status='created', + rec_chain=line[7], + pep_chain=line[8], + rfdiff_relax=self.rfdiff_relax + )) + self.visited.add(_id) + return tasks + + +def run_pyrosetta(task: RelaxTask): + if not task.can_proceed() : + return task + # if task.update_if_finished('rosetta'): + # return task + + out_path = task.set_current_path_tag('rosetta') + try: + if task.rfdiff_relax: + rfdiff_refine(task.in_path, out_path, task.pep_chain) + else: + pyrosetta_fastrelax(task.in_path, out_path, task.pep_chain, rfdiff_config=task.rfdiff_relax) + dG = pyrosetta_interface_energy(out_path, [task.rec_chain], [task.pep_chain]) + task.mark_success() + except Exception as e: + print(e) + dG = 1e10 + task.mark_failure() + task.set_dG(dG) + return task diff --git a/evaluation/dG/energy.py b/evaluation/dG/energy.py new file mode 100644 index 0000000000000000000000000000000000000000..c9d24620817b73ec78a41d642ef9d46a287bc11e --- /dev/null +++ b/evaluation/dG/energy.py @@ -0,0 +1,236 @@ +#!/usr/bin/python +# -*- coding:utf-8 -*- +''' + From https://github.com/luost26/diffab/blob/main/diffab/tools/relax/pyrosetta_relaxer.py +''' +import os +import time +import pyrosetta +from pyrosetta.rosetta.protocols.analysis import InterfaceAnalyzerMover +# for fast relax +from pyrosetta.rosetta import protocols +from pyrosetta.rosetta.protocols.relax import FastRelax +from pyrosetta.rosetta.core.pack.task import TaskFactory +from pyrosetta.rosetta.core.pack.task import operation +from pyrosetta.rosetta.core.select import residue_selector as selections +from pyrosetta.rosetta.core.select.movemap import MoveMapFactory, move_map_action +from pyrosetta.rosetta.core.scoring import ScoreType + +from Bio.PDB import PDBIO, PDBParser +from Bio.PDB.Structure import Structure as BStructure +from Bio.PDB.Model import Model as BModel +from Bio.PDB.Chain import Chain as BChain + + +pyrosetta.init(' '.join([ + '-mute', 'all', + '-use_input_sc', + '-ignore_unrecognized_res', + '-ignore_zero_occupancy', 'false', + '-load_PDB_components', 'false', + '-relax:default_repeats', '2', + '-no_fconfig', + # below are from https://github.com/nrbennet/dl_binder_design/blob/main/mpnn_fr/dl_interface_design.py + # '-beta_nov16', + '-use_terminal_residues', 'true', + '-in:file:silent_struct_type', 'binary' +])) + + +def current_milli_time(): + return round(time.time() * 1000) + + +def get_scorefxn(scorefxn_name:str): + """ + Gets the scorefxn with appropriate corrections. + Taken from: https://gist.github.com/matteoferla/b33585f3aeab58b8424581279e032550 + """ + import pyrosetta + + corrections = { + 'beta_july15': False, + 'beta_nov16': False, + 'gen_potential': False, + 'restore_talaris_behavior': False, + } + if 'beta_july15' in scorefxn_name or 'beta_nov15' in scorefxn_name: + # beta_july15 is ref2015 + corrections['beta_july15'] = True + elif 'beta_nov16' in scorefxn_name: + corrections['beta_nov16'] = True + elif 'genpot' in scorefxn_name: + corrections['gen_potential'] = True + pyrosetta.rosetta.basic.options.set_boolean_option('corrections:beta_july15', True) + elif 'talaris' in scorefxn_name: #2013 and 2014 + corrections['restore_talaris_behavior'] = True + else: + pass + for corr, value in corrections.items(): + pyrosetta.rosetta.basic.options.set_boolean_option(f'corrections:{corr}', value) + return pyrosetta.create_score_function(scorefxn_name) + + +class RelaxRegion(object): + + def __init__(self, scorefxn='ref2015', max_iter=1000, subset='nbrs', move_bb=True, rfdiff_config=False): + super().__init__() + + if rfdiff_config: + self.scorefxn = get_scorefxn('beta_nov16') + xml = os.path.join(os.path.dirname(__file__), 'RosettaFastRelaxUtil.xml') + objs = protocols.rosetta_scripts.XmlObjects.create_from_file(xml) + self.fast_relax = objs.get_mover('FastRelax') + self.fast_relax.max_iter(max_iter) + else: + self.scorefxn = get_scorefxn(scorefxn) + self.fast_relax = FastRelax() + self.fast_relax.set_scorefxn(self.scorefxn) + self.fast_relax.max_iter(max_iter) + + assert subset in ('all', 'target', 'nbrs') + self.subset = subset + self.move_bb = move_bb + + def __call__(self, pdb_path, ligand_chains): # flexible_residue_first, flexible_residue_last): + pose = pyrosetta.pose_from_pdb(pdb_path) + start_t = current_milli_time() + original_pose = pose.clone() + + tf = TaskFactory() + tf.push_back(operation.InitializeFromCommandline()) + tf.push_back(operation.RestrictToRepacking()) # Only allow residues to repack. No design at any position. + + # Create selector for the region to be relaxed + # Turn off design and repacking on irrelevant positions + # if flexible_residue_first[-1] == ' ': + # flexible_residue_first = flexible_residue_first[:-1] + # if flexible_residue_last[-1] == ' ': + # flexible_residue_last = flexible_residue_last[:-1] + if self.subset != 'all': + chain_selectors = [selections.ChainSelector(chain) for chain in ligand_chains] + if len(chain_selectors) == 1: + gen_selector = chain_selectors[0] + else: + gen_selector = selections.OrResidueSelector(chain_selectors[0], chain_selectors[1]) + for selector in chain_selectors[2:]: + gen_selector = selections.OrResidueSelector(gen_selector, selector) + # gen_selector = selections.ChainSelector(pep_chain) + # gen_selector = selections.ResidueIndexSelector() + # gen_selector.set_index_range( + # pose.pdb_info().pdb2pose(*flexible_residue_first), + # pose.pdb_info().pdb2pose(*flexible_residue_last), + # ) + nbr_selector = selections.NeighborhoodResidueSelector() + nbr_selector.set_focus_selector(gen_selector) + nbr_selector.set_include_focus_in_subset(True) + + if self.subset == 'nbrs': + subset_selector = nbr_selector + elif self.subset == 'target': + subset_selector = gen_selector + + prevent_repacking_rlt = operation.PreventRepackingRLT() + prevent_subset_repacking = operation.OperateOnResidueSubset( + prevent_repacking_rlt, + subset_selector, + flip_subset=True, + ) + tf.push_back(prevent_subset_repacking) + + scorefxn = self.scorefxn + fr = self.fast_relax + + pose = original_pose.clone() + # pos_list = pyrosetta.rosetta.utility.vector1_unsigned_long() + # for pos in range(pose.pdb_info().pdb2pose(*flexible_residue_first), pose.pdb_info().pdb2pose(*flexible_residue_last)+1): + # pos_list.append(pos) + # basic_idealize(pose, pos_list, scorefxn, fast=True) + + mmf = MoveMapFactory() + if self.move_bb: + mmf.add_bb_action(move_map_action.mm_enable, gen_selector) + mmf.add_chi_action(move_map_action.mm_enable, subset_selector) + mm = mmf.create_movemap_from_pose(pose) + + fr.set_movemap(mm) + fr.set_task_factory(tf) + fr.apply(pose) + + e_before = scorefxn(original_pose) + e_relax = scorefxn(pose) + # print('\n\n[Finished in %.2f secs]' % ((current_milli_time() - start_t) / 1000)) + # print(' > Energy (before): %.4f' % scorefxn(original_pose)) + # print(' > Energy (optimized): %.4f' % scorefxn(pose)) + return pose, e_before, e_relax + + +def pyrosetta_fastrelax(pdb_path, out_path, pep_chain, rfdiff_config=False): + minimizer = RelaxRegion(rfdiff_config=rfdiff_config) + pose_min, _, _ = minimizer( + pdb_path=pdb_path, + ligand_chains=[pep_chain] + ) + pose_min.dump_pdb(out_path) + + +def _rename_chain(pdb_path, out_path, src_pep_chain, tgt_pep_chain, tgt_rec_chain): + + io = PDBIO() + parser = PDBParser() + + structure = parser.get_structure('anonymous', pdb_path) + + new_mapping = {} + pep_chain, rec_chain = BChain(id=tgt_pep_chain), BChain(id=tgt_rec_chain) + + for model in structure: + for chain in model: + if chain.get_id() == src_pep_chain: + new_mapping[src_pep_chain] = tgt_pep_chain + for res in chain: + pep_chain.add(res.copy()) + else: + new_mapping[chain.get_id()] = tgt_rec_chain + for res in chain: + rec_chain.add(res.copy()) + + structure = BStructure(id=structure.get_id()) + model = BModel(id=0) + model.add(pep_chain) + model.add(rec_chain) + structure.add(model) + + io.set_structure(structure) + io.save(out_path) + + return new_mapping + + +def rfdiff_refine(pdb_path, out_path, pep_chain): + # rename peptide chain to A and receptor to B + new_mapping = _rename_chain(pdb_path, out_path, pep_chain, 'A', 'B') + + # force fields from RFDiffusion + get_scorefxn('beta_nov16') + xml = os.path.join(os.path.dirname(__file__), 'RosettaFastRelaxUtil.xml') + objs = protocols.rosetta_scripts.XmlObjects.create_from_file(xml) + fastrelax = objs.get_mover('FastRelax') + pose = pyrosetta.pose_from_pdb(out_path) + fastrelax.apply(pose) + pose.dump_pdb(out_path) + + # get back to original chain ids + reverse_mapping = { new_mapping[key]: key for key in new_mapping } + _rename_chain(out_path, out_path, 'A', reverse_mapping['A'], reverse_mapping['B']) + + +def pyrosetta_interface_energy(pdb_path, receptor_chains, ligand_chains, return_dict=False): + pose = pyrosetta.pose_from_pdb(pdb_path) + interface = ''.join(ligand_chains) + '_' + ''.join(receptor_chains) + mover = InterfaceAnalyzerMover(interface) + mover.set_pack_separated(True) + mover.apply(pose) + if return_dict: + return pose.scores + return pose.scores['dG_separated'] diff --git a/evaluation/dG/openmm_relaxer.py b/evaluation/dG/openmm_relaxer.py new file mode 100644 index 0000000000000000000000000000000000000000..f2c0d178a724795cc28426088cf2287981f5b3ca --- /dev/null +++ b/evaluation/dG/openmm_relaxer.py @@ -0,0 +1,107 @@ +import os +import time +import io +import logging +import pdbfixer +import openmm +from openmm import app as openmm_app +from openmm import unit +ENERGY = unit.kilocalories_per_mole +LENGTH = unit.angstroms + + +class ForceFieldMinimizer(object): + + def __init__(self, stiffness=10.0, max_iterations=0, tolerance=2.39*unit.kilocalories_per_mole, platform='CUDA'): + super().__init__() + self.stiffness = stiffness + self.max_iterations = max_iterations + self.tolerance = tolerance + assert platform in ('CUDA', 'CPU') + self.platform = platform + + def _fix(self, pdb_str): + fixer = pdbfixer.PDBFixer(pdbfile=io.StringIO(pdb_str)) + fixer.findNonstandardResidues() + fixer.replaceNonstandardResidues() + + fixer.findMissingResidues() + fixer.findMissingAtoms() + fixer.addMissingAtoms(seed=0) + fixer.addMissingHydrogens() + + out_handle = io.StringIO() + openmm_app.PDBFile.writeFile(fixer.topology, fixer.positions, out_handle, keepIds=True) + return out_handle.getvalue() + + def _get_pdb_string(self, topology, positions): + with io.StringIO() as f: + openmm_app.PDBFile.writeFile(topology, positions, f, keepIds=True) + return f.getvalue() + + def _minimize(self, pdb_str): + pdb = openmm_app.PDBFile(io.StringIO(pdb_str)) + + force_field = openmm_app.ForceField("charmm36.xml") # referring to http://docs.openmm.org/latest/userguide/application/02_running_sims.html + constraints = openmm_app.HBonds + system = force_field.createSystem(pdb.topology, constraints=constraints) + + # Add constraints to non-generated regions + force = openmm.CustomExternalForce("0.5 * k * ((x-x0)^2 + (y-y0)^2 + (z-z0)^2)") + force.addGlobalParameter("k", self.stiffness) + for p in ["x0", "y0", "z0"]: + force.addPerParticleParameter(p) + + for i, a in enumerate(pdb.topology.atoms()): + if a.element.name != 'hydrogen': + force.addParticle(i, pdb.positions[i]) + + system.addForce(force) + + # Set up the integrator and simulation + integrator = openmm.LangevinIntegrator(0, 0.01, 0.0) + platform = openmm.Platform.getPlatformByName("CUDA") + simulation = openmm_app.Simulation(pdb.topology, system, integrator, platform) + simulation.context.setPositions(pdb.positions) + + # Perform minimization + ret = {} + state = simulation.context.getState(getEnergy=True, getPositions=True) + ret["einit"] = state.getPotentialEnergy().value_in_unit(ENERGY) + ret["posinit"] = state.getPositions(asNumpy=True).value_in_unit(LENGTH) + + simulation.minimizeEnergy(maxIterations=self.max_iterations, tolerance=self.tolerance) + + state = simulation.context.getState(getEnergy=True, getPositions=True) + ret["efinal"] = state.getPotentialEnergy().value_in_unit(ENERGY) + ret["pos"] = state.getPositions(asNumpy=True).value_in_unit(LENGTH) + ret["min_pdb"] = self._get_pdb_string(simulation.topology, state.getPositions()) + + return ret['min_pdb'], ret + + def _add_energy_remarks(self, pdb_str, ret): + pdb_lines = pdb_str.splitlines() + pdb_lines.insert(1, "REMARK 1 FINAL ENERGY: {:.3f} KCAL/MOL".format(ret['efinal'])) + pdb_lines.insert(1, "REMARK 1 INITIAL ENERGY: {:.3f} KCAL/MOL".format(ret['einit'])) + return "\n".join(pdb_lines) + + def __call__(self, pdb_str, out_path, return_info=True): + if '\n' not in pdb_str and pdb_str.lower().endswith(".pdb"): + with open(pdb_str) as f: + pdb_str = f.read() + + pdb_fixed = self._fix(pdb_str) + pdb_min, ret = self._minimize(pdb_fixed) + pdb_min = self._add_energy_remarks(pdb_min, ret) + with open(out_path, 'w') as f: + f.write(pdb_min) + if return_info: + return pdb_min, ret + else: + return pdb_min + + +if __name__ == '__main__': + import sys + force_field = ForceFieldMinimizer() + force_field(sys.argv[1], sys.argv[2]) diff --git a/evaluation/dG/run.py b/evaluation/dG/run.py new file mode 100644 index 0000000000000000000000000000000000000000..09800367f74ff9cb31ee13c56894a87741ccf2fe --- /dev/null +++ b/evaluation/dG/run.py @@ -0,0 +1,92 @@ +#!/usr/bin/python +# -*- coding:utf-8 -*- +import os +import json +import argparse +import statistics + +import ray + +from utils.logger import print_log + +from .base import TaskScanner, run_pyrosetta + +# @ray.remote(num_gpus=1/8, num_cpus=1) +# def run_openmm_remote(task): +# return run_openmm(task) + + +@ray.remote(num_cpus=1) +def run_pyrosetta_remote(task): + return run_pyrosetta(task) + + +@ray.remote +def pipeline_pyrosetta(task): + funcs = [ + run_pyrosetta_remote, + ] + for fn in funcs: + task = fn.remote(task) + return ray.get(task) + + +def parse(): + parser = argparse.ArgumentParser(description='calculating dG using pyrosetta') + parser.add_argument('--results', type=str, required=True, help='Path to the summary of the results (.jsonl)') + parser.add_argument('--n_sample', type=int, default=float('inf'), help='Maximum number of samples for calculation') + parser.add_argument('--rfdiff_relax', action='store_true', help='Use rfdiff fastrelax') + parser.add_argument('--out_path', type=str, default=None, help='Output path, default dG_report.jsonl under the same directory as results') + return parser.parse_args() + + +def main(args): + # output summary + if args.out_path is None: + args.out_path = os.path.join(os.path.dirname(args.results), 'dG_report.jsonl') + results = {} + + # parallel + ray.init() + scanner = TaskScanner(args.results, args.n_sample, args.rfdiff_relax) + if args.results.endswith('txt'): + tasks = scanner.scan_dataset() + else: + tasks = scanner.scan() + futures = [pipeline_pyrosetta.remote(t) for t in tasks] + if len(futures) > 0: + print_log(f'Submitted {len(futures)} tasks.') + while len(futures) > 0: + done_ids, futures = ray.wait(futures, num_returns=1) + for done_id in done_ids: + done_task = ray.get(done_id) + print_log(f'Remaining {len(futures)}. Finished {done_task.current_path}, dG {done_task.dG}') + _id, number = done_task.info['id'], done_task.info['number'] + if _id not in results: + results[_id] = { + 'min': float('inf'), + 'all': {} + } + results[_id]['all'][number] = done_task.dG + results[_id]['min'] = min(results[_id]['min'], done_task.dG) + + # write results + for _id in results: + success = 0 + for n in results[_id]['all']: + if results[_id]['all'][n] < 0: + success += 1 + results[_id]['success rate'] = success / len(results[_id]['all']) + json.dump(results, open(args.out_path, 'w'), indent=2) + + # show results + vals = [results[_id]['min'] for _id in results] + print(f'median: {statistics.median(vals)}, mean: {sum(vals) / len(vals)}') + success = [results[_id]['success rate'] for _id in results] + print(f'mean success rate: {sum(success) / len(success)}') + + +if __name__ == '__main__': + import random + random.seed(12) + main(parse()) diff --git a/evaluation/diversity.py b/evaluation/diversity.py new file mode 100644 index 0000000000000000000000000000000000000000..3e8de4802e53f71cebe3110bae975d0a6be3748f --- /dev/null +++ b/evaluation/diversity.py @@ -0,0 +1,68 @@ +#!/usr/bin/python +# -*- coding:utf-8 -*- +from typing import List + +import numpy as np +from scipy.cluster.hierarchy import linkage, fcluster +from scipy.spatial.distance import squareform +from scipy.stats.contingency import association + +from evaluation.seq_metric import align_sequences + + +def seq_diversity(seqs: List[str], th: float=0.4) -> float: + ''' + th: sequence distance + ''' + dists = [] + for i, seq1 in enumerate(seqs): + dists.append([]) + for j, seq2 in enumerate(seqs): + _, sim = align_sequences(seq1, seq2) + dists[i].append(1 - sim) + dists = np.array(dists) + Z = linkage(squareform(dists), 'single') + cluster = fcluster(Z, t=th, criterion='distance') + return len(np.unique(cluster)) / len(seqs), cluster + + +def struct_diversity(structs: np.ndarray, th: float=4.0) -> float: + ''' + structs: N*L*3, alpha carbon coordinates + th: threshold for clustering (distance < th) + ''' + ca_dists = np.sum((structs[:, None] - structs[None, :]) ** 2, axis=-1) # [N, N, L] + rmsd = np.sqrt(np.mean(ca_dists, axis=-1)) + Z = linkage(squareform(rmsd), 'single') # since the distances might not be euclidean distances (e.g. rmsd) + cluster = fcluster(Z, t=th, criterion='distance') + return len(np.unique(cluster)) / structs.shape[0], cluster + + +def diversity(seqs: List[str], structs: np.ndarray): + seq_div, seq_clu = seq_diversity(seqs) + if structs is None: + return seq_div, None, seq_div, None + struct_div, struct_clu = struct_diversity(structs) + co_div = np.sqrt(seq_div * struct_div) + + n_seq_clu, n_struct_clu = np.max(seq_clu), np.max(struct_clu) # clusters start from 1 + if n_seq_clu == 1 or n_struct_clu == 1: + consistency = 1.0 if n_seq_clu == n_struct_clu else 0.0 + else: + table = [[0 for _ in range(n_struct_clu)] for _ in range(n_seq_clu)] + for seq_c, struct_c in zip(seq_clu, struct_clu): + table[seq_c - 1][struct_c - 1] += 1 + consistency = association(np.array(table), method='cramer') + + return seq_div, struct_div, co_div, consistency + + +if __name__ == '__main__': + N, L = 100, 10 + a = np.random.randn(N, L, 3) + print(struct_diversity(a)) + from utils.const import aas + aas = [tup[0] for tup in aas] + seqs = np.random.randint(0, len(aas), (N, L)) + seqs = [''.join([aas[i] for i in idx]) for idx in seqs] + print(seq_diversity(seqs, 0.4)) \ No newline at end of file diff --git a/evaluation/dockq.py b/evaluation/dockq.py new file mode 100644 index 0000000000000000000000000000000000000000..7a3e2ee926891379f5e89eae509979079cbae392 --- /dev/null +++ b/evaluation/dockq.py @@ -0,0 +1,15 @@ +#!/usr/bin/python +# -*- coding:utf-8 -*- +import os +import re + +from globals import DOCKQ_DIR + + +def dockq(mod_pdb: str, native_pdb: str, pep_chain: str): + p = os.popen(f'{os.path.join(DOCKQ_DIR, "DockQ.py")} {mod_pdb} {native_pdb} -model_chain1 {pep_chain} -native_chain1 {pep_chain} -no_needle') + text = p.read() + p.close() + res = re.search(r'DockQ\s+([0-1]\.[0-9]+)', text) + score = float(res.group(1)) + return score \ No newline at end of file diff --git a/evaluation/rmsd.py b/evaluation/rmsd.py new file mode 100644 index 0000000000000000000000000000000000000000..7deb7a2ca3a54935447966a8e56462247233abd9 --- /dev/null +++ b/evaluation/rmsd.py @@ -0,0 +1,11 @@ +#!/usr/bin/python +# -*- coding:utf-8 -*- +import torch +import numpy as np + + +# a: [N, 3], b: [N, 3] +def compute_rmsd(a, b, aligned=False): # amino acids level rmsd + dist = np.sum((a - b) ** 2, axis=-1) + rmsd = np.sqrt(dist.sum() / a.shape[0]) + return float(rmsd) \ No newline at end of file diff --git a/evaluation/seq_metric.py b/evaluation/seq_metric.py new file mode 100644 index 0000000000000000000000000000000000000000..6cd184685e82f5cd605c2dde20a88aef343ac32f --- /dev/null +++ b/evaluation/seq_metric.py @@ -0,0 +1,71 @@ +#!/usr/bin/python +# -*- coding:utf-8 -*- +from math import sqrt + +from Bio.Align import substitution_matrices, PairwiseAligner + + +def aar(candidate, reference): + hit = 0 + for a, b in zip(candidate, reference): + if a == b: + hit += 1 + return hit / len(reference) + + +def align_sequences(sequence_A, sequence_B, **kwargs): + """ + Performs a global pairwise alignment between two sequences + using the BLOSUM62 matrix and the Needleman-Wunsch algorithm + as implemented in Biopython. Returns the alignment, the sequence + identity and the residue mapping between both original sequences. + """ + + sub_matrice = substitution_matrices.load('BLOSUM62') + aligner = PairwiseAligner() + aligner.substitution_matrix = sub_matrice + alns = aligner.align(sequence_A, sequence_B) + + best_aln = alns[0] + aligned_A, aligned_B = best_aln + + base = sqrt(aligner.score(sequence_A, sequence_A) * aligner.score(sequence_B, sequence_B)) + seq_id = aligner.score(sequence_A, sequence_B) / base + return (aligned_A, aligned_B), seq_id + + +def slide_aar(candidate, reference, aar_func): + ''' + e.g. + candidate: AILPV + reference: ILPVH + + should be matched as + AILPV + ILPVH + + To do this, we slide the candidate and calculate the maximum aar: + A + AI + AIL + AILP + AILPV + ILPV + LPV + PV + V + ''' + special_token = ' ' + ref_len = len(reference) + padded_candidate = special_token * (ref_len - 1) + candidate + special_token * (ref_len - 1) + value = 0 + for start in range(len(padded_candidate) - ref_len + 1): + value = max(value, aar_func(padded_candidate[start:start + ref_len], reference)) + return value + + +if __name__ == '__main__': + print(align_sequences('PKGYAAPSA', 'KPAVYKFTL')) + print(align_sequences('KPAVYKFTL', 'PKGYAAPSA')) + print(align_sequences('PKGYAAPSA', 'PKGYAAPSA')) + print(align_sequences('KPAVYKFTL', 'KPAVYKFTL')) \ No newline at end of file diff --git a/generate.py b/generate.py new file mode 100644 index 0000000000000000000000000000000000000000..9cc3d7fbdfc1de879244b6499c1cce86ff35df1c --- /dev/null +++ b/generate.py @@ -0,0 +1,235 @@ +#!/usr/bin/python +# -*- coding:utf-8 -*- +import argparse +import json +import os +import pickle as pkl +from tqdm import tqdm +from copy import deepcopy +from multiprocessing import Pool + +import yaml +import torch +from torch.utils.data import DataLoader + +import models +from utils.config_utils import overwrite_values +from data.converter.pdb_to_list_blocks import pdb_to_list_blocks +from data.converter.list_blocks_to_pdb import list_blocks_to_pdb +from data.format import VOCAB, Atom +from data import create_dataloader, create_dataset +from utils.logger import print_log +from utils.random_seed import setup_seed +from utils.const import sidechain_atoms + + +def get_best_ckpt(ckpt_dir): + with open(os.path.join(ckpt_dir, 'checkpoint', 'topk_map.txt'), 'r') as f: + ls = f.readlines() + ckpts = [] + for l in ls: + k,v = l.strip().split(':') + k = float(k) + v = v.split('/')[-1] + ckpts.append((k,v)) + + # ckpts = sorted(ckpts, key=lambda x:x[0]) + best_ckpt = ckpts[0][1] + return os.path.join(ckpt_dir, 'checkpoint', best_ckpt) + + +def to_device(data, device): + if isinstance(data, dict): + for key in data: + data[key] = to_device(data[key], device) + elif isinstance(data, list) or isinstance(data, tuple): + res = [to_device(item, device) for item in data] + data = type(data)(res) + elif hasattr(data, 'to'): + data = data.to(device) + return data + + +def clamp_coord(coord): + # some models (e.g. diffab) will output very large coordinates (absolute value >1000) which will corrupt the pdb file + new_coord = [] + for val in coord: + if abs(val) >= 1000: + val = 0 + new_coord.append(val) + return new_coord + + +def overwrite_blocks(blocks, seq=None, X=None): + if seq is not None: + assert len(blocks) == len(seq), f'{len(blocks)} {len(seq)}' + new_blocks = [] + for i, block in enumerate(blocks): + block = deepcopy(block) + if seq is None: + abrv = block.abrv + else: + abrv = VOCAB.symbol_to_abrv(seq[i]) + if block.abrv != abrv: + if X is None: + block.units = [atom for atom in block.units if atom.name in VOCAB.backbone_atoms] + if X is not None: + coords = X[i] + atoms = VOCAB.backbone_atoms + sidechain_atoms[VOCAB.abrv_to_symbol(abrv)] + block.units = [ + Atom(atom_name, clamp_coord(coord), atom_name[0]) for atom_name, coord in zip(atoms, coords) + ] + block.abrv = abrv + new_blocks.append(block) + return new_blocks + + +def generate_wrapper(model, sample_opt={}): + if isinstance(model, models.AutoEncoder): + def wrapper(batch): + X, S, ppls = model.test(batch['X'], batch['S'], batch['mask'], batch['position_ids'], batch['lengths'], batch['atom_mask']) + return X, S, ppls + elif isinstance(model, models.LDMPepDesign): + def wrapper(batch): + X, S, ppls = model.sample(batch['X'], batch['S'], batch['mask'], batch['position_ids'], batch['lengths'], batch['atom_mask'], + L=batch['L'] if 'L' in batch else None, sample_opt=sample_opt) + return X, S, ppls + else: + raise NotImplementedError(f'Wrapper for {type(model)} not implemented') + return wrapper + + +def save_data( + _id, n, + x_pkl_file, s_pkl_file, pmetric_pkl_file, + ref_pdb, rec_chain, lig_chain, ref_save_dir, cand_save_dir, + seq_only, struct_only, backbone_only + ): + + X, S, pmetric = pkl.load(open(x_pkl_file, 'rb')), pkl.load(open(s_pkl_file, 'rb')), pkl.load(open(pmetric_pkl_file, 'rb')) + os.remove(x_pkl_file), os.remove(s_pkl_file), os.remove(pmetric_pkl_file) + if seq_only: + X = None + elif struct_only: + S = None + rec_blocks, lig_blocks = pdb_to_list_blocks(ref_pdb, selected_chains=[rec_chain, lig_chain]) + ref_pdb = os.path.join(ref_save_dir, _id + "_ref.pdb") + list_blocks_to_pdb([rec_blocks, lig_blocks], [rec_chain, lig_chain], ref_pdb) + # os.system(f'cp {ref_pdb} {os.path.join(ref_save_dir, _id + "_ref.pdb")}') + ref_seq = ''.join([VOCAB.abrv_to_symbol(block.abrv) for block in lig_blocks]) + lig_blocks = overwrite_blocks(lig_blocks, S, X) + gen_seq = ''.join([VOCAB.abrv_to_symbol(block.abrv) for block in lig_blocks]) + save_dir = os.path.join(cand_save_dir, _id) + if not os.path.exists(save_dir): + os.makedirs(save_dir) + gen_pdb = os.path.join(save_dir, _id + f'_gen_{n}.pdb') + list_blocks_to_pdb([rec_blocks, lig_blocks], [rec_chain, lig_chain], gen_pdb) + + return { + 'id': _id, + 'number': n, + 'gen_pdb': gen_pdb, + 'ref_pdb': ref_pdb, + 'pmetric': pmetric, + 'rec_chain': rec_chain, + 'lig_chain': lig_chain, + 'ref_seq': ref_seq, + 'gen_seq': gen_seq, + 'seq_only': seq_only, + 'struct_only': struct_only, + 'backbone_only': backbone_only + } + + +def main(args, opt_args): + config = yaml.safe_load(open(args.config, 'r')) + config = overwrite_values(config, opt_args) + struct_only = config.get('struct_only', False) + seq_only = config.get('seq_only', False) + assert not (seq_only and struct_only) + backbone_only = config.get('backbone_only', False) + # load model + b_ckpt = args.ckpt if args.ckpt.endswith('.ckpt') else get_best_ckpt(args.ckpt) + ckpt_dir = os.path.split(os.path.split(b_ckpt)[0])[0] + print(f'Using checkpoint {b_ckpt}') + model = torch.load(b_ckpt, map_location='cpu') + device = torch.device('cpu' if args.gpu == -1 else f'cuda:{args.gpu}') + model.to(device) + model.eval() + + # load data + _, _, test_set = create_dataset(config['dataset']) + test_loader = create_dataloader(test_set, config['dataloader']) + + # save path + if args.save_dir is None: + save_dir = os.path.join(ckpt_dir, 'results') + else: + save_dir = args.save_dir + ref_save_dir = os.path.join(save_dir, 'references') + cand_save_dir = os.path.join(save_dir, 'candidates') + for directory in [ref_save_dir, cand_save_dir]: + if not os.path.exists(directory): + os.makedirs(directory) + + + fout = open(os.path.join(save_dir, 'results.jsonl'), 'w') + item_idx = 0 + + # multiprocessing + pool = Pool(args.n_cpu) + + n_samples = config.get('n_samples', 1) + + pbar = tqdm(total=n_samples * len(test_loader)) + for n in range(n_samples): + item_idx = 0 + with torch.no_grad(): + for batch in test_loader: + batch = to_device(batch, device) + batch_X, batch_S, batch_pmetric = generate_wrapper(model, deepcopy(config.get('sample_opt', {})))(batch) + + # parallel + inputs = [] + for X, S, pmetric in zip(batch_X, batch_S, batch_pmetric): + _id, ref_pdb, rec_chain, lig_chain = test_set.get_summary(item_idx) + # save temporary pickle file + x_pkl_file = os.path.join(save_dir, _id + f'_gen_{n}_X.pkl') + pkl.dump(X, open(x_pkl_file, 'wb')) + s_pkl_file = os.path.join(save_dir, _id + f'_gen_{n}_S.pkl') + pkl.dump(S, open(s_pkl_file, 'wb')) + pmetric_pkl_file = os.path.join(save_dir, _id + f'_gen_{n}_pmetric.pkl') + pkl.dump(pmetric, open(pmetric_pkl_file, 'wb')) + inputs.append(( + _id, n, + x_pkl_file, s_pkl_file, pmetric_pkl_file, + ref_pdb, rec_chain, lig_chain, ref_save_dir, cand_save_dir, + seq_only, struct_only, backbone_only + )) + item_idx += 1 + + results = pool.starmap(save_data, inputs) + for result in results: + fout.write(json.dumps(result) + '\n') + + pbar.update(1) + + fout.close() + + +def parse(): + parser = argparse.ArgumentParser(description='Generate peptides given epitopes') + parser.add_argument('--config', type=str, required=True, help='Path to the test configuration') + parser.add_argument('--ckpt', type=str, required=True, help='Path to checkpoint') + parser.add_argument('--save_dir', type=str, default=None, help='Directory to save generated peptides') + + parser.add_argument('--gpu', type=int, default=0, help='GPU to use, -1 for cpu') + parser.add_argument('--n_cpu', type=int, default=4, help='Number of CPU to use (for parallelly saving the generated results)') + return parser.parse_known_args() + + +if __name__ == '__main__': + args, opt_args = parse() + print_log(f'Overwritting args: {opt_args}') + setup_seed(12) + main(args, opt_args) diff --git a/globals.py b/globals.py new file mode 100644 index 0000000000000000000000000000000000000000..a1464bd0d9163003c6dd19f81c6a6f2d01f597a7 --- /dev/null +++ b/globals.py @@ -0,0 +1,21 @@ +#!/usr/bin/python +# -*- coding:utf-8 -*- +import os +''' +Two parts: +1. basic variables +2. benchmark definitions and configs for data processing +''' +# 1. basic variables +PROJ_DIR = os.path.split(__file__)[0] +# cache directory +CACHE_DIR = os.path.join(PROJ_DIR, '__cache__') +if not os.path.exists(CACHE_DIR): + os.makedirs(CACHE_DIR) + +# DockQ +# IMPORTANT: change it to your path to DockQ project) +DOCKQ_DIR = os.path.join(PROJ_DIR, 'evaluation', 'DockQ') +if not os.path.exists(DOCKQ_DIR): + os.system(f'cd {os.path.join(PROJ_DIR, "evaluation")}; git clone --branch v1.0 --depth 1 https://github.com/bjornwallner/DockQ.git') + os.system(f'cd {DOCKQ_DIR}; make') diff --git a/models/LDM/diffusion/dpm_full.py b/models/LDM/diffusion/dpm_full.py new file mode 100644 index 0000000000000000000000000000000000000000..69728ea25ad836481dfa6b926d3901a05801140c --- /dev/null +++ b/models/LDM/diffusion/dpm_full.py @@ -0,0 +1,342 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import functools +from tqdm.auto import tqdm + +from torch.autograd import grad +from torch_scatter import scatter_mean + +from utils.nn_utils import variadic_meshgrid + +from .transition import construct_transition + +from ...dyMEAN.modules.am_egnn import AMEGNN +from ...dyMEAN.modules.radial_basis import RadialBasis + + +def low_trianguler_inv(L): + # L: [bs, 3, 3] + L_inv = torch.linalg.solve_triangular(L, torch.eye(3).unsqueeze(0).expand_as(L).to(L.device), upper=False) + return L_inv + + +class EpsilonNet(nn.Module): + + def __init__( + self, + input_size, + hidden_size, + n_channel, + n_layers=3, + edge_size=0, + n_rbf=0, + cutoff=1.0, + dropout=0.1, + additional_pos_embed=True + ): + super().__init__() + + atom_embed_size = hidden_size // 4 + edge_embed_size = hidden_size // 4 + pos_embed_size, seg_embed_size = input_size, input_size + # enc_input_size = input_size + seg_embed_size + 3 + (pos_embed_size if additional_pos_embed else 0) + enc_input_size = input_size + 3 + (pos_embed_size if additional_pos_embed else 0) + self.encoder = AMEGNN( + enc_input_size, hidden_size, hidden_size, n_channel, + channel_nf=atom_embed_size, radial_nf=hidden_size, + in_edge_nf=edge_embed_size + edge_size, n_layers=n_layers, residual=True, + dropout=dropout, dense=False, n_rbf=n_rbf, cutoff=cutoff) + self.hidden2input = nn.Linear(hidden_size, input_size) + # self.pos_embed2latent = nn.Linear(hidden_size, pos_embed_size) + # self.segment_embedding = nn.Embedding(2, seg_embed_size) + self.edge_embedding = nn.Embedding(2, edge_embed_size) + + def forward( + self, H_noisy, X_noisy, position_embedding, ctx_edges, inter_edges, + atom_embeddings, atom_weights, mask_generate, beta, + ctx_edge_attr=None, inter_edge_attr=None): + """ + Args: + H_noisy: (N, hidden_size) + X_noisy: (N, 14, 3) + mask_generate: (N) + batch_ids: (N) + beta: (N) + Returns: + eps_H: (N, hidden_size) + eps_X: (N, 14, 3) + """ + t_embed = torch.stack([beta, torch.sin(beta), torch.cos(beta)], dim=-1) + # seg_embed = self.segment_embedding(mask_generate.long()) + if position_embedding is None: + # in_feat = torch.cat([H_noisy, t_embed, seg_embed], dim=-1) # [N, hidden_size * 2 + 3] + in_feat = torch.cat([H_noisy, t_embed], dim=-1) # [N, hidden_size * 2 + 3] + else: + # in_feat = torch.cat([H_noisy, t_embed, self.pos_embed2latent(position_embedding), seg_embed], dim=-1) # [N, hidden_size * 3 + 3] + in_feat = torch.cat([H_noisy, t_embed, position_embedding], dim=-1) # [N, hidden_size * 3 + 3] + edges = torch.cat([ctx_edges, inter_edges], dim=-1) + edge_embed = torch.cat([ + torch.zeros_like(ctx_edges[0]), torch.ones_like(inter_edges[0]) + ], dim=-1) + edge_embed = self.edge_embedding(edge_embed) + if ctx_edge_attr is None: + edge_attr = edge_embed + else: + edge_attr = torch.cat([ + edge_embed, + torch.cat([ctx_edge_attr, inter_edge_attr], dim=0)], + dim=-1 + ) # [E, embed size + edge_attr_size] + next_H, next_X = self.encoder(in_feat, X_noisy, edges, ctx_edge_attr=edge_attr, channel_attr=atom_embeddings, channel_weights=atom_weights) + + # equivariant vector features changes + eps_X = next_X - X_noisy + eps_X = torch.where(mask_generate[:, None, None].expand_as(eps_X), eps_X, torch.zeros_like(eps_X)) + + # invariant scalar features changes + next_H = self.hidden2input(next_H) + eps_H = next_H - H_noisy + eps_H = torch.where(mask_generate[:, None].expand_as(eps_H), eps_H, torch.zeros_like(eps_H)) + + return eps_H, eps_X + + +class FullDPM(nn.Module): + + def __init__( + self, + latent_size, + hidden_size, + n_channel, + num_steps, + n_layers=3, + dropout=0.1, + trans_pos_type='Diffusion', + trans_seq_type='Diffusion', + trans_pos_opt={}, + trans_seq_opt={}, + n_rbf=0, + cutoff=1.0, + std=10.0, + additional_pos_embed=True, + dist_rbf=0, + dist_rbf_cutoff=7.0 + ): + super().__init__() + self.eps_net = EpsilonNet( + latent_size, hidden_size, n_channel, n_layers=n_layers, edge_size=dist_rbf, + n_rbf=n_rbf, cutoff=cutoff, dropout=dropout, additional_pos_embed=additional_pos_embed) + if dist_rbf > 0: + self.dist_rbf = RadialBasis(dist_rbf, dist_rbf_cutoff) + self.num_steps = num_steps + self.trans_x = construct_transition(trans_pos_type, num_steps, trans_pos_opt) + self.trans_h = construct_transition(trans_seq_type, num_steps, trans_seq_opt) + + self.register_buffer('std', torch.tensor(std, dtype=torch.float)) + + def _normalize_position(self, X, batch_ids, mask_generate, atom_mask, L=None): + ctx_mask = (~mask_generate[:, None].expand_as(atom_mask)) & atom_mask + ctx_mask[:, 0] = 0 + ctx_mask[:, 2:] = 0 # only retain CA + centers = scatter_mean(X[ctx_mask], batch_ids[:, None].expand_as(ctx_mask)[ctx_mask], dim=0) # [bs, 3] + centers = centers[batch_ids].unsqueeze(1) # [N, 1, 3] + if L is None: + X = (X - centers) / self.std + else: + with torch.no_grad(): + L_inv = low_trianguler_inv(L) + # print(L_inv[0]) + X = X - centers + X = torch.matmul(L_inv[batch_ids][..., None, :, :], X.unsqueeze(-1)).squeeze(-1) + return X, centers + + def _unnormalize_position(self, X_norm, centers, batch_ids, L=None): + if L is None: + X = X_norm * self.std + centers + else: + X = torch.matmul(L[batch_ids][..., None, :, :], X_norm.unsqueeze(-1)).squeeze(-1) + centers + return X + + @torch.no_grad() + def _get_batch_ids(self, mask_generate, lengths): + + # batch ids + batch_ids = torch.zeros_like(mask_generate).long() + batch_ids[torch.cumsum(lengths, dim=0)[:-1]] = 1 + batch_ids.cumsum_(dim=0) + + return batch_ids + + @torch.no_grad() + def _get_edges(self, mask_generate, batch_ids, lengths): + row, col = variadic_meshgrid( + input1=torch.arange(batch_ids.shape[0], device=batch_ids.device), + size1=lengths, + input2=torch.arange(batch_ids.shape[0], device=batch_ids.device), + size2=lengths, + ) # (row, col) + + is_ctx = mask_generate[row] == mask_generate[col] + is_inter = ~is_ctx + ctx_edges = torch.stack([row[is_ctx], col[is_ctx]], dim=0) # [2, Ec] + inter_edges = torch.stack([row[is_inter], col[is_inter]], dim=0) # [2, Ei] + return ctx_edges, inter_edges + + @torch.no_grad() + def _get_edge_dist(self, X, edges, atom_mask): + ''' + Args: + X: [N, 14, 3] + edges: [2, E] + atom_mask: [N, 14] + ''' + ca_x = X[:, 1] # [N, 3] + no_ca_mask = torch.logical_not(atom_mask[:, 1]) # [N] + ca_x[no_ca_mask] = X[:, 0][no_ca_mask] # latent coordinates + dist = torch.norm(ca_x[edges[0]] - ca_x[edges[1]], dim=-1) # [N] + return dist + + def forward(self, H_0, X_0, position_embedding, mask_generate, lengths, atom_embeddings, atom_mask, L=None, t=None, sample_structure=True, sample_sequence=True): + # if L is not None: + # L = L / self.std + batch_ids = self._get_batch_ids(mask_generate, lengths) + batch_size = batch_ids.max() + 1 + if t == None: + t = torch.randint(0, self.num_steps + 1, (batch_size,), dtype=torch.long, device=H_0.device) + X_0, centers = self._normalize_position(X_0, batch_ids, mask_generate, atom_mask, L) + + if sample_structure: + X_noisy, eps_X = self.trans_x.add_noise(X_0, mask_generate, batch_ids, t) + else: + X_noisy, eps_X = X_0, torch.zeros_like(X_0) + if sample_sequence: + H_noisy, eps_H = self.trans_h.add_noise(H_0, mask_generate, batch_ids, t) + else: + H_noisy, eps_H = H_0, torch.zeros_like(H_0) + + ctx_edges, inter_edges = self._get_edges(mask_generate, batch_ids, lengths) + if hasattr(self, 'dist_rbf'): + ctx_edge_attr = self._get_edge_dist(self._unnormalize_position(X_noisy, centers, batch_ids, L), ctx_edges, atom_mask) + inter_edge_attr = self._get_edge_dist(self._unnormalize_position(X_noisy, centers, batch_ids, L), inter_edges, atom_mask) + ctx_edge_attr = self.dist_rbf(ctx_edge_attr).view(ctx_edges.shape[1], -1) + inter_edge_attr = self.dist_rbf(inter_edge_attr).view(inter_edges.shape[1], -1) + else: + ctx_edge_attr, inter_edge_attr = None, None + + beta = self.trans_x.get_timestamp(t)[batch_ids] # [N] + eps_H_pred, eps_X_pred = self.eps_net( + H_noisy, X_noisy, position_embedding, ctx_edges, inter_edges, atom_embeddings, atom_mask.float(), mask_generate, beta, + ctx_edge_attr=ctx_edge_attr, inter_edge_attr=inter_edge_attr) + + loss_dict = {} + + # equivariant vector feature loss, TODO: latent channel + if sample_structure: + mask_loss = mask_generate[:, None] & atom_mask + loss_X = F.mse_loss(eps_X_pred[mask_loss], eps_X[mask_loss], reduction='none').sum(dim=-1) # (Ntgt * n_latent_channel) + loss_X = loss_X.sum() / (mask_loss.sum().float() + 1e-8) + loss_dict['X'] = loss_X + else: + loss_dict['X'] = 0 + + # invariant scalar feature loss + if sample_sequence: + loss_H = F.mse_loss(eps_H_pred[mask_generate], eps_H[mask_generate], reduction='none').sum(dim=-1) # [N] + loss_H = loss_H.sum() / (mask_generate.sum().float() + 1e-8) + loss_dict['H'] = loss_H + else: + loss_dict['H'] = 0 + + return loss_dict + + @torch.no_grad() + def sample(self, H, X, position_embedding, mask_generate, lengths, atom_embeddings, atom_mask, + L=None, sample_structure=True, sample_sequence=True, pbar=False, energy_func=None, energy_lambda=0.01 + ): + """ + Args: + H: contextual hidden states, (N, latent_size) + X: contextual atomic coordinates, (N, 14, 3) + L: cholesky decomposition of the covariance matrix \Sigma=LL^T, (bs, 3, 3) + energy_func: guide diffusion towards lower energy landscape + """ + # if L is not None: + # L = L / self.std + batch_ids = self._get_batch_ids(mask_generate, lengths) + X, centers = self._normalize_position(X, batch_ids, mask_generate, atom_mask, L) + # print(X[0, 0]) + + # Set the orientation and position of residues to be predicted to random values + if sample_structure: + X_rand = torch.randn_like(X) # [N, 14, 3] + X_init = torch.where(mask_generate[:, None, None].expand_as(X), X_rand, X) + else: + X_init = X + + if sample_sequence: + H_rand = torch.randn_like(H) + H_init = torch.where(mask_generate[:, None].expand_as(H), H_rand, H) + else: + H_init = H + + # traj = {self.num_steps: (self._unnormalize_position(X_init, centers, batch_ids, L), H_init)} + traj = {self.num_steps: (X_init, H_init)} + if pbar: + pbar = functools.partial(tqdm, total=self.num_steps, desc='Sampling') + else: + pbar = lambda x: x + for t in pbar(range(self.num_steps, 0, -1)): + X_t, H_t = traj[t] + # X_t, _ = self._normalize_position(X_t, batch_ids, mask_generate, atom_mask, L) + X_t, H_t = torch.round(X_t, decimals=4), torch.round(H_t, decimals=4) # reduce numerical error + # print(t, 'input', X_t[0, 0] * 1000) + + # beta = self.trans_x.var_sched.betas[t].view(1).repeat(X_t.shape[0]) + beta = self.trans_x.get_timestamp(t).view(1).repeat(X_t.shape[0]) + t_tensor = torch.full([X_t.shape[0], ], fill_value=t, dtype=torch.long, device=X_t.device) + + ctx_edges, inter_edges = self._get_edges(mask_generate, batch_ids, lengths) + if hasattr(self, 'dist_rbf'): + ctx_edge_attr = self._get_edge_dist(self._unnormalize_position(X_t, centers, batch_ids, L), ctx_edges, atom_mask) + inter_edge_attr = self._get_edge_dist(self._unnormalize_position(X_t, centers, batch_ids, L), inter_edges, atom_mask) + ctx_edge_attr = self.dist_rbf(ctx_edge_attr).view(ctx_edges.shape[1], -1) + inter_edge_attr = self.dist_rbf(inter_edge_attr).view(inter_edges.shape[1], -1) + else: + ctx_edge_attr, inter_edge_attr = None, None + eps_H, eps_X = self.eps_net( + H_t, X_t, position_embedding, ctx_edges, inter_edges, atom_embeddings, atom_mask.float(), mask_generate, beta, + ctx_edge_attr=ctx_edge_attr, inter_edge_attr=inter_edge_attr) + if energy_func is not None: + with torch.enable_grad(): + cur_X_state = X_t.clone().double() + cur_X_state.requires_grad = True + energy = energy_func( + X=self._unnormalize_position(cur_X_state, centers.double(), batch_ids, L.double()), + mask_generate=mask_generate, batch_ids=batch_ids) + energy_eps_X = grad([energy], [cur_X_state], create_graph=False, retain_graph=False)[0].float() + # print(energy_lambda, energy / mask_generate.sum()) + energy_eps_X[~mask_generate] = 0 + energy_eps_X = -energy_eps_X + # print(t, 'energy', energy_eps_X[mask_generate][0, 0] * 1000) + else: + energy_eps_X = None + + # print(t, 'eps X', eps_X[mask_generate][0, 0] * 1000) + H_next = self.trans_h.denoise(H_t, eps_H, mask_generate, batch_ids, t_tensor) + X_next = self.trans_x.denoise(X_t, eps_X, mask_generate, batch_ids, t_tensor, guidance=energy_eps_X, guidance_weight=energy_lambda) + # print(t, 'output', X_next[mask_generate][0, 0] * 1000) + # if t == 90: + # aa + + if not sample_structure: + X_next = X_t + if not sample_sequence: + H_next = H_t + + # traj[t-1] = (self._unnormalize_position(X_next, centers, batch_ids, L), H_next) + traj[t-1] = (X_next, H_next) + traj[t] = (self._unnormalize_position(traj[t][0], centers, batch_ids, L).cpu(), traj[t][1].cpu()) + # traj[t] = tuple(x.cpu() for x in traj[t]) # Move previous states to cpu memory. + traj[0] = (self._unnormalize_position(traj[0][0], centers, batch_ids, L), traj[0][1]) + return traj diff --git a/models/LDM/diffusion/transition.py b/models/LDM/diffusion/transition.py new file mode 100644 index 0000000000000000000000000000000000000000..71d6ccea53769f0364163854c036efaaec0c3584 --- /dev/null +++ b/models/LDM/diffusion/transition.py @@ -0,0 +1,153 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def construct_transition(_type, num_steps, opt): + if _type == 'Diffusion': + return ContinuousTransition(num_steps, opt) + elif _type == 'FlowMatching': + return FlowMatchingTransition(num_steps, opt) + else: + raise NotImplementedError(f'transition type {_type} not implemented') + + +class VarianceSchedule(nn.Module): + + def __init__(self, num_steps=100, s=0.01): + super().__init__() + T = num_steps + t = torch.arange(0, num_steps+1, dtype=torch.float) + f_t = torch.cos( (np.pi / 2) * ((t/T) + s) / (1 + s) ) ** 2 + alpha_bars = f_t / f_t[0] + + betas = 1 - (alpha_bars[1:] / alpha_bars[:-1]) + betas = torch.cat([torch.zeros([1]), betas], dim=0) + betas = betas.clamp_max(0.999) + + sigmas = torch.zeros_like(betas) + for i in range(1, betas.size(0)): + sigmas[i] = ((1 - alpha_bars[i-1]) / (1 - alpha_bars[i])) * betas[i] + sigmas = torch.sqrt(sigmas) + + self.register_buffer('betas', betas) + self.register_buffer('alpha_bars', alpha_bars) + self.register_buffer('alphas', 1 - betas) + self.register_buffer('sigmas', sigmas) + + +class ContinuousTransition(nn.Module): + + def __init__(self, num_steps, var_sched_opt={}): + super().__init__() + self.var_sched = VarianceSchedule(num_steps, **var_sched_opt) + + def get_timestamp(self, t): + # use beta as timestamp + return self.var_sched.betas[t] + + def add_noise(self, p_0, mask_generate, batch_ids, t): + """ + Args: + p_0: [N, ...] + mask_generate: [N] + batch_ids: [N] + t: [batch_size] + """ + expand_shape = [p_0.shape[0]] + [1 for _ in p_0.shape[1:]] + mask_generate = mask_generate.view(*expand_shape) + + alpha_bar = self.var_sched.alpha_bars[t] # [batch_size] + alpha_bar = alpha_bar[batch_ids] # [N] + + c0 = torch.sqrt(alpha_bar).view(*expand_shape) + c1 = torch.sqrt(1 - alpha_bar).view(*expand_shape) + + e_rand = torch.randn_like(p_0) # [N, 14, 3] + supervise_e_rand = e_rand.clone() + p_noisy = c0*p_0 + c1*e_rand + p_noisy = torch.where(mask_generate.expand_as(p_0), p_noisy, p_0) + + return p_noisy, supervise_e_rand + + def denoise(self, p_t, eps_p, mask_generate, batch_ids, t, guidance=None, guidance_weight=1.0): + # IMPORTANT: + # clampping alpha is to fix the instability issue at the first step (t=T) + # it seems like a problem with the ``improved ddpm''. + expand_shape = [p_t.shape[0]] + [1 for _ in p_t.shape[1:]] + mask_generate = mask_generate.view(*expand_shape) + + alpha = self.var_sched.alphas[t].clamp_min( + self.var_sched.alphas[-2] + )[batch_ids] + alpha_bar = self.var_sched.alpha_bars[t][batch_ids] + sigma = self.var_sched.sigmas[t][batch_ids].view(*expand_shape) + + c0 = ( 1.0 / torch.sqrt(alpha + 1e-8) ).view(*expand_shape) + c1 = ( (1 - alpha) / torch.sqrt(1 - alpha_bar + 1e-8) ).view(*expand_shape) + + z = torch.where( + (t > 1).view(*expand_shape).expand_as(p_t), + torch.randn_like(p_t), + torch.zeros_like(p_t), + ) + + if guidance is not None: + eps_p = eps_p - torch.sqrt(1 - alpha_bar).view(*expand_shape) * guidance + + # if guidance is not None: + # p_next = c0 * (p_t - c1 * eps_p) + sigma * z + sigma * sigma * guidance_weight * guidance + # else: + # p_next = c0 * (p_t - c1 * eps_p) + sigma * z + p_next = c0 * (p_t - c1 * eps_p) + sigma * z + p_next = torch.where(mask_generate.expand_as(p_t), p_next, p_t) + return p_next + + +# TODO: flow matching (uniform or OT), not done yet +class FlowMatchingTransition(nn.Module): + + def __init__(self, num_steps, opt={}): + super().__init__() + self.num_steps = num_steps + # TODO: number of steps T or T + 1 + c1 = torch.arange(0, num_steps + 1).float() / num_steps + c0 = 1 - c1 + self.register_buffer('c0', c0) + self.register_buffer('c1', c1) + + def get_timestamp(self, t): + # use c1 as timestamp + return self.c1[t] + + def add_noise(self, p_0, mask_generate, batch_ids, t): + """ + Args: + p_0: [N, ...] + mask_generate: [N] + batch_ids: [N] + t: [batch_size] + """ + expand_shape = [p_0.shape[0]] + [1 for _ in p_0.shape[1:]] + mask_generate = mask_generate.view(*expand_shape) + + c0 = self.c0[t][batch_ids].view(*expand_shape) + c1 = self.c1[t][batch_ids].view(*expand_shape) + + e_rand = torch.randn_like(p_0) # [N, 14, 3] + p_noisy = c0*p_0 + c1*e_rand + p_noisy = torch.where(mask_generate.expand_as(p_0), p_noisy, p_0) + + return p_noisy, (e_rand - p_0) + + def denoise(self, p_t, eps_p, mask_generate, batch_ids, t): + # IMPORTANT: + # clampping alpha is to fix the instability issue at the first step (t=T) + # it seems like a problem with the ``improved ddpm''. + expand_shape = [p_t.shape[0]] + [1 for _ in p_t.shape[1:]] + mask_generate = mask_generate.view(*expand_shape) + + p_next = p_t - eps_p / self.num_steps + p_next = torch.where(mask_generate.expand_as(p_t), p_next, p_t) + return p_next diff --git a/models/LDM/energies/dist.py b/models/LDM/energies/dist.py new file mode 100644 index 0000000000000000000000000000000000000000..1a59e790432c19cc5738588d5f46cec3679ab908 --- /dev/null +++ b/models/LDM/energies/dist.py @@ -0,0 +1,81 @@ +#!/usr/bin/python +# -*- coding:utf-8 -*- +import torch +import torch.nn.functional as F + +from data.format import VOCAB +from utils.nn_utils import graph_to_batch + + +@torch.no_grad() +def continuous_bool(x, k=1000): + return (x > 0).float() + + +def _consec_dist_loss(gen_X, gen_X_mask, lb, ub, eps=1e-6): + consec_dist = torch.norm(gen_X[..., 1:, :] - gen_X[..., :-1, :], dim=-1) # [bs, max_L - 1] + consec_lb_loss = lb - consec_dist # [bs, max_L - 1] + consec_ub_loss = consec_dist - ub # [bs, max_L - 1] + + consec_lb_invalid = (consec_dist < lb) & gen_X_mask[..., 1:] + consec_ub_invalid = (consec_dist > ub) & gen_X_mask[..., 1:] + consec_loss = torch.where(consec_lb_invalid, consec_lb_loss, torch.zeros_like(consec_lb_loss)) + consec_loss = torch.where(consec_ub_invalid, consec_ub_loss, consec_loss) + + consec_loss = consec_loss.sum(-1) / (consec_lb_invalid + consec_ub_invalid + eps).sum(-1) + consec_loss = torch.sum(consec_loss) # consistent loss scale across different batch size + return consec_loss + + +def _inner_clash_loss(gen_X, gen_X_mask, mean, eps=1e-6): + dist = torch.norm(gen_X[..., :, None, :] - gen_X[..., None, :, :], dim=-1) # [bs, max_L, max_L] + dist_mask = gen_X_mask[..., :, None] & gen_X_mask[..., None, :] # [bs, max_L, max_L] + pos = torch.cumsum(torch.ones_like(gen_X_mask, dtype=torch.long), dim=-1) # [bs, max_L] + non_consec_mask = torch.abs(pos[..., :, None] - pos[..., None, :]) > 1 # [bs, max_L, max_L] + + clash_loss = mean - dist + clash_loss_mask = (clash_loss > 0) & dist_mask & non_consec_mask # [bs, max_L, max_L] + clash_loss = torch.where(clash_loss_mask, clash_loss, torch.zeros_like(clash_loss)) + + clash_loss = clash_loss.sum(-1).sum(-1) / (clash_loss_mask.sum(-1).sum(-1) + eps) + clash_loss = torch.sum(clash_loss) # consistent loss scale across different residue number and batch size + return clash_loss + + +def _outer_clash_loss(ctx_X, ctx_X_mask, gen_X, gen_X_mask, mean, eps=1e-6): + dist = torch.norm(gen_X[..., :, None, :] - ctx_X[..., None, :, :], dim=-1) # [bs, max_gen_L, max_ctx_L] + dist_mask = gen_X_mask[..., :, None] & ctx_X_mask[..., None, :] # [bs, max_gen_L, max_ctx_L] + clash_loss = mean - dist # [bs, max_gen_L, max_ctx_L] + clash_loss_mask = (clash_loss > 0) & dist_mask # [bs, max_gen_L, max_ctx_L] + clash_loss = torch.where(clash_loss_mask, clash_loss, torch.zeros_like(clash_loss)) + + clash_loss = clash_loss.sum(-1).sum(-1) / (clash_loss_mask.sum(-1).sum(-1) + eps) + clash_loss = torch.sum(clash_loss) # consistent loss scale across different residue number and batch size + return clash_loss + + +def dist_energy(X, mask_generate, batch_ids, mean, std, tolerance=3, **kwargs): + mean, std = round(mean, 4), round(std, 4) + lb, ub = mean - tolerance * std, mean + tolerance * std + + X = X.clone() # [N, 3] + + ctx_X, ctx_batch_ids = X[~mask_generate], batch_ids[~mask_generate] + gen_X, gen_batch_ids = X[mask_generate], batch_ids[mask_generate] + ctx_X = ctx_X[:, VOCAB.ca_channel_idx] # CA (alpha carbon) + gen_X = gen_X[:, 0] # latent one + + # to batch representation + ctx_X, ctx_X_mask = graph_to_batch(ctx_X, ctx_batch_ids, mask_is_pad=False) # [bs, max_ctx_L, 3] + gen_X, gen_X_mask = graph_to_batch(gen_X, gen_batch_ids, mask_is_pad=False) # [bs, max_gen_L, 3] + + # consecutive + consec_loss = _consec_dist_loss(gen_X, gen_X_mask, lb, ub) + + # inner clash + inner_clash_loss = _inner_clash_loss(gen_X, gen_X_mask, mean) + + # outer clash + outer_clash_loss = _outer_clash_loss(ctx_X, ctx_X_mask, gen_X, gen_X_mask, mean) + + return consec_loss + inner_clash_loss + outer_clash_loss diff --git a/models/LDM/ldm.py b/models/LDM/ldm.py new file mode 100644 index 0000000000000000000000000000000000000000..10d404e0a5c71d07cdc1ac9f0e8df82057fc6c7b --- /dev/null +++ b/models/LDM/ldm.py @@ -0,0 +1,208 @@ +#!/usr/bin/python +# -*- coding:utf-8 -*- +import enum + +import torch +import torch.nn as nn + +import utils.register as R +from utils.oom_decorator import oom_decorator +from data.format import VOCAB + +from .diffusion.dpm_full import FullDPM +from .energies.dist import dist_energy +from ..autoencoder.model import AutoEncoder + + +@R.register('LDMPepDesign') +class LDMPepDesign(nn.Module): + + def __init__( + self, + autoencoder_ckpt, + autoencoder_no_randomness, + hidden_size, + num_steps, + n_layers, + dist_rbf=0, + dist_rbf_cutoff=7.0, + n_rbf=0, + cutoff=1.0, + max_gen_position=30, + mode='codesign', + h_loss_weight=None, + diffusion_opt={}): + super().__init__() + self.autoencoder_no_randomness = autoencoder_no_randomness + self.latent_idx = VOCAB.symbol_to_idx(VOCAB.LAT) + + self.autoencoder: AutoEncoder = torch.load(autoencoder_ckpt, map_location='cpu') + for param in self.autoencoder.parameters(): + param.requires_grad = False + self.autoencoder.eval() + + self.train_sequence, self.train_structure = True, True + if mode == 'fixbb': + self.train_structure = False + elif mode == 'fixseq': + self.train_sequence = False + + latent_size = self.autoencoder.latent_size if self.train_sequence else hidden_size + + self.abs_position_encoding = nn.Embedding(max_gen_position, latent_size) + self.diffusion = FullDPM( + latent_size=latent_size, + hidden_size=hidden_size, + n_channel=self.autoencoder.n_channel, + num_steps=num_steps, + n_layers=n_layers, + n_rbf=n_rbf, + cutoff=cutoff, + dist_rbf=dist_rbf, + dist_rbf_cutoff=dist_rbf_cutoff, + **diffusion_opt + ) + if self.train_sequence: + self.hidden2latent = nn.Linear(hidden_size, self.autoencoder.latent_size) + if h_loss_weight is None: + self.h_loss_weight = self.autoencoder.latent_n_channel * 3 / self.autoencoder.latent_size # make loss_X and loss_H about the same size + else: + self.h_loss_weight = h_loss_weight + if self.train_structure: + # for better constrained sampling + self.consec_dist_mean, self.consec_dist_std = None, None + + @oom_decorator + def forward(self, X, S, mask, position_ids, lengths, atom_mask, L=None): + ''' + L: [bs, 3, 3], cholesky decomposition of the covariance matrix \Sigma = LL^T + ''' + + # encode latent_H_0 (N*d) and latent_X_0 (N*3) + with torch.no_grad(): + self.autoencoder.eval() + H, Z, _, _ = self.autoencoder.encode(X, S, mask, position_ids, lengths, atom_mask, no_randomness=self.autoencoder_no_randomness) + + # diffusion model + if self.train_sequence: + S = S.clone() + S[mask] = self.latent_idx + + with torch.no_grad(): + H_0, (atom_embeddings, _) = self.autoencoder.aa_feature(S, position_ids) + position_embedding = self.abs_position_encoding(torch.where(mask, position_ids + 1, torch.zeros_like(position_ids))) + + if self.train_sequence: + H_0 = self.hidden2latent(H_0) + H_0 = H_0.clone() + H_0[mask] = H + + if self.train_structure: + X = X.clone() + X[mask] = self.autoencoder._fill_latent_channels(Z) + atom_mask = atom_mask.clone() + atom_mask_gen = atom_mask[mask] + atom_mask_gen[:, :self.autoencoder.latent_n_channel] = 1 + atom_mask_gen[:, self.autoencoder.latent_n_channel:] = 0 + atom_mask[mask] = atom_mask_gen + del atom_mask_gen + else: # fixbb, only retain backbone atoms in masked region + atom_mask = self.autoencoder._remove_sidechain_atom_mask(atom_mask, mask) + + loss_dict = self.diffusion.forward( + H_0=H_0, + X_0=X, + position_embedding=position_embedding, + mask_generate=mask, + lengths=lengths, + atom_embeddings=atom_embeddings, + atom_mask=atom_mask, + L=L, + sample_structure=self.train_structure, + sample_sequence=self.train_sequence + ) + + # loss + loss = 0 + if self.train_sequence: + loss = loss + loss_dict['H'] * self.h_loss_weight + if self.train_structure: + loss = loss + loss_dict['X'] + + return loss, loss_dict + + def set_consec_dist(self, mean: float, std: float): + self.consec_dist_mean = mean + self.consec_dist_std = std + + def latent_geometry_guidance(self, X, mask_generate, batch_ids, tolerance=3, **kwargs): + assert self.consec_dist_mean is not None and self.consec_dist_std is not None, \ + 'Please run set_consec_dist(self, mean, std) to setup guidance parameters' + return dist_energy( + X, mask_generate, batch_ids, + self.consec_dist_mean, self.consec_dist_std, + tolerance=tolerance, **kwargs + ) + + @torch.no_grad() + def sample( + self, + X, S, mask, position_ids, lengths, atom_mask, L=None, + sample_opt={ + 'pbar': False, + 'energy_func': None, + 'energy_lambda': 0.0, + 'autoencoder_n_iter': 1 + }, + return_tensor=False, + optimize_sidechain=True, + ): + self.autoencoder.eval() + # diffusion sample + if self.train_sequence: + S = S.clone() + S[mask] = self.latent_idx + + H_0, (atom_embeddings, _) = self.autoencoder.aa_feature(S, position_ids) + position_embedding = self.abs_position_encoding(torch.where(mask, position_ids + 1, torch.zeros_like(position_ids))) + + if self.train_sequence: + H_0 = self.hidden2latent(H_0) + H_0 = H_0.clone() + H_0[mask] = 0 # no possibility for leakage + + if self.train_structure: + X = X.clone() + X[mask] = 0 + atom_mask = atom_mask.clone() + atom_mask_gen = atom_mask[mask] + atom_mask_gen[:, :self.autoencoder.latent_n_channel] = 1 + atom_mask_gen[:, self.autoencoder.latent_n_channel:] = 0 + atom_mask[mask] = atom_mask_gen + del atom_mask_gen + else: # fixbb, only retain backbone atoms in masked region + atom_mask = self.autoencoder._remove_sidechain_atom_mask(atom_mask, mask) + + sample_opt['sample_sequence'] = self.train_sequence + sample_opt['sample_structure'] = self.train_structure + if 'energy_func' in sample_opt: + if sample_opt['energy_func'] is None: + pass + elif sample_opt['energy_func'] == 'default': + sample_opt['energy_func'] = self.latent_geometry_guidance + # otherwise this should be a function + autoencoder_n_iter = sample_opt.pop('autoencoder_n_iter', 1) + + traj = self.diffusion.sample(H_0, X, position_embedding, mask, lengths, atom_embeddings, atom_mask, L, **sample_opt) + X_0, H_0 = traj[0] + X_0, H_0 = X_0[mask][:, :self.autoencoder.latent_n_channel], H_0[mask] + + # autodecoder decode + batch_X, batch_S, batch_ppls = self.autoencoder.test( + X, S, mask, position_ids, lengths, atom_mask, + given_laten_H=H_0, given_latent_X=X_0, return_tensor=return_tensor, + allow_unk=False, optimize_sidechain=optimize_sidechain, + n_iter=autoencoder_n_iter + ) + + return batch_X, batch_S, batch_ppls diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..947611bf89717e16ebc8c6f05a4a090447d812fa --- /dev/null +++ b/models/__init__.py @@ -0,0 +1,4 @@ +#!/usr/bin/python +# -*- coding:utf-8 -*- +from .autoencoder.model import AutoEncoder +from .LDM.ldm import LDMPepDesign diff --git a/models/autoencoder/backbone/api.py b/models/autoencoder/backbone/api.py new file mode 100644 index 0000000000000000000000000000000000000000..4c5970a2ad6c7c8d012fe79cf2bc2f99c9b417f6 --- /dev/null +++ b/models/autoencoder/backbone/api.py @@ -0,0 +1,33 @@ +#!/usr/bin/python +# -*- coding:utf-8 -*- +import torch +import torch.nn as nn + +from utils.nn_utils import graph_to_batch + +from .backbone import FrameBuilder + + +class BackboneModel(nn.Module): + def __init__(self) -> None: + super().__init__() + self.backbone_builder = FrameBuilder() + + def forward(self, X, batch_ids): + ''' + X: [N, 14, 3], predicted all-atom coordinates (obviously with a lot of invalidities) + assume the first 4 are N, CA, C, O + S: [N], predicted sequence + ''' + + # to batch-form representations + X, mask = graph_to_batch(X, batch_ids, mask_is_pad=False) + C = mask.long() + + # rectify backbones + R, t, q = self.backbone_builder.inverse(X, C) + X_bb = self.backbone_builder(R, t, C) # [bs, L, 4, 3] + X = torch.cat([X_bb, X[:, :, 4:]], dim=-2) # [bs, L, 14, 3] + + # get back to our graph representations + return X[mask] \ No newline at end of file diff --git a/models/autoencoder/backbone/backbone.py b/models/autoencoder/backbone/backbone.py new file mode 100644 index 0000000000000000000000000000000000000000..6788ee54786c7afb94f7b90caf800ffc32112058 --- /dev/null +++ b/models/autoencoder/backbone/backbone.py @@ -0,0 +1,146 @@ +#!/usr/bin/python +# -*- coding:utf-8 -*- +""" + Modified from https://github.com/generatebio/chroma/blob/main/chroma/layers/structure/backbone.py +""" +from typing import Optional, Tuple + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..sidechain.structure import geometry + + +def compose_translation( + R_a: torch.Tensor, t_a: torch.Tensor, t_b: torch.Tensor +) -> torch.Tensor: + """Compose translation component of `T_compose = T_a * T_b` (broadcastable). + + Args: + R_a (torch.Tensor): Transform `T_a` rotation matrix with shape `(...,3,3)`. + t_a (torch.Tensor): Transform `T_a` translation with shape `(...,3)`. + t_b (torch.Tensor): Transform `T_b` translation with shape `(...,3)`. + + Returns: + t_composed (torch.Tensor): Composed transform `a * b` translation vector with + shape `(...,3)`. + """ + t_composed = t_a + (R_a @ t_b.unsqueeze(-1)).squeeze(-1) + return t_composed + + +class FrameBuilder(nn.Module): + """Build protein backbones from rigid residue poses. + + Inputs: + R (torch.Tensor): Rotation of residue orientiations + with shape `(num_batch, num_residues, 3, 3)`. If `None`, + then `q` must be provided instead. + t (torch.Tensor): Translation of residue orientiations + with shape `(num_batch, num_residues, 3)`. This is the + location of the C-alpha coordinates. + C (torch.Tensor): Chain map with shape `(num_batch, num_residues)`. + q (Tensor, optional): Quaternions representing residue orientiations + with shape `(num_batch, num_residues, 4)`. + + Outputs: + X (torch.Tensor): All-atom protein coordinates with shape + `(num_batch, num_residues, 4, 3)` + """ + + def __init__(self, distance_eps: float = 1e-3): + super().__init__() + + # Build idealized backbone fragment + t = torch.tensor( + [ + [1.459, 0.0, 0.0], # N-C via Engh & Huber is 1.459 + [0.0, 0.0, 0.0], # CA is origin + [-0.547, 0.0, -1.424], # C is placed 1.525 A @ 111 degrees from N + ], + dtype=torch.float32, + ).reshape([1, 1, 3, 3]) + R = torch.eye(3).reshape([1, 1, 1, 3, 3]) + self.register_buffer("_t_atom", t) + self.register_buffer("_R_atom", R) + + # Carbonyl geometry from CHARMM all36_prot ALA definition + self._length_C_O = 1.2297 + self._angle_CA_C_O = 122.5200 + self._dihedral_Np_CA_C_O = 180 + self.distance_eps = distance_eps + + def _build_O(self, X_chain: torch.Tensor, C: torch.LongTensor): + """Build backbone carbonyl oxygen.""" + # Build carboxyl groups + X_N, X_CA, X_C = X_chain.unbind(-2) + + # TODO: fix this behavior for termini + mask_next = (C > 0).float()[:, 1:].unsqueeze(-1) + X_N_next = F.pad(mask_next * X_N[:, 1:,], (0, 0, 0, 1),) + + num_batch, num_residues = C.shape + ones = torch.ones(list(C.shape), dtype=torch.float32, device=C.device) + X_O = geometry.extend_atoms( + X_N_next, + X_CA, + X_C, + self._length_C_O * ones, + self._angle_CA_C_O * ones, + self._dihedral_Np_CA_C_O * ones, + degrees=True, + ) + mask = (C > 0).float().reshape(list(C.shape) + [1, 1]) + X = mask * torch.stack([X_N, X_CA, X_C, X_O], dim=-2) + return X + + def forward( + self, + R: torch.Tensor, + t: torch.Tensor, + C: torch.LongTensor, + q: Optional[torch.Tensor] = None, + ): + assert q is None or R is None + + if R is None: + # (B,N,1,3,3) and (B,N,1,3) + R = geometry.rotations_from_quaternions( + q, normalize=True, eps=self.distance_eps + ) + + R = R.unsqueeze(-3) + t_frame = t.unsqueeze(-2) + X_chain = compose_translation(R, t_frame, self._t_atom) + X = self._build_O(X_chain, C) + return X + + def inverse( + self, X: torch.Tensor, C: torch.LongTensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Reconstruct transformations from poses. + + Inputs: + X (torch.Tensor): All-atom protein coordinates with shape + `(num_batch, num_residues, 4, 3)` + C (torch.Tensor): Chain map with shape `(num_batch, num_residues)`. + + Outputs: + R (torch.Tensor): Rotation of residue orientiations + with shape `(num_batch, num_residues, 3, 3)`. + t (torch.Tensor): Translation of residue orientiations + with shape `(num_batch, num_residues, 3)`. This is the + location of the C-alpha coordinates. + q (torch.Tensor): Quaternions representing residue orientiations + with shape `(num_batch, num_residues, 4)`. + """ + X_bb = X[:, :, :4, :] + R, t = geometry.frames_from_backbone(X_bb, distance_eps=self.distance_eps) + q = geometry.quaternions_from_rotations(R, eps=self.distance_eps) + mask = (C > 0).float().unsqueeze(-1) + R = mask.unsqueeze(-1) * R + t = mask * t + q = mask * q + return R, t, q \ No newline at end of file diff --git a/models/autoencoder/model.py b/models/autoencoder/model.py new file mode 100644 index 0000000000000000000000000000000000000000..0f885a64dd8d8566ca10b31491e4905086ad0bd9 --- /dev/null +++ b/models/autoencoder/model.py @@ -0,0 +1,494 @@ +#!/usr/bin/python +# -*- coding:utf-8 -*- +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch_scatter import scatter_mean + +from data.format import VOCAB +from utils import register as R +from utils.oom_decorator import oom_decorator +from utils.const import aas +from utils.nn_utils import variadic_meshgrid + +from .sidechain.api import SideChainModel +from .backbone.api import BackboneModel + +from ..dyMEAN.modules.am_egnn import AMEGNN # adaptive-multichannel egnn +from ..dyMEAN.nn_utils import SeparatedAminoAcidFeature, ProteinFeature + + +def create_encoder( + name, + atom_embed_size, + embed_size, + hidden_size, + n_channel, + n_layers, + dropout, + n_rbf, + cutoff +): + if name == 'dyMEAN': + encoder = AMEGNN( + embed_size, hidden_size, hidden_size, n_channel, + channel_nf=atom_embed_size, radial_nf=hidden_size, + in_edge_nf=0, n_layers=n_layers, residual=True, + dropout=dropout, dense=False, n_rbf=n_rbf, cutoff=cutoff) + else: + raise NotImplementedError(f'Encoder {encoder} not implemented') + + return encoder + + + +@R.register('AutoEncoder') +class AutoEncoder(nn.Module): + def __init__( + self, + embed_size, + hidden_size, + latent_size, + n_channel, + latent_n_channel=1, + mask_id=VOCAB.get_mask_idx(), + latent_id=VOCAB.symbol_to_idx(VOCAB.LAT), + max_position=2048, + relative_position=False, + CA_channel_idx=VOCAB.backbone_atoms.index('CA'), + n_layers=3, + dropout=0.1, + mask_ratio=0.0, + fix_alpha_carbon=False, + h_kl_weight=0.1, + z_kl_weight=0.5, + coord_loss_weights={ + 'Xloss': 1.0, + 'ca_Xloss': 0.0, + 'bb_bond_lengths_loss': 1.0, + 'sc_bond_lengths_loss': 1.0, + 'bb_dihedral_angles_loss': 0.0, # this significantly poison the training + 'sc_chi_angles_loss': 0.5 + }, + coord_loss_ratio=0.5, # (1 - r)*seq + r * coord + coord_prior_var=1.0, # sigma^2 + anchor_at_ca=False, + share_decoder=False, + n_rbf=0, + cutoff=0, + encoder='dyMEAN', + mode='codesign', # codesign, fixbb (inverse folding), fixseq (structure prediction) + additional_noise_scale=0.0 # whether to add additional noise on coordinates to enhance robustness + ) -> None: + super().__init__() + self.mask_id = mask_id + self.latent_id = latent_id + self.ca_channel_idx = CA_channel_idx + self.n_channel = n_channel + self.mask_ratio = mask_ratio + self.fix_alpha_carbon = fix_alpha_carbon + self.h_kl_weight = h_kl_weight + self.z_kl_weight = z_kl_weight + self.coord_loss_weights = coord_loss_weights + self.coord_loss_ratio = coord_loss_ratio + self.mode = mode + self.latent_size = 0 if self.mode == 'fixseq' else latent_size + self.latent_n_channel = 0 if self.mode == 'fixbb' else latent_n_channel + self.anchor_at_ca = anchor_at_ca + self.coord_prior_var = coord_prior_var + self.additional_noise_scale = additional_noise_scale + + if self.fix_alpha_carbon: assert self.latent_n_channel == 1, f'Specifying fix alpha carbon (use Ca as the latent coordinate) but number of latent channels is not 1' + if self.anchor_at_ca: assert self.latent_n_channel == 1, f'Specifying anchor_at_ca as True but number of latent channels is not 1' + if self.mode == 'fixseq': assert self.coord_loss_ratio == 1.0, f'Specifying fixseq mode but coordination loss ratio is not 1.0: {self.coord_loss_ratio}' + if self.mode == 'fixbb': assert self.coord_loss_ratio == 0.0, f'Specifying fixbb mode but coordination loss ratio is not 0.0: {self.coord_loss_ratio}' + + atom_embed_size = embed_size // 4 + self.aa_feature = SeparatedAminoAcidFeature( + embed_size, atom_embed_size, + max_position=max_position, + relative_position=relative_position, + fix_atom_weights=True + ) + self.protein_feature = ProteinFeature() + + self.encoder = create_encoder( + name = encoder, + atom_embed_size = atom_embed_size, + embed_size = embed_size, + hidden_size = hidden_size, + n_channel = n_channel, + n_layers = n_layers, + dropout = dropout, + n_rbf = n_rbf, + cutoff = cutoff + ) + + if self.mode != 'fixbb': + self.sidechain_decoder = create_encoder( + name = encoder, + atom_embed_size = atom_embed_size, + embed_size = embed_size, + hidden_size = hidden_size, + n_channel = n_channel, + n_layers = n_layers, + dropout = dropout, + n_rbf = n_rbf, + cutoff = cutoff + ) + self.backbone_model = BackboneModel() + self.sidechain_model = SideChainModel() + self.W_Z_log_var = nn.Linear(hidden_size, latent_n_channel * 3) + + if self.mode != 'fixseq': + self.W_mean = nn.Linear(hidden_size, latent_size) + self.W_log_var = nn.Linear(hidden_size, latent_size) + # self.hidden2latent = nn.Linear(hidden_size, latent_size) + self.latent2hidden = nn.Linear(latent_size, hidden_size) + self.merge_S_H = nn.Linear(hidden_size * 2, hidden_size) + + if share_decoder: + self.seq_decoder = self.sidechain_decoder + else: + self.seq_decoder = create_encoder( + name = encoder, + atom_embed_size = atom_embed_size, + embed_size = embed_size, + hidden_size = hidden_size, + n_channel = n_channel, + n_layers = n_layers, + dropout = dropout, + n_rbf = n_rbf, + cutoff = cutoff + ) + + # residue type index mapping, from original index to 0~20, 0 is unk + self.unk_idx = 0 + self.s_map = [0 for _ in range(len(VOCAB))] + self.s_remap = [0 for _ in range(len(aas) + 1)] + self.s_remap[0] = VOCAB.symbol_to_idx(VOCAB.UNK) + for i, (a, _) in enumerate(aas): + original_idx = VOCAB.symbol_to_idx(a) + self.s_map[original_idx] = i + 1 # start from 1 + self.s_remap[i + 1] = original_idx + self.s_map = nn.Parameter(torch.tensor(self.s_map, dtype=torch.long), requires_grad=False) + self.s_remap = nn.Parameter(torch.tensor(self.s_remap, dtype=torch.long), requires_grad=False) + + if self.mode != 'fixseq': + self.seq_linear = nn.Linear(hidden_size, len(self.s_remap)) + + + @torch.no_grad() + def prepare_inputs(self, X, S, mask, atom_mask, lengths): + + # batch ids + batch_ids = self.get_batch_ids(S, lengths) + + # edges + row, col = variadic_meshgrid( + input1=torch.arange(batch_ids.shape[0], device=batch_ids.device), + size1=lengths, + input2=torch.arange(batch_ids.shape[0], device=batch_ids.device), + size2=lengths, + ) # (row, col) + + is_ctx = mask[row] == mask[col] + is_inter = ~is_ctx + ctx_edges = torch.stack([row[is_ctx], col[is_ctx]], dim=0) # [2, Ec] + inter_edges = torch.stack([row[is_inter], col[is_inter]], dim=0) # [2, Ei] + + return ctx_edges, inter_edges, batch_ids + + @torch.no_grad() + def get_batch_ids(self, S, lengths): + batch_ids = torch.zeros_like(S) + batch_ids[torch.cumsum(lengths, dim=0)[:-1]] = 1 + batch_ids.cumsum_(dim=0) + return batch_ids + + def rsample(self, H, Z, Z_centers, no_randomness=False): + ''' + H: [N, latent_size] + Z: [N, latent_channel, 3] + Z_centers: [N, latent_channel, 3] + ''' + + if self.mode != 'fixseq': + data_size = H.shape[0] + H_mean = self.W_mean(H) + H_log_var = -torch.abs(self.W_log_var(H)) #Following Mueller et al., z_log_var is log(\sigma^2) + H_kl_loss = -0.5 * torch.sum(1.0 + H_log_var - H_mean * H_mean - torch.exp(H_log_var)) / data_size + H_vecs = H_mean if no_randomness else H_mean + torch.exp(H_log_var / 2) * torch.randn_like(H_mean) + else: + H_vecs, H_kl_loss = None, 0 + + if self.mode != 'fixbb': + data_size = Z.shape[0] + Z_mean_delta = Z - Z_centers + Z_log_var = -torch.abs(self.W_Z_log_var(H)).view(-1, self.latent_n_channel, 3) + Z_kl_loss = -0.5 * torch.sum(1.0 + Z_log_var - math.log(self.coord_prior_var) - Z_mean_delta * Z_mean_delta / self.coord_prior_var - torch.exp(Z_log_var) / self.coord_prior_var) / data_size + Z_vecs = Z if no_randomness else Z + torch.exp(Z_log_var / 2) * torch.randn_like(Z) + else: + Z_vecs, Z_kl_loss = None, 0 + + return H_vecs, Z_vecs, H_kl_loss, Z_kl_loss + + def _get_latent_channels(self, X, atom_mask): + atom_weights = atom_mask.float() # 1 for atom, 0 for padding/missing, [N, 14] + if hasattr(self, 'fix_alpha_carbon') and self.fix_alpha_carbon: + return X[:, self.ca_channel_idx].unsqueeze(1) # use alpha carbon as latent channel + elif self.latent_n_channel == 1: + X = (X * atom_weights.unsqueeze(-1)).sum(1) # [N, 3] + X = X / atom_weights.sum(-1).unsqueeze(-1) # [N, 3] + return X.unsqueeze(1) + elif self.latent_n_channel == 5: + bb_X = X[:, :4] + X = (X * atom_weights.unsqueeze(-1)).sum(1) # [N, 3] + X = X / atom_weights.sum(-1).unsqueeze(-1) # [N, 3] + X = torch.cat([bb_X, X.unsqueeze(1)], dim=1) # [N, 5, 3] + return X + else: + raise NotImplementedError(f'Latent number of channels: {self.latent_n_channel} not implemented') + + def _get_latent_channel_anchors(self, X, atom_mask): + if self.anchor_at_ca: + return X[:, self.ca_channel_idx].unsqueeze(1) + else: + return self._get_latent_channels(X, atom_mask) + + def _fill_latent_channels(self, latent_X): + if self.latent_n_channel == 1: + return latent_X.repeat(1, self.n_channel, 1) + elif self.latent_n_channel == 5: + bb_X = latent_X[:, :4] + sc_X = latent_X[:, 4].unsqueeze(1).repeat(1, self.n_channel - 4, 1) + return torch.cat([bb_X, sc_X], dim=1) + else: + raise NotImplementedError(f'Latent number of channels: {self.latent_n_channel} not implemented') + + def _remove_sidechain_atom_mask(self, atom_mask, mask_generate): + atom_mask = atom_mask.clone() + bb_mask = atom_mask[mask_generate] + bb_mask[:, 4:] = 0 # only backbone atoms are visible + atom_mask[mask_generate] = bb_mask + return atom_mask + + @torch.no_grad() + def _mask_pep(self, S, atom_mask, mask_generate): + assert self.mask_ratio > 0 + S, atom_mask = S.clone(), atom_mask.clone() + pep_S = S[mask_generate] + do_mask = torch.rand_like(pep_S, dtype=torch.float) < self.mask_ratio + pep_S[do_mask] = self.mask_id + + S[mask_generate] = pep_S + atom_mask[mask_generate ]= self._remove_sidechain_atom_mask(atom_mask[mask_generate], do_mask) + + return S, atom_mask + + def encode(self, X, S, mask, position_ids, lengths, atom_mask, no_randomness=False): + true_X = X.clone() + + ctx_edges, inter_edges, batch_ids = self.prepare_inputs(X, S, mask, atom_mask, lengths) + H_0, (atom_embeddings, _) = self.aa_feature(S, position_ids) + + edges = torch.cat([ctx_edges, inter_edges], dim=1) + atom_weights = atom_mask.float() # 1 for atom, 0 for padding/missing, [N, 14] + + H, pred_X = self.encoder(H_0, X, edges, channel_attr=atom_embeddings, channel_weights=atom_weights) + H = H[mask] + + if self.mode != 'fixbb': + if hasattr(self, 'fix_alpha_carbon') and self.fix_alpha_carbon: + Z = self._get_latent_channels(true_X, atom_mask) + else: + Z = self._get_latent_channels(pred_X, atom_mask) + Z_centers = self._get_latent_channel_anchors(true_X, atom_mask) + Z, Z_centers = Z[mask], Z_centers[mask] + else: + Z, Z_centers = None, None + + # resample + latent_H, latent_X, H_kl_loss, X_kl_loss = self.rsample(H, Z, Z_centers, no_randomness) + return latent_H, latent_X, H_kl_loss, X_kl_loss + + def decode(self, X, S, H, Z, mask, position_ids, lengths, atom_mask, teacher_forcing): + X, S, atom_mask = X.clone(), S.clone(), atom_mask.clone() + true_S = S[mask].clone() + if self.mode != 'fixbb': # fill coordinates with latent points + X[mask] = self._fill_latent_channels(Z) + if self.mode != 'fixseq': # fill sequences with mask token + S[mask] = self.latent_id + H_from_latent = self.latent2hidden(H) + + if self.mode == 'fixbb': # only backbone atoms are visible + atom_mask = self._remove_sidechain_atom_mask(atom_mask, mask) + elif self.mode == 'codesign': # all channels are visible when deciding the sequence (all dummy atoms) + atom_mask[mask] = 1 + else: # fixseq mode does not need to change atom mask + pass + + ctx_edges, inter_edges, batch_ids = self.prepare_inputs(X, S, mask, atom_mask, lengths) + edges = torch.cat([ctx_edges, inter_edges], dim=1) + + # decode sequence + if self.mode != 'fixseq': + H_0, (atom_embeddings, _) = self.aa_feature(S, position_ids) + H_0 = H_0.clone() + H_0[mask] = H_from_latent # TODO: how about the position encoding + H, _ = self.seq_decoder(H_0, X, edges, channel_attr=atom_embeddings, channel_weights=atom_mask.float()) + pred_S_logits = self.seq_linear(H[mask]) # [Ntgt, 21] + S = S.clone() + if teacher_forcing: # teacher forcing + S[mask] = true_S + else: # inference + S[mask] = self.s_remap[torch.argmax(pred_S_logits, dim=-1)] + else: + pred_S_logits = None + + # decode sidechain + if self.mode != 'fixbb': + H_0, (atom_embeddings, atom_weights) = self.aa_feature(S, position_ids) + H_0 = H_0.clone() + if self.mode != 'fixseq': + H_0[mask] = self.merge_S_H(torch.cat([H_from_latent, H_0[mask]], dim=-1)) + # H_0[mask] = H_from_latent + atom_mask = atom_mask.clone() + atom_mask[mask] = atom_weights.bool()[mask] & atom_mask[mask] # reset atomic visibility of the reconstruction part with the decoded sequence + _, pred_X = self.sidechain_decoder(H_0, X, edges, channel_attr=atom_embeddings, channel_weights=atom_mask.float()) + pred_X = pred_X[mask] + else: + pred_X = None + + return pred_S_logits, pred_X + + @oom_decorator + def forward(self, X, S, mask, position_ids, lengths, atom_mask, teacher_forcing=True): + true_X, true_S = X[mask].clone(), S[mask].clone() + + # encode: H (N*d), Z (N*3) + if self.mask_ratio > 0: + input_S, input_atom_mask = self._mask_pep(S, atom_mask, mask) + else: + input_S, input_atom_mask = S, atom_mask + H, Z, H_kl_loss, Z_kl_loss = self.encode(X, input_S, mask, position_ids, lengths, input_atom_mask) + + if self.mode != 'fixbb': + coord_reg_loss = F.mse_loss(Z, self._get_latent_channel_anchors(true_X, atom_mask[mask])) + else: + coord_reg_loss = 0 + + # add noise to improve robustness + if self.training: + noise = torch.randn_like(Z) * getattr(self, 'additional_noise_scale', 0.0) + Z = Z + noise + + # decode: S (N), Z (N * 14 * 3) with atom mask + recon_S_logits, recon_X = self.decode(X, S, H, Z, mask, position_ids, lengths, atom_mask, teacher_forcing) + + # sequence reconstruction loss + if self.mode != 'fixseq': + seq_recon_loss = F.cross_entropy(recon_S_logits, self.s_map[true_S]) + # aar + with torch.no_grad(): + aar = (torch.argmax(recon_S_logits, dim=-1) == self.s_map[true_S]).sum() / len(recon_S_logits) + else: + seq_recon_loss, aar = 0, 1.0 + + # coordinates reconstruction loss + if self.mode != 'fixbb': + xloss_mask = atom_mask[mask] + batch_ids = self.get_batch_ids(S, lengths)[mask] + segment_ids = torch.ones_like(true_S, device=true_S.device, dtype=torch.long) + if self.n_channel == 4: # backbone only + loss_profile = {} + else: + true_struct_profile = self.protein_feature.get_struct_profile(true_X, true_S, batch_ids, self.aa_feature, segment_ids, xloss_mask) + recon_struct_profile = self.protein_feature.get_struct_profile(recon_X, true_S, batch_ids, self.aa_feature, segment_ids, xloss_mask) + loss_profile = { key + '_loss': F.l1_loss(recon_struct_profile[key], true_struct_profile[key]) for key in recon_struct_profile } + + # mse + xloss = F.mse_loss(recon_X[xloss_mask], true_X[xloss_mask]) + loss_profile['Xloss'] = xloss + + # CA mse + ca_xloss_mask = xloss_mask[:, self.ca_channel_idx] + ca_xloss = F.mse_loss(recon_X[:, self.ca_channel_idx][ca_xloss_mask], true_X[:, self.ca_channel_idx][ca_xloss_mask]) + loss_profile['ca_Xloss'] = ca_xloss + + struct_recon_loss = 0 + for name in loss_profile: + struct_recon_loss = struct_recon_loss + self.coord_loss_weights[name] * loss_profile[name] + else: + struct_recon_loss, loss_profile = 0, {} + + recon_loss = (1 - self.coord_loss_ratio) * (seq_recon_loss + self.h_kl_weight * H_kl_loss) + \ + self.coord_loss_ratio * (struct_recon_loss + self.z_kl_weight * Z_kl_loss) + + return recon_loss, (seq_recon_loss, aar), (struct_recon_loss, loss_profile), (H_kl_loss, Z_kl_loss, coord_reg_loss) + + def _reconstruct(self, X, S, mask, position_ids, lengths, atom_mask, given_laten_H=None, given_latent_X=None, allow_unk=False, optimize_sidechain=True, idealize=False, no_randomness=False): + if given_laten_H is None and given_latent_X is None: + # encode: H (N*d), Z (N*3) + H, Z, _, _ = self.encode(X, S, mask, position_ids, lengths, atom_mask, no_randomness=no_randomness) + + else: + H, Z = given_laten_H, given_latent_X + + # decode: S (N), Z (N * 14 * 3) with atom mask + recon_S_logits, recon_X = self.decode(X, S, H, Z, mask, position_ids, lengths, atom_mask, teacher_forcing=False) + batch_ids = self.get_batch_ids(S, lengths)[mask] + + if self.mode != 'fixseq': + if not allow_unk: + recon_S_logits[:, 0] = float('-inf') + + # map aa index back + recon_S = self.s_remap[torch.argmax(recon_S_logits, dim=-1)] + # ppls + snll_all = F.cross_entropy(recon_S_logits, torch.argmax(recon_S_logits, dim=-1), reduction='none') + batch_ppls = scatter_mean(snll_all, batch_ids, dim=0) + else: + recon_S = S[mask] + batch_ppls = torch.zeros(batch_ids.max() + 1, device=recon_X.device).float() + + if self.mode == 'fixseq' or (self.mode != 'fixbb' and idealize): + # rectify backbone + recon_X = self.backbone_model(recon_X, batch_ids) + # rectify sidechain + recon_X = self.sidechain_model(recon_X, recon_S, batch_ids, optimize_sidechain) + + return recon_X, recon_S, batch_ppls, batch_ids + + @torch.no_grad() + def test(self, X, S, mask, position_ids, lengths, atom_mask, given_laten_H=None, given_latent_X=None, return_tensor=False, allow_unk=False, optimize_sidechain=True, idealize=False, n_iter=1): + + no_randomness = given_laten_H is not None # in reconstruction mode, with latent variable derived from diffusion model + for i in range(n_iter): + recon_X, recon_S, batch_ppls, batch_ids = self._reconstruct(X, S, mask, position_ids, lengths, atom_mask, given_laten_H, given_latent_X, allow_unk, optimize_sidechain, idealize, no_randomness) + X, S = X.clone(), S.clone() + if self.mode != 'fixbb': + X[mask] = recon_X + if self.mode != 'fixseq': + S[mask] = recon_S + given_laten_H, given_latent_X = None, None # let the model encode and decode for later iterations + + if return_tensor: + return recon_X, recon_S, batch_ppls + + batch_X, batch_S = [], [] + batch_ppls = batch_ppls.tolist() + for i, l in enumerate(lengths): + cur_mask = batch_ids == i + if self.mode == 'fixbb': + batch_X.append(None) + else: + batch_X.append(recon_X[cur_mask].tolist()) + if self.mode == 'fixseq': + batch_S.append(None) + else: + batch_S.append(''.join([VOCAB.idx_to_symbol(s) for s in recon_S[cur_mask]])) + + return batch_X, batch_S, batch_ppls diff --git a/models/autoencoder/sidechain/api.py b/models/autoencoder/sidechain/api.py new file mode 100644 index 0000000000000000000000000000000000000000..ad4d79184b1a003178bfb349294c4553b5bf2ab6 --- /dev/null +++ b/models/autoencoder/sidechain/api.py @@ -0,0 +1,67 @@ +#!/usr/bin/python +# -*- coding:utf-8 -*- +import torch +import torch.nn as nn +import torch.nn.functional as F + +from utils.nn_utils import graph_to_batch +from data.format import VOCAB + +from .sidechain import SideChainBuilder, ChiAngles +from .constants import AA20 + + +class SideChainModel(nn.Module): + def __init__(self): + super().__init__() + self.sidechain_builder = SideChainBuilder() + self.chi_angle_calc = ChiAngles() + + aa_index_inverse_mapping = torch.tensor([VOCAB.symbol_to_idx(a) for a in AA20], dtype=torch.long) + aa_index_mapping = torch.ones(aa_index_inverse_mapping.max() + 1, dtype=torch.long) * 20 # set 20 to unk (0~19 are natural amino acids) + aa_index_mapping[aa_index_inverse_mapping] = torch.arange(20) + self.register_buffer('aa_index_mapping', aa_index_mapping) + + def forward(self, X, S, batch_ids, optimize=True): + ''' + X: [N, 14, 3], predicted all-atom coordinates (obviously with a lot of invalidities) + S: [N], predicted sequence + ''' + # do sequence index mapping from our vocabulary to the sidechain builder native indexes + S = self.aa_index_mapping[S] + + # to batch-form representations + X, mask = graph_to_batch(X, batch_ids, mask_is_pad=False) + S, _ = graph_to_batch(S, batch_ids) + C = mask.long() + + # rectify sidechains + chi, _ = self.chi_angle_calc(X, C, S) + ori_X = X.clone() + if optimize: # optimize chi so that the resulted atoms have similar positions with the predicted ones + with torch.enable_grad(): + chi = chi.clone() + chi.requires_grad = True + delta, lr, step, last_mse = 1e-4, 1, 0, 100 + optimizer = torch.optim.Adam([chi], lr=lr) + while True: + X, mask_X = self.sidechain_builder(ori_X[:, :, :4], C, S, chi) + mask_X = mask_X.squeeze(-1) # [bs, L, 14] + X, mask_X = X[:, :, 4:], mask_X[:, :, 4:].bool() + mse = F.mse_loss(X[mask_X], ori_X[:, :, 4:][mask_X]) # only on sidechain + if torch.abs(mse - last_mse) < delta: + break + mse.backward() + # chi.data = chi.data - lr * chi.grad.data + # chi.grad.zero_() + optimizer.step() + optimizer.zero_grad() + last_mse = mse.detach() + step += 1 + chi = chi.detach() + # print(f'optimized {step} steps, mse {last_mse}') + + X, _ = self.sidechain_builder(ori_X[:, :, :4], C, S, chi) + + # get back to our graph representations + return X[mask] \ No newline at end of file diff --git a/models/autoencoder/sidechain/constants/__init__.py b/models/autoencoder/sidechain/constants/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..40384f4aeafdc4cd88915b3749005de77fc3c79e --- /dev/null +++ b/models/autoencoder/sidechain/constants/__init__.py @@ -0,0 +1,16 @@ +# Copyright Generate Biomedicines, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .geometry import AA_GEOMETRY +from .sequence import * diff --git a/models/autoencoder/sidechain/constants/geometry.py b/models/autoencoder/sidechain/constants/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..e8781a1d396a80821c4fc8e2377de591500b506c --- /dev/null +++ b/models/autoencoder/sidechain/constants/geometry.py @@ -0,0 +1,558 @@ +# Copyright Generate Biomedicines, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Dictionary containing ideal internal coordinates and chi angle assignments + for building amino acid 3D coordinates""" +from typing import Dict + +AA_GEOMETRY: Dict[str, dict] = { + "ALA": { + "atoms": ["CB"], + "chi_indices": [], + "parents": [["N", "C", "CA"]], + "types": {"C": "C", "CA": "CT1", "CB": "CT3", "N": "NH1", "O": "O"}, + "z-angles": [111.09], + "z-dihedrals": [123.23], + "z-lengths": [1.55], + }, + "ARG": { + "atoms": ["CB", "CG", "CD", "NE", "CZ", "NH1", "NH2"], + "chi_indices": [1, 2, 3, 4], + "parents": [ + ["N", "C", "CA"], + ["N", "CA", "CB"], + ["CA", "CB", "CG"], + ["CB", "CG", "CD"], + ["CG", "CD", "NE"], + ["CD", "NE", "CZ"], + ["NH1", "NE", "CZ"], + ], + "types": { + "C": "C", + "CA": "CT1", + "CB": "CT2", + "CD": "CT2", + "CG": "CT2", + "CZ": "C", + "N": "NH1", + "NE": "NC2", + "NH1": "NC2", + "NH2": "NC2", + "O": "O", + }, + "z-angles": [112.26, 115.95, 114.01, 107.09, 123.05, 118.06, 122.14], + "z-dihedrals": [123.64, 180.0, 180.0, 180.0, 180.0, 180.0, 178.64], + "z-lengths": [1.56, 1.55, 1.54, 1.5, 1.34, 1.33, 1.33], + }, + "ASN": { + "atoms": ["CB", "CG", "OD1", "ND2"], + "chi_indices": [1, 2], + "parents": [ + ["N", "C", "CA"], + ["N", "CA", "CB"], + ["CA", "CB", "CG"], + ["OD1", "CB", "CG"], + ], + "types": { + "C": "C", + "CA": "CT1", + "CB": "CT2", + "CG": "CC", + "N": "NH1", + "ND2": "NH2", + "O": "O", + "OD1": "O", + }, + "z-angles": [113.04, 114.3, 122.56, 116.15], + "z-dihedrals": [121.18, 180.0, 180.0, -179.19], + "z-lengths": [1.56, 1.53, 1.23, 1.35], + }, + "ASP": { + "atoms": ["CB", "CG", "OD1", "OD2"], + "chi_indices": [1, 2], + "parents": [ + ["N", "C", "CA"], + ["N", "CA", "CB"], + ["CA", "CB", "CG"], + ["OD1", "CB", "CG"], + ], + "types": { + "C": "C", + "CA": "CT1", + "CB": "CT2A", + "CG": "CC", + "N": "NH1", + "O": "O", + "OD1": "OC", + "OD2": "OC", + }, + "z-angles": [114.1, 112.6, 117.99, 117.7], + "z-dihedrals": [122.33, 180.0, 180.0, -170.23], + "z-lengths": [1.56, 1.52, 1.26, 1.25], + }, + "CYS": { + "atoms": ["CB", "SG"], + "chi_indices": [1], + "parents": [["N", "C", "CA"], ["N", "CA", "CB"]], + "types": {"C": "C", "CA": "CT1", "CB": "CT2", "N": "NH1", "O": "O", "SG": "S"}, + "z-angles": [111.98, 113.87], + "z-dihedrals": [121.79, 180.0], + "z-lengths": [1.56, 1.84], + }, + "GLN": { + "atoms": ["CB", "CG", "CD", "OE1", "NE2"], + "chi_indices": [1, 2, 3], + "parents": [ + ["N", "C", "CA"], + ["N", "CA", "CB"], + ["CA", "CB", "CG"], + ["CB", "CG", "CD"], + ["OE1", "CG", "CD"], + ], + "types": { + "C": "C", + "CA": "CT1", + "CB": "CT2", + "CD": "CC", + "CG": "CT2", + "N": "NH1", + "NE2": "NH2", + "O": "O", + "OE1": "O", + }, + "z-angles": [111.68, 115.52, 112.5, 121.52, 116.84], + "z-dihedrals": [121.91, 180.0, 180.0, 180.0, 179.57], + "z-lengths": [1.55, 1.55, 1.53, 1.23, 1.35], + }, + "GLU": { + "atoms": ["CB", "CG", "CD", "OE1", "OE2"], + "chi_indices": [1, 2, 3], + "parents": [ + ["N", "C", "CA"], + ["N", "CA", "CB"], + ["CA", "CB", "CG"], + ["CB", "CG", "CD"], + ["OE1", "CG", "CD"], + ], + "types": { + "C": "C", + "CA": "CT1", + "CB": "CT2A", + "CD": "CC", + "CG": "CT2", + "N": "NH1", + "O": "O", + "OE1": "OC", + "OE2": "OC", + }, + "z-angles": [111.71, 115.69, 115.73, 114.99, 120.08], + "z-dihedrals": [121.9, 180.0, 180.0, 180.0, -179.1], + "z-lengths": [1.55, 1.56, 1.53, 1.26, 1.25], + }, + "GLY": { + "atoms": [], + "chi_indices": [], + "parents": [], + "types": {"C": "C", "CA": "CT2", "N": "NH1", "O": "O"}, + "z-angles": [], + "z-dihedrals": [], + "z-lengths": [], + }, + "HIS": { + "atoms": ["CB", "CG", "ND1", "CD2", "CE1", "NE2"], + "chi_indices": [1, 2], + "parents": [ + ["N", "C", "CA"], + ["N", "CA", "CB"], + ["CA", "CB", "CG"], + ["ND1", "CB", "CG"], + ["CB", "CG", "ND1"], + ["CB", "CG", "CD2"], + ], + "types": { + "C": "C", + "CA": "CT1", + "CB": "CT2", + "CD2": "CPH1", + "CE1": "CPH2", + "CG": "CPH1", + "N": "NH1", + "ND1": "NR1", + "NE2": "NR2", + "O": "O", + }, + "z-angles": [109.99, 114.05, 124.1, 129.6, 107.03, 110.03], + "z-dihedrals": [122.46, 180.0, 90.0, -171.29, -173.21, 171.99], + "z-lengths": [1.55, 1.5, 1.38, 1.36, 1.35, 1.38], + }, + "HSD": { + "atoms": ["CB", "CG", "ND1", "CD2", "CE1", "NE2"], + "chi_indices": [1, 2], + "parents": [ + ["N", "C", "CA"], + ["N", "CA", "CB"], + ["CA", "CB", "CG"], + ["ND1", "CB", "CG"], + ["CB", "CG", "ND1"], + ["CB", "CG", "CD2"], + ], + "types": { + "C": "C", + "CA": "CT1", + "CB": "CT2", + "CD2": "CPH1", + "CE1": "CPH2", + "CG": "CPH1", + "N": "NH1", + "ND1": "NR1", + "NE2": "NR2", + "O": "O", + }, + "z-angles": [109.99, 114.05, 124.1, 129.6, 107.03, 110.03], + "z-dihedrals": [122.46, 180.0, 90.0, -171.29, -173.21, 171.99], + "z-lengths": [1.55, 1.5, 1.38, 1.36, 1.35, 1.38], + }, + "HSE": { + "atoms": ["CB", "CG", "ND1", "CD2", "CE1", "NE2"], + "chi_indices": [], + "parents": [ + ["N", "C", "CA"], + ["N", "CA", "CB"], + ["CA", "CB", "CG"], + ["ND1", "CB", "CG"], + ["CB", "CG", "ND1"], + ["CB", "CG", "CD2"], + ], + "types": { + "C": "C", + "CA": "CT1", + "CB": "CT2", + "CD2": "CPH1", + "CE1": "CPH2", + "CG": "CPH1", + "N": "NH1", + "ND1": "NR2", + "NE2": "NR1", + "O": "O", + }, + "z-angles": [111.67, 116.94, 120.17, 129.71, 105.2, 105.8], + "z-dihedrals": [123.52, 180.0, 90.0, -178.26, -179.2, 178.66], + "z-lengths": [1.56, 1.51, 1.39, 1.36, 1.32, 1.38], + }, + "HSP": { + "atoms": ["CB", "CG", "ND1", "CD2", "CE1", "NE2"], + "chi_indices": [], + "parents": [ + ["N", "C", "CA"], + ["N", "CA", "CB"], + ["CA", "CB", "CG"], + ["ND1", "CB", "CG"], + ["CB", "CG", "ND1"], + ["CB", "CG", "CD2"], + ], + "types": { + "C": "C", + "CA": "CT1", + "CB": "CT2A", + "CD2": "CPH1", + "CE1": "CPH2", + "CG": "CPH1", + "N": "NH1", + "ND1": "NR3", + "NE2": "NR3", + "O": "O", + }, + "z-angles": [109.38, 114.18, 122.94, 128.93, 108.9, 106.93], + "z-dihedrals": [125.13, 180.0, 90.0, -165.26, -167.62, 167.13], + "z-lengths": [1.55, 1.52, 1.37, 1.35, 1.33, 1.37], + }, + "ILE": { + "atoms": ["CB", "CG1", "CG2", "CD1"], + "chi_indices": [1, 3], + "parents": [ + ["N", "C", "CA"], + ["N", "CA", "CB"], + ["CG1", "CA", "CB"], + ["CA", "CB", "CG1"], + ], + "types": { + "C": "C", + "CA": "CT1", + "CB": "CT1", + "CD": "CT3", + "CG1": "CT2", + "CG2": "CT3", + "N": "NH1", + "O": "O", + }, + "z-angles": [112.93, 113.63, 113.93, 114.09], + "z-dihedrals": [124.22, 180.0, -130.04, 180.0], + "z-lengths": [1.57, 1.55, 1.55, 1.54], + }, + "LEU": { + "atoms": ["CB", "CG", "CD1", "CD2"], + "chi_indices": [1, 2], + "parents": [ + ["N", "C", "CA"], + ["N", "CA", "CB"], + ["CA", "CB", "CG"], + ["CD1", "CB", "CG"], + ], + "types": { + "C": "C", + "CA": "CT1", + "CB": "CT2", + "CD1": "CT3", + "CD2": "CT3", + "CG": "CT1", + "N": "NH1", + "O": "O", + }, + "z-angles": [112.12, 117.46, 110.48, 112.57], + "z-dihedrals": [121.52, 180.0, 180.0, 120.0], + "z-lengths": [1.55, 1.55, 1.54, 1.54], + }, + "LYS": { + "atoms": ["CB", "CG", "CD", "CE", "NZ"], + "chi_indices": [1, 2, 3, 4], + "parents": [ + ["N", "C", "CA"], + ["N", "CA", "CB"], + ["CA", "CB", "CG"], + ["CB", "CG", "CD"], + ["CG", "CD", "CE"], + ], + "types": { + "C": "C", + "CA": "CT1", + "CB": "CT2", + "CD": "CT2", + "CE": "CT2", + "CG": "CT2", + "N": "NH1", + "NZ": "NH3", + "O": "O", + }, + "z-angles": [111.36, 115.76, 113.28, 112.33, 110.46], + "z-dihedrals": [122.23, 180.0, 180.0, 180.0, 180.0], + "z-lengths": [1.56, 1.54, 1.54, 1.53, 1.46], + }, + "MET": { + "atoms": ["CB", "CG", "SD", "CE"], + "chi_indices": [1, 2, 3], + "parents": [ + ["N", "C", "CA"], + ["N", "CA", "CB"], + ["CA", "CB", "CG"], + ["CB", "CG", "SD"], + ], + "types": { + "C": "C", + "CA": "CT1", + "CB": "CT2", + "CE": "CT3", + "CG": "CT2", + "N": "NH1", + "O": "O", + "SD": "S", + }, + "z-angles": [111.88, 115.92, 110.28, 98.94], + "z-dihedrals": [121.62, 180.0, 180.0, 180.0], + "z-lengths": [1.55, 1.55, 1.82, 1.82], + }, + "PHE": { + "atoms": ["CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ"], + "chi_indices": [1, 2], + "parents": [ + ["N", "C", "CA"], + ["N", "CA", "CB"], + ["CA", "CB", "CG"], + ["CD1", "CB", "CG"], + ["CB", "CG", "CD1"], + ["CB", "CG", "CD2"], + ["CG", "CD1", "CE1"], + ], + "types": { + "C": "C", + "CA": "CT1", + "CB": "CT2", + "CD1": "CA", + "CD2": "CA", + "CE1": "CA", + "CE2": "CA", + "CG": "CA", + "CZ": "CA", + "N": "NH1", + "O": "O", + }, + "z-angles": [112.45, 112.76, 120.32, 120.76, 120.63, 120.62, 119.93], + "z-dihedrals": [122.49, 180.0, 90.0, -177.96, -177.37, 177.2, -0.12], + "z-lengths": [1.56, 1.51, 1.41, 1.41, 1.4, 1.4, 1.4], + }, + "PRO": { + "atoms": ["CB", "CG", "CD"], + "chi_indices": [1, 2], + "parents": [["N", "C", "CA"], ["N", "CA", "CB"], ["CA", "CB", "CG"]], + "types": { + "C": "C", + "CA": "CP1", + "CB": "CP2", + "CD": "CP3", + "CG": "CP2", + "N": "N", + "O": "O", + }, + "z-angles": [111.74, 104.39, 103.21], + "z-dihedrals": [113.74, 31.61, -34.59], + "z-lengths": [1.54, 1.53, 1.53], + }, + "SER": { + "atoms": ["CB", "OG"], + "chi_indices": [1], + "parents": [["N", "C", "CA"], ["N", "CA", "CB"]], + "types": { + "C": "C", + "CA": "CT1", + "CB": "CT2", + "N": "NH1", + "O": "O", + "OG": "OH1", + }, + "z-angles": [111.4, 112.45], + "z-dihedrals": [124.75, 180.0], + "z-lengths": [1.56, 1.43], + }, + "THR": { + "atoms": ["CB", "OG1", "CG2"], + "chi_indices": [1], + "parents": [["N", "C", "CA"], ["N", "CA", "CB"], ["OG1", "CA", "CB"]], + "types": { + "C": "C", + "CA": "CT1", + "CB": "CT1", + "CG2": "CT3", + "N": "NH1", + "O": "O", + "OG1": "OH1", + }, + "z-angles": [112.74, 112.16, 115.91], + "z-dihedrals": [126.46, 180.0, -124.13], + "z-lengths": [1.57, 1.43, 1.53], + }, + "TRP": { + "atoms": ["CB", "CG", "CD2", "CD1", "CE2", "NE1", "CE3", "CZ3", "CH2", "CZ2"], + "chi_indices": [1, 2], + "parents": [ + ["N", "C", "CA"], + ["N", "CA", "CB"], + ["CA", "CB", "CG"], + ["CD2", "CB", "CG"], + ["CD1", "CG", "CD2"], + ["CG", "CD2", "CE2"], + ["CE2", "CG", "CD2"], + ["CE2", "CD2", "CE3"], + ["CD2", "CE3", "CZ3"], + ["CE3", "CZ3", "CH2"], + ], + "types": { + "C": "C", + "CA": "CT1", + "CB": "CT2", + "CD1": "CA", + "CD2": "CPT", + "CE2": "CPT", + "CE3": "CAI", + "CG": "CY", + "CH2": "CA", + "CZ2": "CAI", + "CZ3": "CA", + "N": "NH1", + "NE1": "NY", + "O": "O", + }, + "z-angles": [ + 111.23, + 115.14, + 123.95, + 129.18, + 106.65, + 107.87, + 132.54, + 118.16, + 120.97, + 120.87, + ], + "z-dihedrals": [ + 122.68, + 180.0, + 90.0, + -172.81, + -0.08, + 0.14, + 179.21, + -0.2, + 0.1, + 0.01, + ], + "z-lengths": [1.56, 1.52, 1.44, 1.37, 1.41, 1.37, 1.4, 1.4, 1.4, 1.4], + }, + "TYR": { + "atoms": ["CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ", "OH"], + "chi_indices": [1, 2], + "parents": [ + ["N", "C", "CA"], + ["N", "CA", "CB"], + ["CA", "CB", "CG"], + ["CD1", "CB", "CG"], + ["CB", "CG", "CD1"], + ["CB", "CG", "CD2"], + ["CG", "CD1", "CE1"], + ["CE1", "CE2", "CZ"], + ], + "types": { + "C": "C", + "CA": "CT1", + "CB": "CT2", + "CD1": "CA", + "CD2": "CA", + "CE1": "CA", + "CE2": "CA", + "CG": "CA", + "CZ": "CA", + "N": "NH1", + "O": "O", + "OH": "OH1", + }, + "z-angles": [112.34, 112.94, 120.49, 120.46, 120.4, 120.56, 120.09, 120.25], + "z-dihedrals": [122.27, 180.0, 90.0, -176.46, -175.49, 175.32, -0.19, -178.98], + "z-lengths": [1.56, 1.51, 1.41, 1.41, 1.4, 1.4, 1.4, 1.41], + }, + "VAL": { + "atoms": ["CB", "CG1", "CG2"], + "chi_indices": [1], + "parents": [["N", "C", "CA"], ["N", "CA", "CB"], ["CG1", "CA", "CB"]], + "types": { + "C": "C", + "CA": "CT1", + "CB": "CT1", + "CG1": "CT3", + "CG2": "CT3", + "N": "NH1", + "O": "O", + }, + "z-angles": [111.23, 113.97, 112.17], + "z-dihedrals": [122.95, 180.0, 123.99], + "z-lengths": [1.57, 1.54, 1.54], + }, +} diff --git a/models/autoencoder/sidechain/constants/sequence.py b/models/autoencoder/sidechain/constants/sequence.py new file mode 100644 index 0000000000000000000000000000000000000000..4d05dd13182daeb2016caeae78558938691bc883 --- /dev/null +++ b/models/autoencoder/sidechain/constants/sequence.py @@ -0,0 +1,112 @@ +# Copyright Generate Biomedicines, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Constants used across protein representations. + +These constants standardize protein tokenization alphabets, ideal structure +geometries and topologies, etc. +""" +from .geometry import AA_GEOMETRY + +# Standard tokenization for Omniprot and Omniprot-interacting models +OMNIPROT_TOKENS = "ABCDEFGHIKLMNOPQRSTUVWYXZ*-#" +POTTS_EXTENDED_TOKENS = "ACDEFGHIKLMNPQRSTVWY-*#" +PAD = "-" +START = "@" +STOP = "*" +MASK = "#" +DNA_TOKENS = "ACGT" +RNA_TOKENS = "AGCU" +PROTEIN_TOKENS = "ACDEFGHIKLMNPQRSTVWY" + +# Minimal 20-letter alphabet and corresponding triplet codes +AA20 = "ACDEFGHIKLMNPQRSTVWY" +AA20_3_TO_1 = { + "ALA": "A", + "ARG": "R", + "ASN": "N", + "ASP": "D", + "CYS": "C", + "GLN": "Q", + "GLU": "E", + "GLY": "G", + "HIS": "H", + "ILE": "I", + "LEU": "L", + "LYS": "K", + "MET": "M", + "PHE": "F", + "PRO": "P", + "SER": "S", + "THR": "T", + "TRP": "W", + "TYR": "Y", + "VAL": "V", +} +AA20_1_TO_3 = { + "A": "ALA", + "R": "ARG", + "N": "ASN", + "D": "ASP", + "C": "CYS", + "Q": "GLN", + "E": "GLU", + "G": "GLY", + "H": "HIS", + "I": "ILE", + "L": "LEU", + "K": "LYS", + "M": "MET", + "F": "PHE", + "P": "PRO", + "S": "SER", + "T": "THR", + "W": "TRP", + "Y": "TYR", + "V": "VAL", +} +AA20_3 = [AA20_1_TO_3[aa] for aa in AA20] + +# Adding noncanonical amino acids +NONCANON_AA = [ + "HSD", + "HSE", + "HSC", + "HSP", + "MSE", + "CSO", + "SEC", + "CSX", + "HIP", + "SEP", + "TPO", +] +AA31_3 = AA20_3 + NONCANON_AA + +# Chain alphabet for PDB chain naming +CHAIN_ALPHABET = "_ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789" + +# Standard atom indexing +ATOMS_BB = ["N", "CA", "C", "O"] + +ATOM_SYMMETRIES = { + "ARG": [("NH1", "NH2")], # Correct handling of NH1 and NH2 is relabeling + "ASP": [("OD1", "OD2")], + "GLU": [("OE1", "OE2")], + "PHE": [("CD1", "CD2"), ("CE1", "CE2")], + "TYR": [("CD1", "CD2"), ("CE1", "CE2")], +} + +AA20_NUM_ATOMS = [4 + len(AA_GEOMETRY[aa]["atoms"]) for aa in AA20_3] +AA20_NUM_CHI = [len(AA_GEOMETRY[aa]["chi_indices"]) for aa in AA20_3] diff --git a/models/autoencoder/sidechain/sidechain.py b/models/autoencoder/sidechain/sidechain.py new file mode 100644 index 0000000000000000000000000000000000000000..e0e926771c254eb246e0e88a3f4909f9c8e0d3a1 --- /dev/null +++ b/models/autoencoder/sidechain/sidechain.py @@ -0,0 +1,804 @@ +# From Chroma: https://github.com/generatebio/chroma/tree/main +# Many thanks!!!! + +# Copyright Generate Biomedicines, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Layers for modeling protein side chain conformations. + +This module contains layers for building, measuring, and scoring protein side +chain conformations in a differentiable way. These can be used for tasks such +as building differentiable all-atom structures from chi-angles, computing chi +angles from existing structures, and scoring or optimizing side chains using +symmetry-aware rmsds. +""" + + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from . import constants +from .structure.geometry import ( + dihedrals, + extend_atoms, + frames_from_backbone, + quaternions_from_rotations, + rotations_from_quaternions, +) + + +class SideChainBuilder(nn.Module): + """Protein side chain builder from chi angles. + + When only partial information is given such as chi angles, this module + will default to using the ideal geometries given in the CHARMM toppar + topology files. + + `Optimization of the additive CHARMM all-atom protein force + field targeting improved sampling of the backbone phi, + psi and side-chain chi1 and chi2 dihedral angles` + + Inputs: + X (tensor): Backbone coordinates with shape + `(batch_size, num_residues, 4, 3)`. + C (tensor): Chain map with shape `(batch_size, num_residues)`. + S (tensor): Sequence tokens with shape `(batch_size, num_residues)`. + chi (tensor): Backbone chi angles with shape + `(batch_size, num_residues, 4)`. + + Outputs: + X (tensor): All-atom coordinates with shape + `(batch_size, num_residues, 14, 3)`. + mask_X (tensor): Atomic mask with shape + `(batch_size, num_residues, 14, 1)` + """ + + def __init__(self, distance_eps=1e-6): + super(SideChainBuilder, self).__init__() + self.num_atoms = 10 + self.num_chi = 4 + self.num_aa = len(constants.AA20) + self.distance_eps = distance_eps + + self._init_maps() + + def _init_maps(self): + """Build geometry and topology maps in tensor form.""" + + shape = (3, self.num_atoms, self.num_aa) + self.register_buffer("_Z", torch.zeros(shape, dtype=torch.float)) + self.register_buffer("_parents", torch.zeros(shape, dtype=torch.long)) + self.register_buffer( + "_chi_ix", 10 * torch.ones((self.num_chi, self.num_aa), dtype=torch.long) + ) + + for i, aa in enumerate(constants.AA20_3): + aa_dict = constants.AA_GEOMETRY[aa] + atoms_parents = ["N", "CA", "C", "O"] + aa_dict["atoms"] + for j, atom in enumerate(aa_dict["atoms"]): + # Internal coordinates per atom + self._Z[0, j, i] = aa_dict["z-lengths"][j] + self._Z[1, j, i] = aa_dict["z-angles"][j] + self._Z[2, j, i] = aa_dict["z-dihedrals"][j] + + # Parent indices per atom + parents = [atoms_parents.index(p) for p in aa_dict["parents"][j]] + self._parents[0, j, i] = parents[0] + self._parents[1, j, i] = parents[1] + self._parents[2, j, i] = parents[2] + + # Map which chi angles are flexible + for j, parent_ix in enumerate(aa_dict["chi_indices"]): + self._chi_ix[j, i] = parent_ix + + # Convert angles from degrees to radians + self._Z[1:, :, :] = self._Z[1:, :, :] * np.pi / 180.0 + + # Manually fix Arginine, for which CHARMM places NH1 in trans to CD + self._Z[2, 5, constants.AA20.index("R")] = 0.0 + + def forward(self, X, C, S, chi=None): + num_batch, num_residues = list(S.shape) + + if X.shape[2] > 4: + X = X[:, :, :4, :] + + # Expand sequence indexing tensors for gathering residue-specific info + # (B,L) => (B,L,4) + S_expand3 = S.unsqueeze(-1).expand(-1, -1, 4) + # (B,L) => (B,AA,ATOM,L) + S_expand4 = S.reshape([num_batch, 1, 1, num_residues]).expand( + -1, 3, self.num_atoms, -1 + ) + + def _gather(Z): + Z_expand = Z.unsqueeze(0).expand([num_batch, -1, -1, -1]) + # (B,3,ATOM,AA) @ (B,3,ATOM,L) => (B,3,ATOM,L) => (B,L,3,ATOM) + Z_i = torch.gather(Z_expand, -1, S_expand4).permute([0, 3, 1, 2]) + return Z_i + + # Build ideal geometry length, angle, and dihedral tensors 3x(B,L,10) + B, A, D = _gather(self._Z).unbind(-2) + + if chi is not None: + # Scatter chi angles (B,L,4) onto their corresponding dihedrals (B,L,10) + # (4,AA) => (B,AA,4) + chi_ix_expand = ( + self._chi_ix.unsqueeze(0).expand([num_batch, -1, -1]).transpose(-2, -1) + ) + # (B,AA,4) @ (B,L,4) => (B,L,4) + chi_ix_i = torch.gather(chi_ix_expand, -2, S_expand3) + + # Scatter extra chi angles into an extra pad dimension & re-slice + # (B,L,10) <- (B,L,4),(B,L,4) => (B,L,10) + D_pad = F.pad(D, (0, 1)) + D_pad = torch.scatter(D_pad, -1, chi_ix_i, chi) + D = D_pad[:, :, : self.num_atoms] + + # Build indices of parent atoms (B,L,3,10) + X_full = F.pad(X, (0, 0, 0, self.num_atoms)) + parents = _gather(self._parents) + + # Build atom i given current frame + for i in range(self.num_atoms): + # Gather parents (B,L,A,3) => (B,L,3,3) + parents_expand = parents[:, :, :, i].unsqueeze(-1).expand(-1, -1, -1, 3) + # (B,L,A,3) @ (B,L,3,3) => (B,L,3,3) + X1, X2, X3 = torch.gather(X_full, -2, parents_expand).unbind(-2) + + # Extend atom i + X4 = extend_atoms( + X1, + X2, + X3, + B[:, :, i], + A[:, :, i], + D[:, :, i], + degrees=False, + distance_eps=self.distance_eps, + ) + + # Scatter + # X[:,:,i+4,:] = X4 + # scatter_ix = (i+4) * torch.ones( + # (num_batch,num_residues,1,3), dtype=torch.long + # ) + # print(X_full.shape, X4.shape, scatter_ix.shape, i+4) + # print(scatter_ix) + # X_full.scatter_(-2, scatter_ix, X4.unsqueeze(-2)) + # X_full = torch.scatter(X_full, -2, scatter_ix, X4) + # X_full = X_full + 0.1*X4.mean() + + # For some reason direct scatter causes autograd issues + X4_expand = F.pad(X4.unsqueeze(-2), (0, 0, 4 + i, 9 - i)) + X_full = X_full + X4_expand + + # DEBUG: TEST + if False: + D_reconstruct = dihedrals(X1, X2, X3, X4) + D_error = ( + (torch.cos(D[:, :, i]) - torch.cos(D_reconstruct)) ** 2 + + (torch.sin(D[:, :, i]) - torch.sin(D_reconstruct)) ** 2 + ).mean() + print(D_error) + + mask_X = atom_mask(C, S).unsqueeze(-1) + X_full = mask_X * X_full + return X_full, mask_X + + +class ChiAngles(nn.Module): + """Computes Chi-angles from an all-atom protein structure. + + Inputs: + X (tensor): Atomic coordinates with shape + `(batch_size, num_residues, 14, 3)`. + C (tensor): Chain map with shape `(batch_size, num_residues)`. + S (tensor): Sequence tokens with shape `(batch_size, num_residues)`. + + Outputs: + chi (tensor): Backbone chi angles with shape + `(batch_size, num_residues, 4)`. + mask_chi (tensor): Chi angle mask with shape + `(batch_size, num_residues, 4)`. + """ + + def __init__(self, distance_eps=1e-6): + super(ChiAngles, self).__init__() + self.num_atoms = 10 + self.num_chi = 4 + self.num_aa = len(constants.AA20) + + self.distance_eps = distance_eps + + self._init_maps() + + def _init_maps(self): + """Build geometry and topology maps in tensor form.""" + + self.register_buffer( + "_chi_atom_sets", + torch.zeros((self.num_aa, self.num_chi, 4), dtype=torch.long), + ) + + for i, aa in enumerate(constants.AA20_3): + aa_dict = constants.AA_GEOMETRY[aa] + atoms_names = ["N", "CA", "C", "O"] + aa_dict["atoms"] + + # Map which chi angles are flexible + for j, parent_ix in enumerate(aa_dict["chi_indices"]): + atom_quartet = aa_dict["parents"][parent_ix] + [ + aa_dict["atoms"][parent_ix] + ] + for k, atom in enumerate(atom_quartet): + self._chi_atom_sets[i, j, k] = atoms_names.index(atom) + + def forward(self, X, C, S): + num_batch, num_residues = list(S.shape) + # (B,L) => (B,L,16) + S_expand = S.unsqueeze(-1).expand([-1, -1, 16]) + + # (AA,CHI,ATOM) => (AA,16) => (B,AA,16) + chi_indices_per_aa = self._chi_atom_sets.reshape([1, self.num_aa, 16]) + chi_indices_per_aa = chi_indices_per_aa.expand([num_batch, -1, -1]) + + # (B,AA,16) @ (B,L,16) => (B,L,16) => (B,L,16) + chi_indices = torch.gather(chi_indices_per_aa, -2, S_expand) + chi_indices = chi_indices.unsqueeze(-1).expand([-1, -1, -1, 3]) + + # (B,L,14,3) @ (B,L,16,3) => (B,L,16,3) => (B,L,4,4,3) => (B,L,4) + X_chi = torch.gather(X, -2, chi_indices) + X_1, X_2, X_3, X_4 = X_chi.reshape([num_batch, num_residues, 4, 4, 3]).unbind( + -2 + ) + + chi = dihedrals(X_1, X_2, X_3, X_4, distance_eps=self.distance_eps) + + mask_chi = chi_mask(C, S) + chi = chi * mask_chi + return chi, mask_chi + + +class SideChainSymmetryRenamer(nn.Module): + """Rename atom to their 180-degree symmetry alternatives via permutation. + + Inputs: + X (tensor): Atomic coordinates with shape + `(batch_size, num_residues, 14, 3)`. + S (tensor): Sequence tokens with shape `(batch_size, num_residues)`. + + Outputs: + X_renamed (tensor): Renamed atomic coordinates with shape + `(batch_size, num_residues, 14, 3)`. + """ + + def __init__(self): + super(SideChainSymmetryRenamer, self).__init__() + self.num_atoms = 10 + self.num_aa = len(constants.AA20) + + # Build symmetry indices give alternative atom labelings + self.register_buffer( + "_symmetry_indices", + torch.arange(self.num_atoms).unsqueeze(0).repeat(self.num_aa, 1), + ) + for i, aa in enumerate(constants.AA20_3): + if aa in constants.ATOM_SYMMETRIES: + for aa_1, aa_2 in constants.ATOM_SYMMETRIES[aa]: + atom_names = constants.AA_GEOMETRY[aa]["atoms"] + ix_1 = atom_names.index(aa_1) + ix_2 = atom_names.index(aa_2) + self._symmetry_indices[i, ix_1] = ix_2 + self._symmetry_indices[i, ix_2] = ix_1 + + def _gather_per_residue(self, AA_table, S): + num_batch, num_residues = list(S.shape) + + # (B,L) => (B,L,ATOM) + S_expand = S.unsqueeze(-1).expand([-1, -1, self.num_atoms]) + + # (AA,ATOM) => (B,AA,ATOM) + value_per_aa = AA_table.unsqueeze(0).expand([num_batch, -1, -1]) + + # (B,AA,ATOM) @ (B,L,ATOM) => (B,L,ATOM) + value_per_residue = torch.gather(value_per_aa, -2, S_expand) + return value_per_residue + + def forward(self, X, S): + alt_indices = self._gather_per_residue(self._symmetry_indices, S) + alt_indices = alt_indices.unsqueeze(-1).expand(-1, -1, -1, 3) + + X_bb, X_sc = X[:, :, :4, :], X[:, :, 4:, :] + X_sc_alternate = torch.gather(X_sc, -2, alt_indices) + X_alternate = torch.cat([X_bb, X_sc_alternate], dim=-2) + return X_alternate + + +class AllAtomFrameBuilder(nn.Module): + """Build all-atom protein structure from oriented C-alphas and chi angles. + + Inputs: + x (Tensor): C-alpha coordinates with shape `(num_batch, num_residues, 3)`. + q (Tensor): Quaternions representing C-alpha orientiations with shape + with shape `(num_batch, num_residues, 4)`. + chi (tensor): Backbone chi angles with shape + `(num_batch, num_residues, 4)`. + C (tensor): Chain map with shape `(num_batch, num_residues)`. + S (tensor): Sequence tokens with shape `(num_batch, num_residues)`. + + Outputs: + X (Tensor): All-atom protein coordinates with shape + `(num_batch, num_residues, 14, 3)` + """ + + def __init__(self): + super(AllAtomFrameBuilder, self).__init__() + self.sidechain_builder = SideChainBuilder() + self.chi_angles = ChiAngles() + + # Build idealized backbone fragment + # IC +N CA *C O 1.3558 116.8400 180.0000 122.5200 1.2297 + dX = torch.tensor( + [ + [1.459, 0.0, 0.0], # N-C via Engh & Huber is 1.459 + [0.0, 0.0, 0.0], # CA is origin + [-0.547, 0.0, -1.424], # C is placed 1.525 A @ 111 degrees from N + ], + dtype=torch.float32, + ) + self.register_buffer("_dX_local", dX) + + def forward(self, x, q, chi, C, S): + # Build backbone + R = rotations_from_quaternions(q, normalize=True) + dX = torch.einsum("ay,nixy->niax", self._dX_local, R) + X_chain = x.unsqueeze(-2) + dX + + # Build carboxyl groups + X_N, X_CA, X_C = X_chain.unbind(-2) + + # TODO: fix this behavior for termini + mask_next = (C > 0).float()[:, 1:].unsqueeze(-1) + X_N_next = F.pad(mask_next * X_N[:, 1:,], (0, 0, 0, 1),) + + num_batch, num_residues = C.shape + ones = torch.ones(list(C.shape), dtype=torch.float32, device=C.device) + X_O = extend_atoms( + X_N_next, + X_CA, + X_C, + 1.2297 * ones, + 122.5200 * ones, + 180 * ones, + degrees=True, + ) + X_bb = torch.stack([X_N, X_CA, X_C, X_O], dim=-2) + + # Build sidechains + X, mask_atoms = self.sidechain_builder(X_bb, C, S, chi) + return X, mask_atoms + + def inverse(self, X, C, S): + X_bb = X[:, :, :4, :] + R, x = frames_from_backbone(X_bb) + q = quaternions_from_rotations(R) + chi, mask_chi = self.chi_angles(X, C, S) + return x, q, chi + + +class LossSideChainRMSD(nn.Module): + """Compute side chain RMSDs per residues from an all-atom protein structure. + + Inputs: + X (tensor): Atomic coordinates with shape + `(batch_size, num_residues, 14, 3)`. + X_target (tensor): Atomic coordinates with shape + `(batch_size, num_residues, 14, 3)`. + S (tensor): Sequence tokens with shape `(batch_size, num_residues)`. + + Outputs: + chi (tensor): Backbone chi angles with shape + `(batch_size, num_residues, 4)`. + """ + + def __init__(self, rmsd_eps=1e-2): + super(LossSideChainRMSD, self).__init__() + self.num_atoms = 10 + self.num_aa = len(constants.AA20) + + self.rmsd_eps = rmsd_eps + self.renamer = SideChainSymmetryRenamer() + + def _rmsd(self, X, X_target, atom_mask): + sd = atom_mask * ((X - X_target) ** 2).sum(-1) + rmsd = torch.sqrt( + sd.sum(-1) / (atom_mask.sum(-1) + self.rmsd_eps) + self.rmsd_eps + ) + return rmsd + + def forward(self, X, X_target, C, S, include_symmetry=True): + mask_atoms = atom_mask(C, S) + + X_alt = self.renamer(X, S)[:, :, 4:, :] + X = X[:, :, 4:, :] + X_target = X_target[:, :, 4:, :] + mask_atoms = mask_atoms[:, :, 4:] + + rmsd = self._rmsd(X, X_target, mask_atoms) + if include_symmetry: + rmsd_alternate = self._rmsd(X_alt, X_target, mask_atoms) + + # rmsd = torch.minimum(rmsd, rmsd_alternate) + rmsd = torch.stack([rmsd, rmsd_alternate], -1).min(-1)[0] + rmsd = (C > 0).float() * rmsd + return rmsd + + +class LossFrameAlignedGraph(nn.Module): + """Compute the frame-aligned loss on a nearest neighbors graph. + + Args: + num_neighbors (int): Number of neighbors to build in the graph. Default + is 30. + + Inputs: + X (tensor): Atomic coordinates with shape + `(batch_size, num_residues, 14, 3)`. + X_target (tensor): Atomic coordinates with shape + `(batch_size, num_residues, 14, 3)`. + C (tensor): Chain map with shape `(batch_size, num_residues)`. + S (tensor): Sequence tokens with shape `(batch_size, num_residues)`. + + Outputs: + D (tensor): Per-residue losses with shape `(batch_size, num_residues)`. + """ + + def __init__( + self, + num_neighbors=30, + distance_eps=1e-2, + distance_scale=10.0, + interface_only=False, + ): + super(LossFrameAlignedGraph, self).__init__() + self.distance_eps = distance_eps + self.distance_scale = distance_scale + + self.renamer = SideChainSymmetryRenamer() + self.graph_builder = protein_graph.ProteinGraph(num_neighbors) + self.interface_only = interface_only + + def _frame_ij(self, X, edge_idx): + # Build local frames + num_batch, num_residues, num_atoms, _ = X.shape + + # Build frames at neighbor j (B,L,K,3,3), (B,L,K,3) + X_bb_flat = X[:, :, :4, :].reshape([num_batch, num_residues, -1]) + X_j_flat = graph.collect_neighbors(X_bb_flat, edge_idx) + X_j = X_j_flat.reshape([num_batch, num_residues, -1, 4, 3]) + R_j, X_j_CA = frames_from_backbone(X_j, distance_eps=self.distance_eps) + + # (B,L,1,A,3) - (B,L,K,1,3) => (B,L,K,A,3) + X_ij = X.unsqueeze(-3) - X_j_CA.unsqueeze(-2) + + # Rotate displacements into local frames + r_ij = torch.einsum("nijax,nijxy->nijay", X_ij, R_j) + return r_ij + + def _dist(self, r_ij_1, r_ij_2): + D_sq = (r_ij_1 - r_ij_2) ** 2 + D = torch.sqrt(D_sq.sum(-1) + self.distance_eps) + return D + + def forward(self, X, X_target, C, S): + if X_target.size(2) == 14: + mask_atoms = atom_mask(C, S) + X_alt = self.renamer(X, S) + elif X_target.size(2) == 4: + mask_atoms = (C > 0).float().unsqueeze(-1).expand([-1, -1, 4]) + X_alt = X + else: + raise Exception( + "Size of atom dimension must be 4 (backbone) or 14 (all-atom)." + ) + + # Build the union graph + custom_mask_2D = None + if self.interface_only: + custom_mask_2D = torch.ne(C.unsqueeze(1), C.unsqueeze(2)).float() + edge_idx_model, _ = self.graph_builder( + X[:, :, :4, :], C, custom_mask_2D=custom_mask_2D + ) + edge_idx_target, _ = self.graph_builder( + X_target[:, :, :4, :], C, custom_mask_2D=custom_mask_2D + ) + edge_idx = torch.cat([edge_idx_model, edge_idx_target], 2) + + # Build frame-aligned displacement vectors (B,N,K,A,3) + r_ij = self._frame_ij(X, edge_idx) + r_ij_alt = self._frame_ij(X_alt, edge_idx) + r_ij_target = self._frame_ij(X_target, edge_idx) + + # Build 2D masks (B,N,K,A) + num_batch, num_residues, num_atoms, _ = X.shape + mask_residues = (C > 0).float() + # (B,N,1,A) + mask_i = mask_atoms.reshape([num_batch, num_residues, 1, num_atoms]) + # (B,N,K,1) + mask_j = graph.collect_neighbors(mask_residues.unsqueeze(-1), edge_idx) + mask_ij = mask_i * mask_j + + # Build frame-aligned displacement vectors (B,N,N,A) + D = mask_ij * self._dist(r_ij, r_ij_target) + D_alt = mask_ij * self._dist(r_ij_alt, r_ij_target) + + # Which definition of atom j gives a better score? (B,N) + mask_reduce = mask_ij.sum([-2, -1]) + D_j = D.sum([-2, -1]) / (mask_reduce + self.distance_eps) + D_j_alt = D_alt.sum([-2, -1]) / (mask_reduce + self.distance_eps) + D_j_min = torch.stack([D_j, D_j_alt], -1).min(-1)[0] + + # Return as a per-residue loss + return D_j_min + + +class LossAllAtomDistances(nn.Module): + """Compute the interatomic distance loss on a nearest neighbors graph. + + Args: + num_neighbors (int): Number of neighbors to build in the graph. Default + is 30. + + Inputs: + X (tensor): Atomic coordinates with shape + `(batch_size, num_residues, 14, 3)`. + X_target (tensor): Atomic coordinates with shape + `(batch_size, num_residues, 14, 3)`. + C (tensor): Chain map with shape `(batch_size, num_residues)`. + S (tensor): Sequence tokens with shape `(batch_size, num_residues)`. + + Outputs: + D (tensor): Per-residue losses with shape `(batch_size, num_residues)`. + """ + + def __init__(self, num_neighbors=30, distance_eps=1e-2): + super(LossAllAtomDistances, self).__init__() + self.distance_eps = distance_eps + + self.graph_builder = protein_graph.ProteinGraph(num_neighbors) + + def _dist_ij(self, X, edge_idx): + # Build local frames + num_batch, num_residues, num_atoms, _ = X.shape + + # Build frames at neighbor j (B,L,K,), (B,L,K,A,3) + X_flat = X.reshape([num_batch, num_residues, -1]) + X_j_flat = graph.collect_neighbors(X_flat, edge_idx) + X_j = X_j_flat.reshape([num_batch, num_residues, -1, num_atoms, 3]) + X_i = X.unsqueeze(2).expand([-1, -1, X_j.shape[2], -1, -1]) + + X_ij = torch.cat([X_i, X_j], -2) + D_ij = torch.sqrt( + ((X_ij.unsqueeze(-2) - X_ij.unsqueeze(-3)) ** 2).sum(-1) + self.distance_eps + ) + return D_ij + + def _mask_ij(self, C, S, edge_idx): + # (B,L,A) + mask_atoms = atom_mask(C, S) + + mask_j = graph.collect_neighbors(mask_atoms, edge_idx) + mask_i = mask_atoms.unsqueeze(2).expand([-1, -1, edge_idx.shape[2], -1]) + mask_ij = torch.cat([mask_i, mask_j], -1) + + mask_D = mask_ij.unsqueeze(-1) * mask_ij.unsqueeze(-2) + return mask_D + + def forward(self, X, X_target, C, S): + # Build the union graph + edge_idx_model, _ = self.graph_builder(X[:, :, :4, :], C) + edge_idx_target, _ = self.graph_builder(X_target[:, :, :4, :], C) + edge_idx = torch.cat([edge_idx_model, edge_idx_target], 2) + + mask_ij = self._mask_ij(C, S, edge_idx) + D_model = self._dist_ij(X, edge_idx) + D_target = self._dist_ij(X_target, edge_idx) + + loss = torch.sqrt((D_model - D_target) ** 2 + self.distance_eps) + loss_i = (mask_ij * loss).sum([2, 3, 4]) / ( + mask_ij.sum([2, 3, 4]) + self.distance_eps + ) + return loss_i + + +class LossSidechainClashes(nn.Module): + """Count sidechain clashes in a structure using a nearest neighbors graph. + + This uses the Van der Waals radii based definition of bonding + in pymol as described at https://pymolwiki.org/index.php/Connect_cutoff. + + Args: + num_neighbors (int, optional): Number of neighbors to + build in the graph. Default is 30. + connect_cutoff (float, optional): Bonding cutoff used in formula + `D_clash_cutoff = D_vdw / 2. + self.connect_cutoff`. Default is + 0.35. + use_smooth_cutoff (bool, optional): If True, use a differentiable + definition of clashes by replacing `D < cutoff` with + `sigmoid(smooth_cutoff_alpha * (cutoff - D))`. Default is False. + smooth_cutoff_alpha (float, optional): Steepness parameter for + differentiable clashes, as `alpha -> infinity` it will behave as + discrete cutoff. Default is 1.0. + + Inputs: + X (tensor): Atomic coordinates with shape + `(batch_size, num_residues, 14, 3)`. + C (tensor): Chain map with shape `(batch_size, num_residues)`. + S (tensor): Sequence tokens with shape `(batch_size, num_residues)`. + mask_j (tensor, optional): Binary mask encoding which side chains + should be tested for clashing. + + Outputs: + clashes (tensor): Per-residue number of clashes with shape + `(batch_size, num_residues)`. + """ + + def __init__( + self, + num_neighbors=30, + distance_eps=1e-3, + connect_cutoff=0.35, + use_smooth_cutoff=False, + smooth_cutoff_alpha=1.0, + ): + super(LossSidechainClashes, self).__init__() + self.distance_eps = distance_eps + self.graph_builder = protein_graph.ProteinGraph(num_neighbors) + self.connect_cutoff = connect_cutoff + self.use_smooth_cutoff = use_smooth_cutoff + self.smooth_cutoff_alpha = smooth_cutoff_alpha + + def _dist_ij(self, X, edge_idx): + num_batch, num_residues, num_atoms, _ = X.shape + + # Build frames at neighbor j (B,L,K,), (B,L,K,A,3) + X_flat = X.reshape([num_batch, num_residues, -1]) + X_j_flat = graph.collect_neighbors(X_flat, edge_idx) + X_j = X_j_flat.reshape([num_batch, num_residues, -1, num_atoms, 3]) + X_i = X.unsqueeze(2).expand([-1, -1, X_j.shape[2], -1, -1]) + + D_ij = torch.sqrt( + ((X_i.unsqueeze(-2) - X_j.unsqueeze(-3)) ** 2).sum(-1) + self.distance_eps + ) + return D_ij + + def _mask_ij(self, C, S, edge_idx, mask_j=None): + # (B,L,A) + mask_atoms = atom_mask(C, S) + + # Mask only present atoms + mask_atoms_j = mask_atoms + if mask_j is not None: + mask_atoms_j = mask_atoms_j * mask_j.unsqueeze(-1) + mask_j = graph.collect_neighbors(mask_atoms_j, edge_idx) + mask_i = mask_atoms.unsqueeze(2).expand([-1, -1, edge_idx.shape[2], -1]) + mask_D = mask_i.unsqueeze(-1) * mask_j.unsqueeze(-2) + + # Mask self interactions + node_idx = torch.arange(C.shape[1], device=C.device).reshape([1, -1, 1]) + mask_ne = torch.ne(edge_idx, node_idx) + mask_D = mask_D * mask_ne.reshape(list(mask_ne.shape) + [1, 1]) + return mask_D + + def _gather_vdw_radii(self, C, S): + vdw_radii = {"C": 1.7, "N": 1.55, "O": 1.52, "S": 1.8} + + # Van der waal radii per atom per residue [AA,ATOM] + R = torch.zeros([20, 14], device=C.device) + for i, aa in enumerate(constants.AA20_3): + atoms = constants.ATOMS_BB + constants.AA_GEOMETRY[aa]["atoms"] + for j, atom_name in enumerate(atoms): + R[i, j] = vdw_radii[atom_name[0]] + + # (B, AA, ATOM) @ (B, L, ATOM) => (B, L, ATOM) + R = R.reshape([1, 20, 14]).expand([C.shape[0], -1, -1]) + S = S.unsqueeze(-1).expand([-1, -1, 14]) + atom_radii = torch.gather(R, 1, S) + return atom_radii + + def _gather_vdw_diameters(self, C, S, edge_idx): + num_batch, num_residues, num_neighbors = edge_idx.shape + + # Gather van der Waals radii + radii_i = self._gather_vdw_radii(C, S) + radii_j = graph.collect_neighbors(radii_i, edge_idx) + radii_i = radii_i.reshape([num_batch, num_residues, 1, -1, 1]) + radii_j = radii_j.reshape([num_batch, num_residues, num_neighbors, 1, -1]) + + D_vdw = radii_i + radii_j + return D_vdw + + def forward(self, X, C, S, edge_idx=None, mask_j=None, mask_ij=None): + # Compute sidechain interatomic distances + if edge_idx is None: + edge_idx, mask_ij = self.graph_builder(X[:, :, :4, :], C) + + # Distance with shape [B,L,K,AI,AJ] + mask_clash_ij = self._mask_ij(C, S, edge_idx, mask_j) + if mask_ij is not None: + mask_clash_ij = mask_clash_ij * mask_ij.reshape( + list(mask_ij.shape) + [1, 1] + ) + D = self._dist_ij(X, edge_idx) + D_vdw = self._gather_vdw_diameters(C, S, edge_idx) + D_clash_cutoff = D_vdw / 2.0 + self.connect_cutoff + + # Optionally use a smooth definition of clashes that is differentiable + if self.use_smooth_cutoff: + bond_clash = mask_clash_ij * torch.sigmoid( + self.smooth_cutoff_alpha * (D_clash_cutoff - D) + ) + else: + bond_clash = mask_clash_ij * (D < D_clash_cutoff).float() + + # Only cound outgoing clashes from sidechain atoms at i + bond_clash = bond_clash[:, :, :, 4:, :] + + clashes = bond_clash.sum([2, 3, 4]) + return clashes + + +def _gather_atom_mask(C, S, atoms_per_aa, num_atoms): + device = S.device + atoms_per_aa = torch.tensor(atoms_per_aa, dtype=torch.long) + atoms_per_aa = atoms_per_aa.to(device).unsqueeze(0).expand(S.shape[0], -1) + + # (B,A) @ (B,L) => (B,L) + atoms_per_residue = torch.gather(atoms_per_aa, -1, S) + atoms_per_residue = (C > 0).float() * atoms_per_residue + + ix_expand = torch.arange(num_atoms, device=device).reshape([1, 1, -1]) + mask_atoms = ix_expand < atoms_per_residue.unsqueeze(-1) + mask_atoms = mask_atoms.float() + return mask_atoms + + +def atom_mask(C, S): + """Constructs a all-atom coordinate mask from a sequence and chain map. + + Inputs: + C (tensor): Chain map with shape `(batch_size, num_residues)`. + S (tensor): Sequence tokens with shape `(batch_size, num_residues)`. + + Outputs: + mask_atoms (tensor): Atomic mask with shape + `(batch_size, num_residues, 14)`. + """ + return _gather_atom_mask(C, S, constants.AA20_NUM_ATOMS, 14) + + +def chi_mask(C, S): + """Constructs a all-atom coordinate mask from a sequence and chain map. + + Inputs: + C (tensor): Chain map with shape `(batch_size, num_residues)`. + S (tensor): Sequence tokens with shape `(batch_size, num_residues)`. + + Outputs: + mask_atoms (tensor): Chi angle mask with shape + `(batch_size, num_residues, 4)`. + """ + return _gather_atom_mask(C, S, constants.AA20_NUM_CHI, 4) diff --git a/models/autoencoder/sidechain/structure/geometry.py b/models/autoencoder/sidechain/structure/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..939143c2d4928adb154fedede267912bf7919860 --- /dev/null +++ b/models/autoencoder/sidechain/structure/geometry.py @@ -0,0 +1,681 @@ +# Copyright Generate Biomedicines, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Layers for measuring and building atomic geometries in proteins. + +This module contains pytorch layers for computing common geometric features of +protein backbones in a differentiable way and for converting between internal +and Cartesian coordinate representations. +""" + +from typing import Optional, Tuple + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class Distances(nn.Module): + """Euclidean distance layer (pairwise). + + This layer computes batched pairwise Euclidean distances, where the input + tensor is treated as a batch of vectors with the final dimension as the + feature dimension and the dimension for pairwise expansion can be specified. + + Args: + distance_eps (float, optional): Small parameter to adde to squared + distances to make gradients smooth near 0. + + Inputs: + X (tensor): Input coordinates with shape `([...], length, [...], 3)`. + dim (int, optional): Dimension upon which to expand to pairwise + distances. Defaults to -2. + mask (tensor, optional): Masking tensor with shape + `([...], length, [...])`. + + Outputs: + D (tensor): Distances with shape `([...], length, length, [...])` + """ + + def __init__(self, distance_eps=1e-3): + super(Distances, self).__init__() + self.distance_eps = distance_eps + + def forward( + self, X: torch.Tensor, mask: Optional[torch.Tensor] = None, dim: float = -2 + ) -> torch.Tensor: + dim_expand = dim if dim < 0 else dim + 1 + dX = X.unsqueeze(dim_expand - 1) - X.unsqueeze(dim_expand) + D_square = torch.sum(dX ** 2, -1) + D = torch.sqrt(D_square + self.distance_eps) + if mask is not None: + mask_expand = mask.unsqueeze(dim) * mask.unsqueeze(dim + 1) + D = mask_expand * D + return D + + +class VirtualAtomsCA(nn.Module): + """Virtual atoms layer, branching from backbone C-alpha carbons. + + This layer places virtual atom coordinates relative to backbone coordinates + in a differentiable way. + + Args: + virtual_type (str, optional): Type of virtual atom to place. Currently + supported types are `dicons`, a virtual placement that was + optimized to predict potential rotamer interactions, and `cbeta` + which places a virtual C-beta carbon assuming ideal geometry. + distance_eps (float, optional): Small parameter to add to squared + distances to make gradients smooth near 0. + + Inputs: + X (Tensor): Backbone coordinates with shape + `(num_batch, num_residues, num_atom_types, 3)`. + C (Tensor): Chain map tensor with shape `(num_batch, num_residues)`. + + Outputs: + X_virtual (Tensor): Virtual coordinates with shape + `(num_batch, num_residues, 3)`. + """ + + def __init__(self, virtual_type="dicons", distance_eps=1e-3): + super(VirtualAtomsCA, self).__init__() + self.distance_eps = distance_eps + + """ + Geometry specifications + dicons + Length CA-X: 2.3866 + Angle N-CA-X: 111.0269 + Dihedral C-N-CA-X: -138.886412 + + cbeta + Length CA-X: 1.532 (Engh and Huber, 2001) + Angle N-CA-X: 109.5 (tetrahedral geometry) + Dihedral C-N-CA-X: -125.25 (109.5 / 2 - 180) + """ + self.virtual_type = virtual_type + virtual_geometries = { + "dicons": [2.3866, 111.0269, -138.8864122], + "cbeta": [1.532, 109.5, -125.25], + } + self.virtual_geometries = virtual_geometries + self.distance_eps = distance_eps + + def geometry(self): + bond, angle, dihedral = self.virtual_geometries[self.virtual_type] + return bond, angle, dihedral + + def forward(self, X: torch.Tensor, C: torch.LongTensor) -> torch.Tensor: + bond, angle, dihedral = self.geometry() + + ones = torch.ones([1, 1], device=X.device) + bonds = bond * ones + angles = angle * ones + dihedrals = dihedral * ones + + # Build reference frame + # 1.C -> 2.N -> 3.CA -> 4.X + X_N, X_CA, X_C, X_O = X.unbind(2) + X_virtual = extend_atoms( + X_C, + X_N, + X_CA, + bonds, + angles, + dihedrals, + degrees=True, + distance_eps=self.distance_eps, + ) + + # Mask missing positions + mask = (C > 0).type(torch.float32).unsqueeze(-1) + X_virtual = mask * X_virtual + return X_virtual + + +def normed_vec(V: torch.Tensor, distance_eps: float = 1e-3) -> torch.Tensor: + """Normalized vectors with distance smoothing. + + This normalization is computed as `U = V / sqrt(|V|^2 + eps)` to avoid cusps + and gradient discontinuities. + + Args: + V (Tensor): Batch of vectors with shape `(..., num_dims)`. + distance_eps (float, optional): Distance smoothing parameter for + for computing distances as `sqrt(sum_sq) -> sqrt(sum_sq + eps)`. + Default: 1E-3. + + Returns: + U (Tensor): Batch of normalized vectors with shape `(..., num_dims)`. + """ + # Unit vector from i to j + mag_sq = (V ** 2).sum(dim=-1, keepdim=True) + mag = torch.sqrt(mag_sq + distance_eps) + U = V / mag + return U + + +def normed_cross( + V1: torch.Tensor, V2: torch.Tensor, distance_eps: float = 1e-3 +) -> torch.Tensor: + """Normalized cross product between vectors. + + This normalization is computed as `U = V / sqrt(|V|^2 + eps)` to avoid cusps + and gradient discontinuities. + + Args: + V1 (Tensor): Batch of vectors with shape `(..., 3)`. + V2 (Tensor): Batch of vectors with shape `(..., 3)`. + distance_eps (float, optional): Distance smoothing parameter for + for computing distances as `sqrt(sum_sq) -> sqrt(sum_sq + eps)`. + Default: 1E-3. + + Returns: + C (Tensor): Batch of cross products `v_1 x v_2` with shape `(..., 3)`. + """ + C = normed_vec(torch.cross(V1, V2, dim=-1), distance_eps=distance_eps) + return C + + +def lengths( + atom_i: torch.Tensor, atom_j: torch.Tensor, distance_eps: float = 1e-3 +) -> torch.Tensor: + """Batched bond lengths given batches of atom i and j. + + Args: + atom_i (Tensor): Atom `i` coordinates with shape `(..., 3)`. + atom_j (Tensor): Atom `j` coordinates with shape `(..., 3)`. + distance_eps (float, optional): Distance smoothing parameter for + for computing distances as `sqrt(sum_sq) -> sqrt(sum_sq + eps)`. + Default: 1E-3. + + Returns: + L (Tensor): Elementwise bond lengths `||x_i - x_j||` with shape `(...)`. + """ + # Bond length of i-j + dX = atom_j - atom_i + L = torch.sqrt((dX ** 2).sum(dim=-1) + distance_eps) + return L + + +def angles( + atom_i: torch.Tensor, + atom_j: torch.Tensor, + atom_k: torch.Tensor, + distance_eps: float = 1e-3, + degrees: bool = False, +) -> torch.Tensor: + """Batched bond angles given atoms `i-j-k`. + + Args: + atom_i (Tensor): Atom `i` coordinates with shape `(..., 3)`. + atom_j (Tensor): Atom `j` coordinates with shape `(..., 3)`. + atom_k (Tensor): Atom `k` coordinates with shape `(..., 3)`. + distance_eps (float, optional): Distance smoothing parameter for + for computing distances as `sqrt(sum_sq) -> sqrt(sum_sq + eps)`. + Default: 1E-3. + degrees (bool, optional): If True, convert to degrees. Default: False. + + Returns: + A (Tensor): Elementwise bond angles with shape `(...)`. + """ + # Bond angle of i-j-k + U_ji = normed_vec(atom_i - atom_j, distance_eps=distance_eps) + U_jk = normed_vec(atom_k - atom_j, distance_eps=distance_eps) + inner_prod = torch.einsum("bix,bix->bi", U_ji, U_jk) + inner_prod = torch.clamp(inner_prod, -1, 1) + A = torch.acos(inner_prod) + if degrees: + A = A * 180.0 / np.pi + return A + + +def dihedrals( + atom_i: torch.Tensor, + atom_j: torch.Tensor, + atom_k: torch.Tensor, + atom_l: torch.Tensor, + distance_eps: float = 1e-3, + degrees: bool = False, +) -> torch.Tensor: + """Batched bond dihedrals given atoms `i-j-k-l`. + + Args: + atom_i (Tensor): Atom `i` coordinates with shape `(..., 3)`. + atom_j (Tensor): Atom `j` coordinates with shape `(..., 3)`. + atom_k (Tensor): Atom `k` coordinates with shape `(..., 3)`. + atom_l (Tensor): Atom `l` coordinates with shape `(..., 3)`. + distance_eps (float, optional): Distance smoothing parameter for + for computing distances as `sqrt(sum_sq) -> sqrt(sum_sq + eps)`. + Default: 1E-3. + degrees (bool, optional): If True, convert to degrees. Default: False. + + Returns: + D (Tensor): Elementwise bond dihedrals with shape `(...)`. + """ + U_ij = normed_vec(atom_j - atom_i, distance_eps=distance_eps) + U_jk = normed_vec(atom_k - atom_j, distance_eps=distance_eps) + U_kl = normed_vec(atom_l - atom_k, distance_eps=distance_eps) + normal_ijk = normed_cross(U_ij, U_jk, distance_eps=distance_eps) + normal_jkl = normed_cross(U_jk, U_kl, distance_eps=distance_eps) + # _inner_product = lambda a, b: torch.einsum("bix,bix->bi", a, b) + _inner_product = lambda a, b: (a * b).sum(-1) + cos_dihedrals = _inner_product(normal_ijk, normal_jkl) + angle_sign = _inner_product(U_ij, normal_jkl) + cos_dihedrals = torch.clamp(cos_dihedrals, -1, 1) + D = torch.sign(angle_sign) * torch.acos(cos_dihedrals) + if degrees: + D = D * 180.0 / np.pi + return D + + +def extend_atoms( + X_1: torch.Tensor, + X_2: torch.Tensor, + X_3: torch.Tensor, + lengths: torch.Tensor, + angles: torch.Tensor, + dihedrals: torch.Tensor, + distance_eps: float = 1e-3, + degrees: bool = False, +) -> torch.Tensor: + """Place atom `X_4` given `X_1`, `X_2`, `X_3` and internal coordinates. + + ___________________ + | X_1 - X_2 | + | | | + | X_3 - [X_4] | + |___________________| + + This uses a similar approach as NERF: + Parsons et al, Computational Chemistry (2005). + https://doi.org/10.1002/jcc.20237 + See the reference for further explanation about converting from internal + coordinates to Cartesian coordinates. + + Args: + X_1 (Tensor): First atom coordinates with shape `(..., 3)`. + X_2 (Tensor): Second atom coordinates with shape `(..., 3)`. + X_3 (Tensor): Third atom coordinates with shape `(..., 3)`. + lengths (Tensor): Bond lengths `X_3-X_4` with shape `(...)`. + angles (Tensor): Bond angles `X_2-X_3-X_4` with shape `(...)`. + dihedrals (Tensor): Bond dihedrals `X_1-X_2-X_3-X_4` with shape `(...)`. + distance_eps (float, optional): Distance smoothing parameter for + for computing distances as `sqrt(sum_sq) -> sqrt(sum_sq + eps)`. + This preserves differentiability for zero distances. Default: 1E-3. + degrees (bool, optional): If True, inputs are treated as degrees. + Default: False. + + Returns: + X_4 (Tensor): Placed atom with shape `(..., 3)`. + """ + if degrees: + angles *= np.pi / 180.0 + dihedrals *= np.pi / 180.0 + + r_32 = X_2 - X_3 + r_12 = X_2 - X_1 + n_1 = normed_vec(r_32, distance_eps=distance_eps) + n_2 = normed_cross(n_1, r_12, distance_eps=distance_eps) + n_3 = normed_cross(n_1, n_2, distance_eps=distance_eps) + + lengths = lengths.unsqueeze(-1) + cos_angle = torch.cos(angles).unsqueeze(-1) + sin_angle = torch.sin(angles).unsqueeze(-1) + cos_dihedral = torch.cos(dihedrals).unsqueeze(-1) + sin_dihedral = torch.sin(dihedrals).unsqueeze(-1) + + X_4 = X_3 + lengths * ( + cos_angle * n_1 + + (sin_angle * sin_dihedral) * n_2 + + (sin_angle * cos_dihedral) * n_3 + ) + return X_4 + + +class InternalCoords(nn.Module): + """Internal coordinates layer. + + This layer computes internal coordinates (ICs) from a batch of protein + backbones. To make the ICs differentiable everywhere, this layer replaces + distance calculations of the form `sqrt(sum_sq)` with smooth, non-cusped + approximation `sqrt(sum_sq + eps)`. + + Args: + distance_eps (float, optional): Small parameter to add to squared + distances to make gradients smooth near 0. + + Inputs: + X (Tensor): Backbone coordinates with shape + `(num_batch, num_residues, num_atom_types, 3)`. + C (Tensor): Chain map tensor with shape + `(num_batch, num_residues)`. + + Outputs: + dihedrals (Tensor): Backbone dihedral angles with shape + `(num_batch, num_residues, 4)` + angles (Tensor): Backbone bond lengths with shape + `(num_batch, num_residues, 4)` + lengths (Tensor): Backbone bond lengths with shape + `(num_batch, num_residues, 4)` + """ + + def __init__(self, distance_eps=1e-3): + super(InternalCoords, self).__init__() + self.distance_eps = distance_eps + + def forward( + self, + X: torch.Tensor, + C: Optional[torch.Tensor] = None, + return_masks: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + mask = (C > 0).float() + X_chain = X[:, :, :3, :] + num_batch, num_residues, _, _ = X_chain.shape + X_chain = X_chain.reshape(num_batch, 3 * num_residues, 3) + + # This function historically returns the angle complement + _lengths = lambda Xi, Xj: lengths(Xi, Xj, distance_eps=self.distance_eps) + _angles = lambda Xi, Xj, Xk: np.pi - angles( + Xi, Xj, Xk, distance_eps=self.distance_eps + ) + _dihedrals = lambda Xi, Xj, Xk, Xl: dihedrals( + Xi, Xj, Xk, Xl, distance_eps=self.distance_eps + ) + + # Compute internal coordinates associated with -[N]-[CA]-[C]- + NCaC_L = _lengths(X_chain[:, 1:, :], X_chain[:, :-1, :]) + NCaC_A = _angles(X_chain[:, :-2, :], X_chain[:, 1:-1, :], X_chain[:, 2:, :]) + NCaC_D = _dihedrals( + X_chain[:, :-3, :], + X_chain[:, 1:-2, :], + X_chain[:, 2:-1, :], + X_chain[:, 3:, :], + ) + + # Compute internal coordinates associated with [C]=[O] + _, X_CA, X_C, X_O = X.unbind(dim=2) + X_N_next = X[:, 1:, 0, :] + O_L = _lengths(X_C, X_O) + O_A = _angles(X_CA, X_C, X_O) + O_D = _dihedrals(X_N_next, X_CA[:, :-1, :], X_C[:, :-1, :], X_O[:, :-1, :]) + + if C is None: + C = torch.zeros_like(mask) + + # Mask nonphysical bonds and angles + # Note: this could probably also be expressed as a Conv, unclear + # which is faster and this probably not rate-limiting. + C = C * (mask.type(torch.long)) + ii = torch.stack(3 * [C], dim=-1).view([num_batch, -1]) + L0, L1 = ii[:, :-1], ii[:, 1:] + A0, A1, A2 = ii[:, :-2], ii[:, 1:-1], ii[:, 2:] + D0, D1, D2, D3 = ii[:, :-3], ii[:, 1:-2], ii[:, 2:-1], ii[:, 3:] + + # Mask for linear backbone + mask_L = torch.eq(L0, L1) + mask_A = torch.eq(A0, A1) * torch.eq(A0, A2) + mask_D = torch.eq(D0, D1) * torch.eq(D0, D2) * torch.eq(D0, D3) + mask_L = mask_L.type(torch.float32) + mask_A = mask_A.type(torch.float32) + mask_D = mask_D.type(torch.float32) + + # Masks for branched oxygen + mask_O_D = torch.eq(C[:, :-1], C[:, 1:]) + mask_O_D = mask_O_D.type(torch.float32) + mask_O_A = mask + mask_O_L = mask + + def _pad_pack(D, A, L, O_D, O_A, O_L): + # Pad and pack together the components + D = F.pad(D, (1, 2)) + A = F.pad(A, (0, 2)) + L = F.pad(L, (0, 1)) + O_D = F.pad(O_D, (0, 1)) + D, A, L = [x.reshape(num_batch, num_residues, 3) for x in [D, A, L]] + _pack = lambda a, b: torch.cat([a, b.unsqueeze(-1)], dim=-1) + L = _pack(L, O_L) + A = _pack(A, O_A) + D = _pack(D, O_D) + return D, A, L + + D, A, L = _pad_pack(NCaC_D, NCaC_A, NCaC_L, O_D, O_A, O_L) + mask_D, mask_A, mask_L = _pad_pack( + mask_D, mask_A, mask_L, mask_O_D, mask_O_A, mask_O_L + ) + mask_expand = mask.unsqueeze(-1) + mask_D = mask_expand * mask_D + mask_A = mask_expand * mask_A + mask_L = mask_expand * mask_L + + D = mask_D * D + A = mask_A * A + L = mask_L * L + + if not return_masks: + return D, A, L + else: + return D, A, L, mask_D, mask_A, mask_L + + +class VirtualAtomsCA(nn.Module): + """Virtual atoms layer, branching from backbone C-alpha carbons. + + This layer places virtual atom coordinates relative to backbone coordinates + in a differentiable way. + + Args: + virtual_type (str, optional): Type of virtual atom to place. Currently + supported types are `dicons`, a virtual placement that was + optimized to predict potential rotamer interactions, and `cbeta` + which places a virtual C-beta carbon assuming ideal geometry. + distance_eps (float, optional): Small parameter to add to squared + distances to make gradients smooth near 0. + + Inputs: + X (Tensor): Backbone coordinates with shape + `(num_batch, num_residues, num_atom_types, 3)`. + C (Tensor): Chain map tensor with shape `(num_batch, num_residues)`. + + Outputs: + X_virtual (Tensor): Virtual coordinates with shape + `(num_batch, num_residues, 3)`. + """ + + def __init__(self, virtual_type="dicons", distance_eps=1e-3): + super(VirtualAtomsCA, self).__init__() + self.distance_eps = distance_eps + + """ + Geometry specifications + dicons + Length CA-X: 2.3866 + Angle N-CA-X: 111.0269 + Dihedral C-N-CA-X: -138.886412 + + cbeta + Length CA-X: 1.532 (Engh and Huber, 2001) + Angle N-CA-X: 109.5 (tetrahedral geometry) + Dihedral C-N-CA-X: -125.25 (109.5 / 2 - 180) + """ + self.virtual_type = virtual_type + virtual_geometries = { + "dicons": [2.3866, 111.0269, -138.8864122], + "cbeta": [1.532, 109.5, -125.25], + } + self.virtual_geometries = virtual_geometries + self.distance_eps = distance_eps + + def geometry(self): + bond, angle, dihedral = self.virtual_geometries[self.virtual_type] + return bond, angle, dihedral + + def forward(self, X: torch.Tensor, C: torch.LongTensor) -> torch.Tensor: + bond, angle, dihedral = self.geometry() + + ones = torch.ones([1, 1], device=X.device) + bonds = bond * ones + angles = angle * ones + dihedrals = dihedral * ones + + # Build reference frame + # 1.C -> 2.N -> 3.CA -> 4.X + X_N, X_CA, X_C, X_O = X.unbind(2) + X_virtual = extend_atoms( + X_C, + X_N, + X_CA, + bonds, + angles, + dihedrals, + degrees=True, + distance_eps=self.distance_eps, + ) + + # Mask missing positions + mask = (C > 0).type(torch.float32).unsqueeze(-1) + X_virtual = mask * X_virtual + return X_virtual + + +def quaternions_from_rotations(R: torch.Tensor, eps: float = 1e-3) -> torch.Tensor: + """Convert a batch of rotation matrices to quaternions. + + See en.wikipedia.org/wiki/Quaternions_and_spatial_rotation for further + details on converting between quaternions and rotation matrices. + + Args: + R (tensor): Batch of rotation matrices with shape `(..., 3, 3)`. + + Returns: + q (tensor): Batch of quaternion vectors with shape `(..., 4)`. Quaternion + is in the order `[angle, axis_x, axis_y, axis_z]`. + """ + + batch_dims = list(R.shape)[:-2] + R_flat = R.reshape(batch_dims + [9]) + R00, R01, R02, R10, R11, R12, R20, R21, R22 = R_flat.unbind(-1) + + # Quaternion possesses both an axis and angle of rotation + _sqrt = lambda r: torch.sqrt(F.relu(r) + eps) + q_angle = _sqrt(1 + R00 + R11 + R22).unsqueeze(-1) + magnitudes = _sqrt( + 1 + torch.stack([R00 - R11 - R22, -R00 + R11 - R22, -R00 - R11 + R22], -1) + ) + signs = torch.sign(torch.stack([R21 - R12, R02 - R20, R10 - R01], -1)) + q_axis = signs * magnitudes + + # Normalize (for safety and a missing factor of 2) + q_unc = torch.cat((q_angle, q_axis), -1) + q = normed_vec(q_unc, distance_eps=eps) + return q + + +def rotations_from_quaternions( + q: torch.Tensor, normalize: bool = False, eps: float = 1e-3 +) -> torch.Tensor: + """Convert a batch of quaternions to rotation matrices. + + See en.wikipedia.org/wiki/Quaternions_and_spatial_rotation for further + details on converting between quaternions and rotation matrices. + + Returns: + q (tensor): Batch of quaternion vectors with shape `(..., 4)`. Quaternion + is in the order `[angle, axis_x, axis_y, axis_z]`. + normalize (boolean, optional): Option to normalize the quaternion before + conversion. + + Args: + R (tensor): Batch of rotation matrices with shape `(..., 3, 3)`. + """ + batch_dims = list(q.shape)[:-1] + if normalize: + q = normed_vec(q, distance_eps=eps) + + a, b, c, d = q.unbind(-1) + a2, b2, c2, d2 = a ** 2, b ** 2, c ** 2, d ** 2 + R = torch.stack( + [ + a2 + b2 - c2 - d2, + 2 * b * c - 2 * a * d, + 2 * b * d + 2 * a * c, + 2 * b * c + 2 * a * d, + a2 - b2 + c2 - d2, + 2 * c * d - 2 * a * b, + 2 * b * d - 2 * a * c, + 2 * c * d + 2 * a * b, + a2 - b2 - c2 + d2, + ], + dim=-1, + ) + + R = R.view(batch_dims + [3, 3]) + return R + + +def frames_from_backbone(X: torch.Tensor, distance_eps: float = 1e-3): + """Convert a backbone into local reference frames. + + Args: + X (Tensor): Backbone coordinates with shape `(..., 4, 3)`. + distance_eps (float, optional): Distance smoothing parameter for + for computing distances as `sqrt(sum_sq) -> sqrt(sum_sq + eps)`. + Default: 1E-3. + + Returns: + R (Tensor): Reference frames with shape `(..., 3, 3)`. + X_CA (Tensor): C-alpha coordinates with shape `(..., 3)` + """ + X_N, X_CA, X_C, X_O = X.unbind(-2) + u_CA_N = normed_vec(X_N - X_CA, distance_eps) + u_CA_C = normed_vec(X_C - X_CA, distance_eps) + n_1 = u_CA_N + n_2 = normed_cross(n_1, u_CA_C, distance_eps) + n_3 = normed_cross(n_1, n_2, distance_eps) + R = torch.stack([n_1, n_2, n_3], -1) + return R, X_CA + + +def hat(omega: torch.Tensor) -> torch.Tensor: + """ + Maps [x,y,z] to [[0,-z,y], [z,0,-x], [-y, x, 0]] + Args: + omega (torch.tensor): of size (*, 3) + Returns: + hat{omega} (torch.tensor): of size (*, 3, 3) skew symmetric element in so(3) + """ + target = torch.zeros(*omega.size()[:-1], 9, device=omega.device) + index1 = torch.tensor([7, 2, 3], device=omega.device).expand( + *target.size()[:-1], -1 + ) + index2 = torch.tensor([5, 6, 1], device=omega.device).expand( + *target.size()[:-1], -1 + ) + return ( + target.scatter(-1, index1, omega) + .scatter(-1, index2, -omega) + .reshape(*target.size()[:-1], 3, 3) + ) + + +def V(omega: torch.Tensor, eps: float = 1e-6) -> torch.Tensor: + I = torch.eye(3, device=omega.device).expand(*omega.size()[:-1], 3, 3) + theta = omega.pow(2).sum(dim=-1, keepdim=True).add(eps).sqrt()[..., None] + omega_hat = hat(omega) + M1 = ((1 - theta.cos()) / theta.pow(2)) * (omega_hat) + M2 = ((theta - theta.sin()) / theta.pow(3)) * (omega_hat @ omega_hat) + return I + M1 + M2 diff --git a/models/dyMEAN/model.py b/models/dyMEAN/model.py new file mode 100644 index 0000000000000000000000000000000000000000..aa6996567ae18c3ee40045879a63c79a1949ec17 --- /dev/null +++ b/models/dyMEAN/model.py @@ -0,0 +1,302 @@ +#!/usr/bin/python +# -*- coding:utf-8 -*- +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch_scatter import scatter_mean + +from data.format import VOCAB +from utils.nn_utils import variadic_meshgrid +from utils import register as R +from utils.oom_decorator import oom_decorator + +from .modules.am_egnn import AMEGNN +from .nn_utils import SeparatedAminoAcidFeature, ProteinFeature + + +@R.register('dyMEAN') +class dyMEAN(nn.Module): + def __init__( + self, + embed_size, + hidden_size, + n_channel, + num_classes=len(VOCAB), + mask_id=VOCAB.get_mask_idx(), + max_position=2048, + CA_channel_idx=VOCAB.backbone_atoms.index('CA'), + n_layers=3, + iter_round=3, + dropout=0.1, + fix_atom_weights=False, + relative_position=False, + mode='codesign', # fixbb, fixseq(structure prediction), codesign + std=10.0 + ) -> None: + super().__init__() + self.mask_id = mask_id + self.num_classes = num_classes + self.ca_channel_idx = CA_channel_idx + self.round = iter_round + self.mode = mode + self.std = std + + atom_embed_size = embed_size // 4 + self.aa_feature = SeparatedAminoAcidFeature( + embed_size, atom_embed_size, + max_position=max_position, + relative_position=relative_position, + fix_atom_weights=fix_atom_weights + ) + self.protein_feature = ProteinFeature() + + self.memory_ffn = nn.Sequential( + nn.SiLU(), + nn.Linear(hidden_size, hidden_size), + nn.SiLU(), + nn.Linear(hidden_size, embed_size) + ) + if self.mode != 'fixseq': + self.ffn_residue = nn.Sequential( + nn.SiLU(), + nn.Linear(hidden_size, hidden_size), + nn.SiLU(), + nn.Linear(hidden_size, self.num_classes) + ) + else: + self.prmsd_ffn = nn.Sequential( + nn.SiLU(), + nn.Linear(hidden_size, hidden_size), + nn.SiLU(), + nn.Linear(hidden_size, 1) + ) + self.gnn = AMEGNN( + embed_size, hidden_size, hidden_size, n_channel, + channel_nf=atom_embed_size, radial_nf=hidden_size, + in_edge_nf=0, n_layers=n_layers, residual=True, + dropout=dropout, dense=False) + + # training related cache + self.batch_constants = {} + + def init_mask(self, X, S, xmask, smask, batch_ids): + X, S = X.clone(), S.clone() # [N, 14, 3], [N] + n_channel, n_dim = X.shape[1:] + if self.mode != 'fixseq': + S[smask] = self.mask_id + + if self.mode != 'fixbb': + receptor_centers = scatter_mean(X[~xmask][:, self.ca_channel_idx], batch_ids[~xmask], dim=0) # [bs, 3] + ligand_ca = torch.randn_like(X[xmask][:, self.ca_channel_idx]) * self.std + receptor_centers[batch_ids[xmask]] # [Nlig, 3] + ligand_X = ligand_ca.unsqueeze(1).repeat(1, n_channel, 1) + ligand_X = ligand_X + torch.randn_like(ligand_X) * self.std * 0.1 # smaller scale + X[xmask] = ligand_X + + return X, S + + def message_passing(self, X, S, position_ids, ctx_edges, inter_edges, atom_weights, memory_H=None, smooth_prob=None, smooth_mask=None): + # embeddings + H_0, (atom_embeddings, _) = self.aa_feature(S, position_ids, smooth_prob=smooth_prob, smooth_mask=smooth_mask) + + if memory_H is not None: + H_0 = H_0 + self.memory_ffn(memory_H) + edges = torch.cat([ctx_edges, inter_edges], dim=1) + + H, pred_X = self.gnn(H_0, X, edges, + channel_attr=atom_embeddings, + channel_weights=atom_weights) + + + pred_logits = None if self.mode == 'fixseq' else self.ffn_residue(H) + + return pred_logits, pred_X, H # [N, num_classes], [N, n_channel, 3], [N, hidden_size] + + @torch.no_grad() + def prepare_inputs(self, X, S, xmask, smask, lengths): + + # batch ids + batch_ids = torch.zeros_like(S) + batch_ids[torch.cumsum(lengths, dim=0)[:-1]] = 1 + batch_ids.cumsum_(dim=0) + + # initialization + X, S = self.init_mask(X, S, xmask, smask, batch_ids) + aa_cnt = smask.sum() + + # edges + row, col = variadic_meshgrid( + input1=torch.arange(batch_ids.shape[0], device=batch_ids.device), + size1=lengths, + input2=torch.arange(batch_ids.shape[0], device=batch_ids.device), + size2=lengths, + ) # (row, col) + + is_ctx = xmask[row] == xmask[col] + is_inter = ~is_ctx + ctx_edges = torch.stack([row[is_ctx], col[is_ctx]], dim=0) # [2, Ec] + inter_edges = torch.stack([row[is_inter], col[is_inter]], dim=0) # [2, Ei] + + special_mask = torch.tensor(VOCAB.get_special_mask(), device=S.device, dtype=torch.long) + special_mask = special_mask.repeat(aa_cnt, 1).bool() + + return X, S, aa_cnt, ctx_edges, inter_edges, special_mask, batch_ids + + def normalize(self, X, batch_ids, mask): + centers = scatter_mean(X[~mask], batch_ids[~mask], dim=0) # [bs, 4, 3] + centers = centers.mean(dim=1)[batch_ids].unsqueeze(1) # [N, 4, 3] + X = (X - centers) / self.std + return X, centers + + def _forward(self, X, S, xmask, smask, special_mask, position_ids, ctx_edges, inter_edges, atom_weights): + + # sequence and structure loss + r_pred_S_logits, pred_S_dist = [], None + memory_H = None + # message passing + for t in range(self.round): + pred_S_logits, pred_X, H = self.message_passing( + X, S, position_ids, ctx_edges, inter_edges, + atom_weights, memory_H, pred_S_dist, smask) + r_pred_S_logits.append(pred_S_logits) + memory_H = H + # 1. update X + X = X.clone() + X[xmask] = pred_X[xmask] + + if self.mode != 'fixseq': + # 2. update S + S = S.clone() + if t == self.round - 1: + S[smask] = torch.argmax(pred_S_logits[smask].masked_fill(special_mask, float('-inf')), dim=-1) + else: + pred_S_dist = torch.softmax(pred_S_logits[smask].masked_fill(special_mask, float('-inf')), dim=-1) + + if self.mode == 'fixseq': + # predicted rmsd + prmsd = self.prmsd_ffn(H[xmask]).squeeze() # [N_ab] + else: + prmsd = None + + return H, S, r_pred_S_logits, pred_X, prmsd + + @oom_decorator + def forward(self, X, S, mask, position_ids, lengths, atom_mask, context_ratio=0): + ''' + :param X: [N, 14, 3] + :param S: [N] + :param smask: [N] + :param position_ids: [N], residue position ids + :param context_ratio: float, rate of context provided in masked sequence, should be [0, 1) and anneal to 0 in training + ''' + # clone ground truth coordinates, sequence + true_X, true_S = X.clone(), S.clone() + xmask, smask = mask, mask + + # provide some ground truth for annealing sequence training + if context_ratio > 0: + not_ctx_mask = torch.rand_like(smask, dtype=torch.float) >= context_ratio + smask = torch.logical_and(smask, not_ctx_mask) + + # prepare + X, S, aa_cnt, ctx_edges, inter_edges, special_mask, batch_ids = self.prepare_inputs(X, S, xmask, smask, lengths) + atom_weights = torch.logical_or(atom_mask, xmask.unsqueeze(1)).float() if self.mode != 'fixbb' else atom_mask.float() + X, centers = self.normalize(X, batch_ids, xmask) + true_X, _ = self.normalize(true_X, batch_ids, xmask) + + # get results + H, pred_S, r_pred_S_logits, pred_X, prmsd = self._forward( + X, S, xmask, smask, special_mask, position_ids, ctx_edges, inter_edges, atom_weights) + # # unnormalize + # pred_X = pred_X * self.std + centers + + # sequence negtive log likelihood + snll = 0 + if self.mode != 'fixseq': + for logits in r_pred_S_logits: + snll = snll + F.cross_entropy(logits[smask].masked_fill(special_mask, float('-inf')), true_S[smask], reduction='sum') / (aa_cnt + 1e-10) + snll = snll / self.round + + # coordination loss + if self.mode != 'fixbb': + segment_ids, gen_X, ref_X = torch.ones_like(pred_S[xmask], device=pred_X.device, dtype=torch.long), pred_X[xmask], true_X[xmask] + # backbone bond lengths loss + bb_bond_loss = F.l1_loss( + self.protein_feature._cal_backbone_bond_lengths(gen_X, batch_ids[xmask], segment_ids, atom_mask[xmask]), + self.protein_feature._cal_backbone_bond_lengths(ref_X, batch_ids[xmask], segment_ids, atom_mask[xmask]) + ) + # side-chain bond lengths loss + sc_bond_loss = F.l1_loss( + self.protein_feature._cal_sidechain_bond_lengths(true_S[xmask], gen_X, self.aa_feature, atom_mask[xmask]), + self.protein_feature._cal_sidechain_bond_lengths(true_S[xmask], ref_X, self.aa_feature, atom_mask[xmask]) + ) + # mse + xloss_mask = atom_mask.unsqueeze(-1).repeat(1, 1, 3) & mask.unsqueeze(-1).unsqueeze(-1) # [N, 14, 3] + xloss = F.mse_loss(pred_X[xloss_mask], true_X[xloss_mask]) + # CA pair-wise distance + dist_loss = F.l1_loss( + torch.norm(pred_X[:, self.ca_channel_idx][inter_edges.T[0]] - pred_X[:, self.ca_channel_idx][inter_edges.T[1]], dim=-1), + torch.norm(true_X[:, self.ca_channel_idx][inter_edges.T[0]] - true_X[:, self.ca_channel_idx][inter_edges.T[1]], dim=-1) + ) + struct_loss = bb_bond_loss + sc_bond_loss + xloss + dist_loss + else: + struct_loss, bb_bond_loss, sc_bond_loss, xloss, dist_loss = 0, 0, 0, 0, 0 + + if self.mode != 'fixbb': + # predicted rmsd + prmsd_loss = 0 # TODO: residue-wise rmsd + pdev_loss = prmsd_loss# + prmsd_i_loss + else: + pdev_loss, prmsd_loss = None, None + + # comprehensive loss, 5 for similar scale + loss = snll + 5 * struct_loss + (0 if pdev_loss is None else pdev_loss)# + 0 * ed_loss + + # AAR + with torch.no_grad(): + aa_hit = pred_S[smask] == true_S[smask] + aar = aa_hit.long().sum() / (aa_hit.shape[0] + 1e-10) + + return loss, (snll, aar), (struct_loss, (bb_bond_loss, sc_bond_loss, xloss, dist_loss)), (pdev_loss, prmsd_loss)# , (ed_loss, r_ed_losses) + + def sample(self, X, S, mask, position_ids, lengths, atom_mask, greedy=False): + gen_X, gen_S = X.clone(), S.clone() + xmask, smask = mask, mask + + # prepare + X, S, aa_cnt, ctx_edges, inter_edges, special_mask, batch_ids = self.prepare_inputs(X, S, xmask, smask, lengths) + atom_weights = torch.logical_or(atom_mask, xmask.unsqueeze(1)).float() if self.mode != 'fixbb' else atom_weights.float() + X, centers = self.normalize(X, batch_ids, xmask) + + # get results + H, pred_S, r_pred_S_logits, pred_X, prmsd = self._forward(X, S, xmask, smask, special_mask, position_ids, ctx_edges, inter_edges, atom_weights) + # unnormalize + pred_X = pred_X * self.std + centers + + + if self.mode != 'fixseq': + + logits = r_pred_S_logits[-1][smask] + logits = logits.masked_fill(special_mask, float('-inf')) # mask special tokens + + if greedy: + gen_S[smask] = torch.argmax(logits, dim=-1) # [n] + else: + prob = F.softmax(logits, dim=-1) + gen_S[smask] = torch.multinomial(prob, num_samples=1).squeeze() + snll_all = F.cross_entropy(logits, gen_S[smask], reduction='none') + else: + snll_all = torch.zeros_like(gen_S[smask]).float() + + gen_X[xmask] = pred_X[xmask] + + batch_X, batch_S, batch_ppls = [], [], [] + for i, l in enumerate(lengths): + cur_mask = mask & (batch_ids == i) + batch_X.append(gen_X[cur_mask].tolist()) + batch_S.append(''.join([VOCAB.idx_to_symbol(s) for s in gen_S[cur_mask]])) + batch_ppls.append( + torch.exp(snll_all[cur_mask[mask]].sum() / cur_mask.sum()).item() + ) + return batch_X, batch_S, batch_ppls diff --git a/models/dyMEAN/modules/am_egnn.py b/models/dyMEAN/modules/am_egnn.py new file mode 100644 index 0000000000000000000000000000000000000000..a417df2ca1c393c652917b374d4c95379fc5ea40 --- /dev/null +++ b/models/dyMEAN/modules/am_egnn.py @@ -0,0 +1,372 @@ +#!/usr/bin/python +# -*- coding:utf-8 -*- +import torch +import torch.nn as nn + +from utils.decorators import singleton + +from .radial_basis import RadialBasis + + +class RadialLinear(nn.Module): + def __init__(self, n_rbf, cutoff): + super().__init__() + self.rbf = RadialBasis(n_rbf, cutoff) + self.linear = nn.Linear(n_rbf, 1) + + def forward(self, d): + ''' + args: + d: distance feature [N, ...] + returns: + radial: the same shape with input d, [N, ...] + ''' + output_shape = d.shape + radial = self.rbf(d.view(-1)) # [N*d1*d2..., n_rbf] + radial = self.linear(radial).squeeze(-1) + return radial.view(*output_shape) + + +class AMEGNN(nn.Module): + + def __init__(self, in_node_nf, hidden_nf, out_node_nf, n_channel, channel_nf, + radial_nf, in_edge_nf=0, act_fn=nn.SiLU(), n_layers=4, + residual=True, dropout=0.1, dense=False, n_rbf=0, cutoff=1.0): + super().__init__() + ''' + :param in_node_nf: Number of features for 'h' at the input + :param hidden_nf: Number of hidden features + :param out_node_nf: Number of features for 'h' at the output + :param n_channel: Number of channels of coordinates + :param in_edge_nf: Number of features for the edge features + :param act_fn: Non-linearity + :param n_layers: Number of layer for the EGNN + :param residual: Use residual connections, we recommend not changing this one + :param dropout: probability of dropout + :param dense: if dense, then context states will be concatenated for all layers, + coordination will be averaged + ''' + self.hidden_nf = hidden_nf + self.n_layers = n_layers + + self.dropout = nn.Dropout(dropout) + + self.linear_in = nn.Linear(in_node_nf, self.hidden_nf) + + self.dense = dense + if dense: + self.linear_out = nn.Linear(self.hidden_nf * (n_layers + 1), out_node_nf) + else: + self.linear_out = nn.Linear(self.hidden_nf, out_node_nf) + + for i in range(0, n_layers): + self.add_module(f'gcl_{i}', AM_E_GCL( + self.hidden_nf, self.hidden_nf, self.hidden_nf, n_channel, channel_nf, radial_nf, + edges_in_d=in_edge_nf, act_fn=act_fn, residual=residual, dropout=dropout, n_rbf=n_rbf, cutoff=cutoff + )) + self.out_layer = AM_E_GCL( + self.hidden_nf, self.hidden_nf, self.hidden_nf, n_channel, channel_nf, + radial_nf, edges_in_d=in_edge_nf, act_fn=act_fn, residual=residual, n_rbf=n_rbf, cutoff=cutoff + ) + + def forward(self, h, x, edges, channel_attr, channel_weights, ctx_edge_attr=None, x_update_mask=None): + h = self.linear_in(h) + h = self.dropout(h) + + ctx_states, ctx_coords = [], [] + for i in range(0, self.n_layers): + h, x = self._modules[f'gcl_{i}']( + h, edges, x, channel_attr, channel_weights, + edge_attr=ctx_edge_attr, x_update_mask=x_update_mask) + + ctx_states.append(h) + ctx_coords.append(x) + + h, x = self.out_layer( + h, edges, x, channel_attr, channel_weights, + edge_attr=ctx_edge_attr, x_update_mask=x_update_mask) + ctx_states.append(h) + ctx_coords.append(x) + if self.dense: + h = torch.cat(ctx_states, dim=-1) + x = torch.mean(torch.stack(ctx_coords), dim=0) + h = self.dropout(h) + h = self.linear_out(h) + return h, x + +''' +Below are the implementation of the adaptive multi-channel message passing mechanism +''' + +@singleton +class RollerPooling(nn.Module): + ''' + Adaptive average pooling for the adaptive scaler + ''' + def __init__(self, n_channel) -> None: + super().__init__() + self.n_channel = n_channel + with torch.no_grad(): + pool_matrix = [] + ones = torch.ones((n_channel, n_channel), dtype=torch.float) + for i in range(n_channel): + # i start from 0 instead of 1 !!! (less readable but higher implemetation efficiency) + window_size = n_channel - i + mat = torch.triu(ones) - torch.triu(ones, diagonal=window_size) + pool_matrix.append(mat / window_size) + self.pool_matrix = torch.stack(pool_matrix) + + def forward(self, hidden, target_size): + ''' + :param hidden: [n_edges, n_channel] + :param target_size: [n_edges] + ''' + pool_mat = self.pool_matrix.to(hidden.device).type(hidden.dtype) + pool_mat = pool_mat[target_size - 1] # [n_edges, n_channel, n_channel] + hidden = hidden.unsqueeze(-1) # [n_edges, n_channel, 1] + return torch.bmm(pool_mat, hidden) # [n_edges, n_channel, 1] + + +class AM_E_GCL(nn.Module): + ''' + Adaptive Multi-Channel E(n) Equivariant Convolutional Layer + ''' + + def __init__(self, input_nf, output_nf, hidden_nf, n_channel, channel_nf, radial_nf, + edges_in_d=0, node_attr_d=0, act_fn=nn.SiLU(), residual=True, attention=False, + normalize=False, coords_agg='mean', tanh=False, dropout=0.1, n_rbf=0, cutoff=1.0): + super(AM_E_GCL, self).__init__() + + input_edge = input_nf * 2 + self.residual = residual + self.attention = attention + self.normalize = normalize + self.coords_agg = coords_agg + self.tanh = tanh + self.epsilon = 1e-8 + + self.dropout = nn.Dropout(dropout) + + input_edge = input_nf * 2 + self.edge_mlp = nn.Sequential( + nn.Linear(input_edge + radial_nf + edges_in_d, hidden_nf), + act_fn, + nn.Linear(hidden_nf, hidden_nf), + act_fn) + self.radial_linear = nn.Linear(channel_nf ** 2, radial_nf) + + self.node_mlp = nn.Sequential( + nn.Linear(hidden_nf + input_nf + node_attr_d, hidden_nf), + act_fn, + nn.Linear(hidden_nf, output_nf)) + + layer = nn.Linear(hidden_nf, n_channel, bias=False) + torch.nn.init.xavier_uniform_(layer.weight, gain=0.001) + + coord_mlp = [] + coord_mlp.append(nn.Linear(hidden_nf, hidden_nf)) + coord_mlp.append(act_fn) + coord_mlp.append(layer) + if self.tanh: + coord_mlp.append(nn.Tanh()) + self.coord_mlp = nn.Sequential(*coord_mlp) + + if self.attention: + self.att_mlp = nn.Sequential( + nn.Linear(hidden_nf, 1), + nn.Sigmoid()) + + if n_rbf > 1: + self.rbf_linear = RadialLinear(n_rbf, cutoff) + + def edge_model(self, source, target, radial, edge_attr): + ''' + :param source: [n_edge, input_size] + :param target: [n_edge, input_size] + :param radial: [n_edge, d, d] + :param edge_attr: [n_edge, edge_dim] + ''' + radial = radial.reshape(radial.shape[0], -1) # [n_edge, d ^ 2] + + if edge_attr is None: # Unused. + out = torch.cat([source, target, radial], dim=1) + else: + out = torch.cat([source, target, radial, edge_attr], dim=1) + out = self.edge_mlp(out) + out = self.dropout(out) + + if self.attention: + att_val = self.att_mlp(out) + out = out * att_val + return out + + def node_model(self, x, edge_index, edge_attr, node_attr): + ''' + :param x: [bs * n_node, input_size] + :param edge_index: list of [n_edge], [n_edge] + :param edge_attr: [n_edge, hidden_size], refers to message from i to j + :param node_attr: [bs * n_node, node_dim] + ''' + row, col = edge_index + agg = unsorted_segment_sum(edge_attr, row, num_segments=x.size(0)) # [bs * n_node, hidden_size] + # print_log(f'agg1, {torch.isnan(agg).sum()}', level='DEBUG') + if node_attr is not None: + agg = torch.cat([x, agg, node_attr], dim=1) + else: + agg = torch.cat([x, agg], dim=1) # [bs * n_node, input_size + hidden_size] + # print_log(f'agg, {torch.isnan(agg).sum()}', level='DEBUG') + out = self.node_mlp(agg) # [bs * n_node, output_size] + # print_log(f'out, {torch.isnan(out).sum()}', level='DEBUG') + out = self.dropout(out) + if self.residual: + out = x + out + return out, agg + + def coord_model(self, coord, edge_index, coord_diff, edge_feat, channel_attr, channel_weights, x_update_mask=None): + ''' + coord: [N, n_channel, d] + edge_index: list of [n_edge], [n_edge] + coord_diff: [n_edge, n_channel, d] + edge_feat: [n_edge, hidden_size] + channel_attr: [N, n_channel, channel_nf] + channel_weights: [N, n_channel] + x_update_mask: [N, n_channel], 1 for updating coordinates + ''' + row, col = edge_index + + # first pooling, then element-wise multiply + n_channel = channel_weights.shape[-1] + edge_feat = self.coord_mlp(edge_feat) # [n_edge, n_channel] + channel_sum = (channel_weights != 0).long().sum(-1) # [N] + pooled_edge_feat = RollerPooling(n_channel)(edge_feat, channel_sum[row]) # [n_edge, n_channel, 1] + trans = coord_diff * pooled_edge_feat # [n_edge, n_channel, d] + + # aggregate + if self.coords_agg == 'sum': + agg = unsorted_segment_sum(trans, row, num_segments=coord.size(0)) + elif self.coords_agg == 'mean': + agg = unsorted_segment_mean(trans, row, num_segments=coord.size(0)) # [N, n_channel, d] + else: + raise Exception('Wrong coords_agg parameter' % self.coords_agg) + if x_update_mask is None: + coord = coord + agg + else: + x_update_mask = x_update_mask.unsqueeze(-1).float() # [N, n_channel, 1] + coord = coord + x_update_mask * agg + return coord + + def forward(self, h, edge_index, coord, channel_attr, channel_weights, + edge_attr=None, node_attr=None, x_update_mask=None): + ''' + h: [bs * n_node, hidden_size] + edge_index: list of [n_row] and [n_col] where n_row == n_col (with no cutoff, n_row == bs * n_node * (n_node - 1)) + coord: [bs * n_node, n_channel, d] + channel_attr: [bs * n_node, n_channel, channel_nf] + channel_weights: [bs * n_node, n_channel] + x_update_mask: [bs * n_node, n_channel], 1 for updating coordinates + ''' + row, col = edge_index + + radial, coord_diff = coord2radial(edge_index, coord, channel_attr, channel_weights, self.radial_linear, getattr(self, 'rbf_linear', None)) + + edge_feat = self.edge_model(h[row], h[col], radial, edge_attr) # [n_edge, hidden_size] + coord = self.coord_model(coord, edge_index, coord_diff, edge_feat, channel_attr, channel_weights, x_update_mask) # [bs * n_node, n_channel, d] + h, agg = self.node_model(h, edge_index, edge_feat, node_attr) + + return h, coord + + +def unsorted_segment_sum(data, segment_ids, num_segments): + ''' + :param data: [n_edge, *dimensions] + :param segment_ids: [n_edge] + :param num_segments: [bs * n_node] + ''' + expand_dims = tuple(data.shape[1:]) + result_shape = (num_segments, ) + expand_dims + for _ in expand_dims: + segment_ids = segment_ids.unsqueeze(-1) + segment_ids = segment_ids.expand(-1, *expand_dims) + result = data.new_full(result_shape, 0) # Init empty result tensor. + result.scatter_add_(0, segment_ids, data) + return result + + +def unsorted_segment_mean(data, segment_ids, num_segments): + ''' + :param data: [n_edge, *dimensions] + :param segment_ids: [n_edge] + :param num_segments: [bs * n_node] + ''' + expand_dims = tuple(data.shape[1:]) + result_shape = (num_segments, ) + expand_dims + for _ in expand_dims: + segment_ids = segment_ids.unsqueeze(-1) + segment_ids = segment_ids.expand(-1, *expand_dims) + result = data.new_full(result_shape, 0) # Init empty result tensor. + count = data.new_full(result_shape, 0) + result.scatter_add_(0, segment_ids, data) + count.scatter_add_(0, segment_ids, torch.ones_like(data)) + return result / count.clamp(min=1) + + +CONSTANT = 1 +NUM_SEG = 1 # if you do not have enough memory or you have large attr_size, increase this parameter + +def coord2radial(edge_index, coord, attr, channel_weights, linear_map, rbf_linear=None): + ''' + :param edge_index: tuple([n_edge], [n_edge]) which is tuple of (row, col) + :param coord: [N, n_channel, d] + :param attr: [N, n_channel, attr_size], attribute embedding of each channel + :param channel_weights: [N, n_channel], weights of different channels + :param linear_map: nn.Linear, map features to d_out + :param num_seg: split row/col into segments to reduce memory cost + ''' + row, col = edge_index + + radials = [] + + seg_size = (len(row) + NUM_SEG - 1) // NUM_SEG + + for i in range(NUM_SEG): + start = i * seg_size + end = min(start + seg_size, len(row)) + if end <= start: + break + seg_row, seg_col = row[start:end], col[start:end] + + coord_msg = torch.norm( + coord[seg_row].unsqueeze(2) - coord[seg_col].unsqueeze(1), # [n_edge, n_channel, n_channel, d] + dim=-1, keepdim=False) # [n_edge, n_channel, n_channel] + if rbf_linear: + coord_msg = rbf_linear(coord_msg) + + coord_msg = coord_msg * torch.bmm( + channel_weights[seg_row].unsqueeze(2), + channel_weights[seg_col].unsqueeze(1) + ) # [n_edge, n_channel, n_channel] + + radial = torch.bmm( + attr[seg_row].transpose(-1, -2), # [n_edge, attr_size, n_channel] + coord_msg) # [n_edge, attr_size, n_channel] + radial = torch.bmm(radial, attr[seg_col]) # [n_edge, attr_size, attr_size] + radial = radial.reshape(radial.shape[0], -1) # [n_edge, attr_size * attr_size] + if rbf_linear: # do not need normalization + radial = linear_map(radial) + else: + radial_norm = torch.norm(radial, dim=-1, keepdim=True) + CONSTANT # post norm + radial = linear_map(radial) / radial_norm # [n_edge, d_out] + + radials.append(radial) + + radials = torch.cat(radials, dim=0) # [N_edge, d_out] + + # generate coord_diff by first mean src then minused by dst + # message passed from col to row + channel_mask = (channel_weights != 0).long() # [N, n_channel] + channel_sum = channel_mask.sum(-1) # [N] + pooled_col_coord = (coord[col] * channel_mask[col].unsqueeze(-1)).sum(1) # [n_edge, d] + pooled_col_coord = pooled_col_coord / channel_sum[col].unsqueeze(-1) # [n_edge, d], denominator cannot be 0 since no pad node exists + coord_diff = coord[row] - pooled_col_coord.unsqueeze(1) # [n_edge, n_channel, d] + + return radials, coord_diff \ No newline at end of file diff --git a/models/dyMEAN/modules/am_enc.py b/models/dyMEAN/modules/am_enc.py new file mode 100644 index 0000000000000000000000000000000000000000..eac121418362c80c7d80269088eaf82fea508274 --- /dev/null +++ b/models/dyMEAN/modules/am_enc.py @@ -0,0 +1,93 @@ +#!/usr/bin/python +# -*- coding:utf-8 -*- +import torch +import torch.nn as nn + +from torch_scatter import scatter_softmax +from .am_egnn import AM_E_GCL + + +class AMEncoder(nn.Module): + + def __init__(self, in_node_nf, hidden_nf, out_node_nf, n_channel, channel_nf, + radial_nf, in_edge_nf=0, act_fn=nn.SiLU(), n_layers=4, + residual=True, dropout=0.1, dense=False): + super().__init__() + ''' + :param in_node_nf: Number of features for 'h' at the input + :param hidden_nf: Number of hidden features + :param out_node_nf: Number of features for 'h' at the output + :param n_channel: Number of channels of coordinates + :param in_edge_nf: Number of features for the edge features + :param act_fn: Non-linearity + :param n_layers: Number of layer for the EGNN + :param residual: Use residual connections, we recommend not changing this one + :param dropout: probability of dropout + :param dense: if dense, then context states will be concatenated for all layers, + coordination will be averaged + ''' + self.hidden_nf = hidden_nf + self.n_layers = n_layers + + self.dropout = nn.Dropout(dropout) + + self.linear_in = nn.Linear(in_node_nf, self.hidden_nf) + + self.dense = dense + if dense: + self.linear_out = nn.Linear(self.hidden_nf * (n_layers + 1), out_node_nf) + else: + self.linear_out = nn.Linear(self.hidden_nf, out_node_nf) + + for i in range(0, n_layers): + self.add_module(f'ctx_gcl_{i}', AM_E_GCL( + self.hidden_nf, self.hidden_nf, self.hidden_nf, n_channel, channel_nf, radial_nf, + edges_in_d=in_edge_nf, act_fn=act_fn, residual=residual, dropout=dropout + )) + self.add_module(f'inter_gcl_{i}', AM_E_GCL( + self.hidden_nf, self.hidden_nf, self.hidden_nf, n_channel, channel_nf, radial_nf, + edges_in_d=in_edge_nf, act_fn=act_fn, residual=residual, dropout=dropout + )) + self.out_layer = AM_E_GCL( + self.hidden_nf, self.hidden_nf, self.hidden_nf, n_channel, channel_nf, + radial_nf, edges_in_d=in_edge_nf, act_fn=act_fn, residual=residual + ) + + def forward(self, h, x, ctx_edges, inter_mask, inter_x, inter_edges, update_mask, inter_update_mask, channel_attr, channel_weights, + ctx_edge_attr=None): + h = self.linear_in(h) + h = self.dropout(h) + inter_h = h[inter_mask] + inter_channel_attr = channel_attr[inter_mask] + inter_channel_weights = channel_weights[inter_mask] + + ctx_states, ctx_coords, inter_coords = [], [], [] + for i in range(0, self.n_layers): + h, x = self._modules[f'ctx_gcl_{i}']( + h, ctx_edges, x, channel_attr, channel_weights, + edge_attr=ctx_edge_attr) + # synchronization of the shadow paratope (native -> shadow) + inter_h = inter_h.clone() + inter_h[inter_update_mask] = h[update_mask] + inter_h, inter_x = self._modules[f'inter_gcl_{i}']( + inter_h, inter_edges, inter_x, inter_channel_attr, inter_channel_weights + ) + # synchronization of the shadow paratope (shadow -> native) + h = h.clone() + h[inter_mask] = inter_h + ctx_states.append(h) + ctx_coords.append(x) + inter_coords.append(inter_x) + + h, x = self.out_layer( + h, ctx_edges, x, channel_attr, channel_weights, + edge_attr=ctx_edge_attr) + ctx_states.append(h) + ctx_coords.append(x) + if self.dense: + h = torch.cat(ctx_states, dim=-1) + x = torch.mean(torch.stack(ctx_coords), dim=0) + inter_x = torch.mean(torch.stack(inter_coords), dim=0) + h = self.dropout(h) + h = self.linear_out(h) + return h, x, inter_x \ No newline at end of file diff --git a/models/dyMEAN/modules/radial_basis.py b/models/dyMEAN/modules/radial_basis.py new file mode 100644 index 0000000000000000000000000000000000000000..5e83f1fe2ce97a9fb13b721009fab83aed60f05c --- /dev/null +++ b/models/dyMEAN/modules/radial_basis.py @@ -0,0 +1,225 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +import math + +import numpy as np +import torch +import torch.nn as nn +from torch import Tensor +from scipy.special import binom + + +class GaussianSmearing(torch.nn.Module): + def __init__( + self, + start: float = 0.0, + stop: float = 5.0, + num_gaussians: int = 50, + ): + super().__init__() + offset = torch.linspace(start, stop, num_gaussians) + self.coeff = -0.5 / (offset[1] - offset[0]).item()**2 + self.register_buffer('offset', offset) + + def forward(self, dist: Tensor) -> Tensor: + dist = dist.view(-1, 1) - self.offset.view(1, -1) + return torch.exp(self.coeff * torch.pow(dist, 2)) + + +class PolynomialEnvelope(torch.nn.Module): + """ + Polynomial envelope function that ensures a smooth cutoff. + + Parameters + ---------- + exponent: int + Exponent of the envelope function. + """ + + def __init__(self, exponent): + super().__init__() + assert exponent > 0 + self.p = exponent + self.a = -(self.p + 1) * (self.p + 2) / 2 + self.b = self.p * (self.p + 2) + self.c = -self.p * (self.p + 1) / 2 + + def forward(self, d_scaled): + env_val = ( + 1 + + self.a * d_scaled ** self.p + + self.b * d_scaled ** (self.p + 1) + + self.c * d_scaled ** (self.p + 2) + ) + return torch.where(d_scaled < 1, env_val, torch.zeros_like(d_scaled)) + + +class ExponentialEnvelope(torch.nn.Module): + """ + Exponential envelope function that ensures a smooth cutoff, + as proposed in Unke, Chmiela, Gastegger, Schütt, Sauceda, Müller 2021. + SpookyNet: Learning Force Fields with Electronic Degrees of Freedom + and Nonlocal Effects + """ + + def __init__(self): + super().__init__() + + def forward(self, d_scaled): + env_val = torch.exp( + -(d_scaled ** 2) / ((1 - d_scaled) * (1 + d_scaled)) + ) + return torch.where(d_scaled < 1, env_val, torch.zeros_like(d_scaled)) + + +class SphericalBesselBasis(torch.nn.Module): + """ + 1D spherical Bessel basis + + Parameters + ---------- + num_radial: int + Controls maximum frequency. + cutoff: float + Cutoff distance in Angstrom. + """ + + def __init__( + self, + num_radial: int, + cutoff: float, + ): + super().__init__() + self.norm_const = math.sqrt(2 / (cutoff ** 3)) + # cutoff ** 3 to counteract dividing by d_scaled = d / cutoff + + # Initialize frequencies at canonical positions + self.frequencies = torch.nn.Parameter( + data=torch.tensor( + np.pi * np.arange(1, num_radial + 1, dtype=np.float32) + ), + requires_grad=True, + ) + + def forward(self, d_scaled): + return ( + self.norm_const + / d_scaled[:, None] + * torch.sin(self.frequencies * d_scaled[:, None]) + ) # (num_edges, num_radial) + + +class BernsteinBasis(torch.nn.Module): + """ + Bernstein polynomial basis, + as proposed in Unke, Chmiela, Gastegger, Schütt, Sauceda, Müller 2021. + SpookyNet: Learning Force Fields with Electronic Degrees of Freedom + and Nonlocal Effects + + Parameters + ---------- + num_radial: int + Controls maximum frequency. + pregamma_initial: float + Initial value of exponential coefficient gamma. + Default: gamma = 0.5 * a_0**-1 = 0.94486, + inverse softplus -> pregamma = log e**gamma - 1 = 0.45264 + """ + + def __init__( + self, + num_radial: int, + pregamma_initial: float = 0.45264, + ): + super().__init__() + prefactor = binom(num_radial - 1, np.arange(num_radial)) + self.register_buffer( + "prefactor", + torch.tensor(prefactor, dtype=torch.float), + persistent=False, + ) + + self.pregamma = torch.nn.Parameter( + data=torch.tensor(pregamma_initial, dtype=torch.float), + requires_grad=True, + ) + self.softplus = torch.nn.Softplus() + + exp1 = torch.arange(num_radial) + self.register_buffer("exp1", exp1[None, :], persistent=False) + exp2 = num_radial - 1 - exp1 + self.register_buffer("exp2", exp2[None, :], persistent=False) + + def forward(self, d_scaled): + gamma = self.softplus(self.pregamma) # constrain to positive + exp_d = torch.exp(-gamma * d_scaled)[:, None] + return ( + self.prefactor * (exp_d ** self.exp1) * ((1 - exp_d) ** self.exp2) + ) + + +class RadialBasis(torch.nn.Module): + """ + + Parameters + ---------- + num_radial: int + Controls maximum frequency. + cutoff: float + Cutoff distance in Angstrom. + rbf: dict = {"name": "spherical_bessel"} + Basis function and its hyperparameters. + envelope: dict = {"name": "polynomial", "exponent": 5} + Envelope function and its hyperparameters. + """ + + def __init__( + self, + num_radial: int, + cutoff: float, + rbf: dict = {"name": "gaussian"}, + envelope: dict = {"name": "polynomial", "exponent": 5}, + ): + super().__init__() + self.inv_cutoff = 1 / cutoff + + env_name = envelope["name"].lower() + env_hparams = envelope.copy() + del env_hparams["name"] + + if env_name == "polynomial": + self.envelope = PolynomialEnvelope(**env_hparams) + elif env_name == "exponential": + self.envelope = ExponentialEnvelope(**env_hparams) + else: + raise ValueError(f"Unknown envelope function '{env_name}'.") + + rbf_name = rbf["name"].lower() + rbf_hparams = rbf.copy() + del rbf_hparams["name"] + + # RBFs get distances scaled to be in [0, 1] + if rbf_name == "gaussian": + self.rbf = GaussianSmearing( + start=0, stop=1, num_gaussians=num_radial, **rbf_hparams + ) + elif rbf_name == "spherical_bessel": + self.rbf = SphericalBesselBasis( + num_radial=num_radial, cutoff=cutoff, **rbf_hparams + ) + elif rbf_name == "bernstein": + self.rbf = BernsteinBasis(num_radial=num_radial, **rbf_hparams) + else: + raise ValueError(f"Unknown radial basis function '{rbf_name}'.") + + def forward(self, d): + d_scaled = d * self.inv_cutoff + return self.rbf(d_scaled) # the default gaussian should not use cutoff envelope + + env = self.envelope(d_scaled) + return env[:, None] * self.rbf(d_scaled) # (nEdges, num_radial) diff --git a/models/dyMEAN/nn_utils.py b/models/dyMEAN/nn_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..fd4efd0f24bb1e47d16dcb8a6014fcaf216b80d6 --- /dev/null +++ b/models/dyMEAN/nn_utils.py @@ -0,0 +1,381 @@ +#!/usr/bin/python +# -*- coding:utf-8 -*- +import torch +import torch.nn as nn +import torch.nn.functional as F + +from data.format import VOCAB + +from utils.nn_utils import sequential_and +from utils import const + + +class SinusoidalPositionEmbedding(nn.Module): + """ + Sin-Cos Positional Embedding + """ + def __init__(self, output_dim): + super(SinusoidalPositionEmbedding, self).__init__() + self.output_dim = output_dim + + def forward(self, position_ids): + device = position_ids.device + position_ids = position_ids[None] # [1, N] + indices = torch.arange(self.output_dim // 2, device=device, dtype=torch.float) + indices = torch.pow(10000.0, -2 * indices / self.output_dim) + embeddings = torch.einsum('bn,d->bnd', position_ids, indices) + embeddings = torch.stack([torch.sin(embeddings), torch.cos(embeddings)], dim=-1) + embeddings = embeddings.reshape(-1, self.output_dim) + return embeddings + +# embedding of amino acids. (default: concat residue embedding and atom embedding to one vector) +class AminoAcidEmbedding(nn.Module): + ''' + [residue embedding + position embedding, mean(atom embeddings + atom position embeddings)] + ''' + def __init__(self, num_res_type, num_atom_type, num_atom_pos, res_embed_size, atom_embed_size, + atom_pad_id=VOCAB.get_atom_pad_idx(), max_position=256, relative_position=True): + super().__init__() + self.residue_embedding = nn.Embedding(num_res_type, res_embed_size) + if relative_position: + self.res_pos_embedding = SinusoidalPositionEmbedding(res_embed_size) # relative positional encoding + else: + self.res_pos_embedding = nn.Embedding(max_position, res_embed_size) # absolute position encoding + self.atom_embedding = nn.Embedding(num_atom_type, atom_embed_size) + self.atom_pos_embedding = nn.Embedding(num_atom_pos, atom_embed_size) + self.atom_pad_id = atom_pad_id + self.eps = 1e-10 # for mean of atom embedding (some residues have no atom at all) + + def forward(self, S, RP, A, AP): + ''' + :param S: [N], residue types + :param RP: [N], residue positions + :param A: [N, n_channel], atom types + :param AP: [N, n_channel], atom positions + ''' + res_embed = self.residue_embedding(S) + self.res_pos_embedding(RP) # [N, res_embed_size] + atom_embed = self.atom_embedding(A) + self.atom_pos_embedding(AP) # [N, n_channel, atom_embed_size] + atom_not_pad = (AP != self.atom_pad_id) # [N, n_channel] + denom = torch.sum(atom_not_pad, dim=-1, keepdim=True) + self.eps + atom_embed = torch.sum(atom_embed * atom_not_pad.unsqueeze(-1), dim=1) / denom # [N, atom_embed_size] + return torch.cat([res_embed, atom_embed], dim=-1) # [N, res_embed_size + atom_embed_size] + + +class AminoAcidFeature(nn.Module): + def __init__(self, backbone_only=False) -> None: + super().__init__() + + self.backbone_only = backbone_only + + # number of classes + self.num_aa_type = len(VOCAB) + self.num_atom_type = VOCAB.get_num_atom_type() + self.num_atom_pos = VOCAB.get_num_atom_pos() + + # atom-level special tokens + self.atom_mask_idx = VOCAB.get_atom_mask_idx() + self.atom_pad_idx = VOCAB.get_atom_pad_idx() + self.atom_pos_mask_idx = VOCAB.get_atom_pos_mask_idx() + self.atom_pos_pad_idx = VOCAB.get_atom_pos_pad_idx() + + self.mask_idx = VOCAB.get_mask_idx() + self.unk_idx = VOCAB.symbol_to_idx(VOCAB.UNK) + self.latent_idx = VOCAB.symbol_to_idx(VOCAB.LAT) + + # atoms encoding + residue_atom_type, residue_atom_pos = [], [] + backbone = [VOCAB.atom_to_idx(atom[0]) for atom in VOCAB.backbone_atoms] + backbone_pos = [VOCAB.atom_pos_to_idx(atom[1:]) for atom in VOCAB.backbone_atoms] + n_channel = VOCAB.MAX_ATOM_NUMBER if not backbone_only else 4 + special_mask = VOCAB.get_special_mask() + for i in range(len(VOCAB)): + if i == self.mask_idx or i == self.unk_idx: + # mask or unk + residue_atom_type.append(backbone + [self.atom_mask_idx for _ in range(n_channel - len(backbone))]) + residue_atom_pos.append(backbone_pos + [self.atom_pos_mask_idx for _ in range(n_channel - len(backbone_pos))]) + elif i == self.latent_idx: + # latent index + residue_atom_type.append([VOCAB.get_atom_latent_idx() for _ in range(n_channel)]) + residue_atom_pos.append([VOCAB.get_atom_pos_latent_idx() for _ in range(n_channel)]) + elif special_mask[i] == 1: + # other special token (pad) + residue_atom_type.append([self.atom_pad_idx for _ in range(n_channel)]) + residue_atom_pos.append([self.atom_pos_pad_idx for _ in range(n_channel)]) + else: + # normal amino acids + atom_type, atom_pos = backbone, backbone_pos + if not backbone_only: + sidechain_atoms = const.sidechain_atoms[VOCAB.idx_to_symbol(i)] + atom_type = atom_type + [VOCAB.atom_to_idx(atom[0]) for atom in sidechain_atoms] + atom_pos = atom_pos + [VOCAB.atom_pos_to_idx(atom[1]) for atom in sidechain_atoms] + num_pad = n_channel - len(atom_type) + residue_atom_type.append(atom_type + [self.atom_pad_idx for _ in range(num_pad)]) + residue_atom_pos.append(atom_pos + [self.atom_pos_pad_idx for _ in range(num_pad)]) + + # mapping from residue to atom types and positions + self.residue_atom_type = nn.parameter.Parameter( + torch.tensor(residue_atom_type, dtype=torch.long), + requires_grad=False) + self.residue_atom_pos = nn.parameter.Parameter( + torch.tensor(residue_atom_pos, dtype=torch.long), + requires_grad=False) + + # sidechain geometry + if not backbone_only: + sc_bonds, sc_bonds_mask = [], [] + sc_chi_atoms, sc_chi_atoms_mask = [], [] + for i in range(len(VOCAB)): + if special_mask[i] == 1: + sc_bonds.append([]) + sc_chi_atoms.append([]) + else: + symbol = VOCAB.idx_to_symbol(i) + atom_type = VOCAB.backbone_atoms + const.sidechain_atoms[symbol] + atom2channel = { atom: i for i, atom in enumerate(atom_type) } + chi_atoms = const.chi_angles_atoms[VOCAB.symbol_to_abrv(symbol)] + bond_atoms = const.sidechain_bonds[symbol] + sc_chi_atoms.append( + [[atom2channel[atom] for atom in atoms] for atoms in chi_atoms] + ) + bonds = [] + for src_atom, dst_atom, _ in bond_atoms: + bonds.append((atom2channel[src_atom], atom2channel[dst_atom])) + sc_bonds.append(bonds) + max_num_chis = max([len(chis) for chis in sc_chi_atoms]) + max_num_bonds = max([len(bonds) for bonds in sc_bonds]) + for i in range(len(VOCAB)): + num_chis, num_bonds = len(sc_chi_atoms[i]), len(sc_bonds[i]) + num_pad_chis, num_pad_bonds = max_num_chis - num_chis, max_num_bonds - num_bonds + sc_chi_atoms_mask.append( + [1 for _ in range(num_chis)] + [0 for _ in range(num_pad_chis)] + ) + sc_bonds_mask.append( + [1 for _ in range(num_bonds)] + [0 for _ in range(num_pad_bonds)] + ) + sc_chi_atoms[i].extend([[-1, -1, -1, -1] for _ in range(num_pad_chis)]) + sc_bonds[i].extend([(-1, -1) for _ in range(num_pad_bonds)]) + + # mapping residues to their sidechain chi angle atoms and bonds + self.sidechain_chi_angle_atoms = nn.parameter.Parameter( + torch.tensor(sc_chi_atoms, dtype=torch.long), + requires_grad=False) + self.sidechain_chi_mask = nn.parameter.Parameter( + torch.tensor(sc_chi_atoms_mask, dtype=torch.bool), + requires_grad=False + ) + self.sidechain_bonds = nn.parameter.Parameter( + torch.tensor(sc_bonds, dtype=torch.long), + requires_grad=False + ) + self.sidechain_bonds_mask = nn.parameter.Parameter( + torch.tensor(sc_bonds_mask, dtype=torch.bool), + requires_grad=False + ) + + def _construct_atom_type(self, S): + # construct atom types + return self.residue_atom_type[S] + + def _construct_atom_pos(self, S): + # construct atom positions + return self.residue_atom_pos[S] + + @torch.no_grad() + def get_sidechain_chi_angles_atoms(self, S): + chi_angles_atoms = self.sidechain_chi_angle_atoms[S] # [N, max_num_chis, 4] + chi_mask = self.sidechain_chi_mask[S] # [N, max_num_chis] + return chi_angles_atoms, chi_mask + + @torch.no_grad() + def get_sidechain_bonds(self, S): + bonds = self.sidechain_bonds[S] # [N, max_num_bond, 2] + bond_mask = self.sidechain_bonds_mask[S] + return bonds, bond_mask + + +class SeparatedAminoAcidFeature(AminoAcidFeature): + ''' + Separate embeddings of atoms and residues + ''' + def __init__( + self, + embed_size, + atom_embed_size, + max_position, + relative_position=True, + fix_atom_weights=False, + backbone_only=False + ) -> None: + super().__init__(backbone_only=backbone_only) + atom_weights_mask = self.residue_atom_type == self.atom_pad_idx + self.register_buffer('atom_weights_mask', atom_weights_mask) + self.fix_atom_weights = fix_atom_weights + if fix_atom_weights: + atom_weights = torch.ones_like(self.residue_atom_type, dtype=torch.float) + else: + atom_weights = torch.randn_like(self.residue_atom_type, dtype=torch.float) + atom_weights[atom_weights_mask] = 0 + self.atom_weight = nn.parameter.Parameter(atom_weights, requires_grad=not fix_atom_weights) + self.zero_atom_weight = nn.parameter.Parameter(torch.zeros_like(atom_weights), requires_grad=False) + + self.aa_embedding = AminoAcidEmbedding( + self.num_aa_type, self.num_atom_type, self.num_atom_pos, + embed_size, atom_embed_size, self.atom_pad_idx, + max_position, relative_position) + + def get_atom_weights(self, residue_types): + weights = torch.where( + self.atom_weights_mask, + self.zero_atom_weight, + self.atom_weight + ) # [num_aa_classes, max_atom_number(n_channel)] + if not self.fix_atom_weights: + weights = F.normalize(weights, dim=-1) + return weights[residue_types] + + def forward(self, S, position_ids, smooth_prob=None, smooth_mask=None): + atom_type = self.residue_atom_type[S] # [N, n_channel] + atom_pos = self.residue_atom_pos[S] # [N, n_channel] + + # residue embedding + pos_embedding = self.aa_embedding.res_pos_embedding(position_ids) + H = self.aa_embedding.residue_embedding(S) + if smooth_prob is not None: + res_embeddings = self.aa_embedding.residue_embedding( + torch.arange(smooth_prob.shape[-1], device=S.device, dtype=S.dtype) + ) # [num_aa_type, embed_size] + H[smooth_mask] = smooth_prob.mm(res_embeddings) + H = H + pos_embedding + + # atom embedding + atom_embedding = self.aa_embedding.atom_embedding(atom_type) +\ + self.aa_embedding.atom_pos_embedding(atom_pos) + atom_weights = self.get_atom_weights(S) + + return H, (atom_embedding, atom_weights) + + +class ProteinFeature: + def __init__(self, backbone_only=False): + self.backbone_only = backbone_only + + def _cal_sidechain_bond_lengths(self, S, X, aa_feature: AminoAcidFeature, atom_mask=None): + bonds, bonds_mask = aa_feature.get_sidechain_bonds(S) + n = torch.nonzero(bonds_mask)[:, 0] # [Nbonds] + src, dst = bonds[bonds_mask].T + src_X, dst_X = X[(n, src)], X[(n, dst)] # [Nbonds, 3] + bond_lengths = torch.norm(dst_X - src_X, dim=-1) + if atom_mask is not None: + mask = atom_mask[(n, src)] & atom_mask[(n, dst)] + bond_lengths = bond_lengths[mask] + return bond_lengths + + def _cal_sidechain_chis(self, S, X, aa_feature: AminoAcidFeature, atom_mask=None): + chi_atoms, chi_mask = aa_feature.get_sidechain_chi_angles_atoms(S) + n = torch.nonzero(chi_mask)[:, 0] # [Nchis] + a0, a1, a2, a3 = chi_atoms[chi_mask].T # [Nchis] + x0, x1, x2, x3 = X[(n, a0)], X[(n, a1)], X[(n, a2)], X[(n, a3)] # [Nchis, 3] + u_0, u_1, u_2 = (x1 - x0), (x2 - x1), (x3 - x2) # [Nchis, 3] + # normals of the two planes + n_1 = F.normalize(torch.cross(u_0, u_1), dim=-1) # [Nchis, 3] + n_2 = F.normalize(torch.cross(u_1, u_2), dim=-1) # [Nchis, 3] + cosChi = (n_1 * n_2).sum(-1) # [Nchis] + eps = 1e-7 + cosChi = torch.clamp(cosChi, -1 + eps, 1 - eps) + if atom_mask is not None: + mask = atom_mask[(n, a0)] & atom_mask[(n, a1)] & atom_mask[(n, a2)] & atom_mask[(n, a3)] + cosChi = cosChi[mask] + return cosChi + + def _cal_backbone_bond_lengths(self, X, batch_ids, segment_ids, atom_mask=None): + # loss of backbone (...N-CA-C(O)-N...) bond length + # N-CA, CA-C, C=O + bl1 = torch.norm(X[:, 1:4] - X[:, :3], dim=-1) # [N, 3], (N-CA), (CA-C), (C=O) + if atom_mask is not None: + bl1 = bl1[atom_mask[:, 1:4] & atom_mask[:, :3]] + else: + bl1 = bl1.flatten() + # C-N + bl2 = torch.norm(X[1:, 0] - X[:-1, 2], dim=-1) # [N-1] + same_chain_mask = (segment_ids[1:] == segment_ids[:-1]) & (batch_ids[1:] == batch_ids[:-1]) + if atom_mask is not None: + mask = atom_mask[1:, 0] & atom_mask[:-1, 2] & same_chain_mask + else: + mask = same_chain_mask + bl2 = bl2[mask] + bl = torch.cat([bl1, bl2], dim=0) + return bl + + def _cal_backbone_dihedral_angles(self, X, batch_ids, segment_ids, atom_mask=None): + ori_X = X.clone() # used for calculating bond angles + X = X[:, :3].reshape(-1, 3) # [N * 3, 3], N, CA, C + U = F.normalize(X[1:] - X[:-1], dim=-1) # [N * 3 - 1, 3] + + # 1. dihedral angles + u_2, u_1, u_0 = U[:-2], U[1:-1], U[2:] # [N * 3 - 3, 3] + # backbone normals + n_2 = F.normalize(torch.cross(u_2, u_1), dim=-1) + n_1 = F.normalize(torch.cross(u_1, u_0), dim=-1) + # angle between normals + eps = 1e-7 + cosD = (n_2 * n_1).sum(-1) # [(N-1) * 3] + cosD = torch.clamp(cosD, -1 + eps, 1 - eps) + # D = torch.sign((u_2 * n_1).sum(-1)) * torch.acos(cosD) + seg_id_atom = segment_ids.repeat(1, 3).flatten() # [N * 3] + batch_id_atom = batch_ids.repeat(1, 3).flatten() + same_chain_mask = sequential_and( + seg_id_atom[:-3] == seg_id_atom[1:-2], + seg_id_atom[1:-2] == seg_id_atom[2:-1], + seg_id_atom[2:-1] == seg_id_atom[3:], + batch_id_atom[:-3] == batch_id_atom[1:-2], + batch_id_atom[1:-2] == batch_id_atom[2:-1], + batch_id_atom[2:-1] == batch_id_atom[3:] + ) # [N * 3 - 3] + # D = D[same_chain_mask] + if atom_mask is not None: + mask = atom_mask[:, :3].flatten() # [N * 3] + mask = mask[1:] & mask[:-1] # [N * 3 - 1] + mask = mask[:-2] & mask[1:-1] & mask[2:] # [N * 3 - 3] + mask = mask & same_chain_mask + else: + mask = same_chain_mask + + cosD = cosD[mask] + + # # 2. bond angles (C_{n-1}-N, N-CA), (N-CA, CA-C), (CA-C, C=O), (CA-C, C-N_{n+1}), (O=C, C-Nn) + # u_0, u_1 = U[:-1], U[1:] # [N*3 - 2, 3] + # cosA1 = ((-u_0) * u_1).sum(-1) # [N*3 - 2], (C_{n-1}-N, N-CA), (N-CA, CA-C), (CA-C, C-N_{n+1}) + # same_chain_mask = sequential_and( + # seg_id_atom[:-2] == seg_id_atom[1:-1], + # seg_id_atom[1:-1] == seg_id_atom[2:] + # ) + # cosA1 = cosA1[same_chain_mask] # [N*3 - 2 * num_chain] + # u_co = F.normalize(ori_X[:, 3] - ori_X[:, 2], dim=-1) # [N, 3], C=O + # u_cca = -U[1::3] # [N, 3], C-CA + # u_cn = U[2::3] # [N-1, 3], C-N_{n+1} + # cosA2 = (u_co * u_cca).sum(-1) # [N], (C=O, C-CA) + # cosA3 = (u_co[:-1] * u_cn).sum(-1) # [N-1], (C=O, C-N_{n+1}) + # same_chain_mask = (seg_id[:-1] == seg_id[1:]) # [N-1] + # cosA3 = cosA3[same_chain_mask] + # cosA = torch.cat([cosA1, cosA2, cosA3], dim=-1) + # cosA = torch.clamp(cosA, -1 + eps, 1 - eps) + + # return cosD, cosA + return cosD + + def get_struct_profile(self, X, S, batch_ids, aa_feature: AminoAcidFeature, segment_ids=None, atom_mask=None): + ''' + X: [N, 14, 3], coordinates of all atoms + batch_ids: [N], indicate which item the residue belongs to + segment_ids: [N], indicate which chain the residue belongs to + aa_feature: AminoAcidFeature, storing geometric constants + atom_mask: [N, 14], 0 for padding/missing + ''' + if segment_ids is None: # default regarded as monomers + segment_ids = torch.ones_like(batch_ids) + return { + 'bb_bond_lengths': self._cal_backbone_bond_lengths(X, batch_ids, segment_ids, atom_mask), + 'sc_bond_lengths': self._cal_sidechain_bond_lengths(S, X, aa_feature, atom_mask), + 'bb_dihedral_angles': self._cal_backbone_dihedral_angles(X, batch_ids, segment_ids, atom_mask), + 'sc_chi_angles': self._cal_sidechain_chis(S, X, aa_feature, atom_mask) + } \ No newline at end of file diff --git a/scripts/data_process/aug_from_monomer.py b/scripts/data_process/aug_from_monomer.py new file mode 100644 index 0000000000000000000000000000000000000000..de6c800d03941f51e38d3e4222ccc69c0736d6fe --- /dev/null +++ b/scripts/data_process/aug_from_monomer.py @@ -0,0 +1,465 @@ +import os +import re +import gzip +import time +import shutil +import argparse +from copy import deepcopy +from tempfile import NamedTemporaryFile +import multiprocessing as mp + +import numpy as np +from Bio.PDB import PDBParser,Chain,Model,Structure, PDBIO +from Bio.PDB.DSSP import dssp_dict_from_pdb_file +from Bio.SeqUtils.ProtParam import ProteinAnalysis +from rdkit.Chem.rdMolDescriptors import CalcTPSA +from freesasa import calcBioPDB +from rdkit.Chem import MolFromSmiles + +from globals import CACHE_DIR, CONTACT_DIST +from utils.logger import print_log +from utils.file_utils import cnt_num_files, get_filename +from data.mmap_dataset import create_mmap +from data.converter.pdb_to_list_blocks import pdb_to_list_blocks +from data.converter.blocks_interface import blocks_cb_interface, blocks_interface + +from .pepbind import clustering + + +def parse(): + parser = argparse.ArgumentParser(description='Filter peptide-like loop from monomers') + parser.add_argument('--database_dir', type=str, required=True, + help='Directory of pdb database processed in monomers') + parser.add_argument('--pdb_dir', type=str, required=True, help='Directory to PDB database') + parser.add_argument('--out_dir', type=str, required=True, help='Output directory') + parser.add_argument('--pocket_th', type=float, default=10.0, help='Threshold for determining pocket') + parser.add_argument('--n_cpu', type=int, default=4, help='Number of CPU to use') + + return parser.parse_args() + + +# Constants +AA3TO1 = { + 'ALA':'A', 'VAL':'V', 'PHE':'F', 'PRO':'P', 'MET':'M', + 'ILE':'I', 'LEU':'L', 'ASP':'D', 'GLU':'E', 'LYS':'K', + 'ARG':'R', 'SER':'S', 'THR':'T', 'TYR':'Y', 'HIS':'H', + 'CYS':'C', 'ASN':'N', 'GLN':'Q', 'TRP':'W', 'GLY':'G',} + +hydrophobic_residues=['V','I','L','M','F','W','C'] +charged_residues=['H','R','K','D','E'] + +def add_cb(input_array): + #from protein mpnn + #The virtual Cβ coordinates were calculated using ideal angle and bond length definitions: b = Cα - N, c = C - Cα, a = cross(b, c), Cβ = -0.58273431*a + 0.56802827*b - 0.54067466*c + Cα. + N,CA,C,O = input_array + b = CA - N + c = C - CA + a = np.cross(b,c) + CB = np.around(-0.58273431*a + 0.56802827*b - 0.54067466*c + CA,3) + return CB #np.array([N,CA,C,CB,O]) + +aaSMILES = {'G': 'NCC(=O)O', + 'A': 'N[C@@]([H])(C)C(=O)O', + 'R': 'N[C@@]([H])(CCCNC(=N)N)C(=O)O', + 'N': 'N[C@@]([H])(CC(=O)N)C(=O)O', + 'D': 'N[C@@]([H])(CC(=O)O)C(=O)O', + 'C': 'N[C@@]([H])(CS)C(=O)O', + 'E': 'N[C@@]([H])(CCC(=O)O)C(=O)O', + 'Q': 'N[C@@]([H])(CCC(=O)N)C(=O)O', + 'H': 'N[C@@]([H])(CC1=CN=C-N1)C(=O)O', + 'I': 'N[C@@]([H])(C(CC)C)C(=O)O', + 'L': 'N[C@@]([H])(CC(C)C)C(=O)O', + 'K': 'N[C@@]([H])(CCCCN)C(=O)O', + 'M': 'N[C@@]([H])(CCSC)C(=O)O', + 'F': 'N[C@@]([H])(Cc1ccccc1)C(=O)O', + 'P': 'N1[C@@]([H])(CCC1)C(=O)O', + 'S': 'N[C@@]([H])(CO)C(=O)O', + 'T': 'N[C@@]([H])(C(O)C)C(=O)O', + 'W': 'N[C@@]([H])(CC(=CN2)C1=C2C=CC=C1)C(=O)O', + 'Y': 'N[C@@]([H])(Cc1ccc(O)cc1)C(=O)O', + 'V': 'N[C@@]([H])(C(C)C)C(=O)O'} + + +class Filter: + def __init__( + self, + min_loop_len = 4, + max_loop_len = 25, + min_BSA = 400, + min_relBSA = 0.2, + max_relncBSA = 0.3, + saved_maxlen = 25, + saved_BSA = 400, + saved_relBSA = 0.2, + saved_helix_ratio = 1.0, + saved_strand_ratio = 1.0, + cyclic=False + ) -> None: + + self.re_filter = re.compile(r'D[GPS]|[P]{2,}|C') #https://www.thermofisher.cn/cn/zh/home/life-science/protein-biology/protein-biology-learning-center/protein-biology-resource-library/pierce-protein-methods/peptide-design.html + self.cache_dir = CACHE_DIR + + self.min_loop_len = min_loop_len + self.max_loop_len = max_loop_len + self.min_BSA = min_BSA + self.min_relBSA = min_relBSA + self.max_relncBSA = max_relncBSA + self.saved_maxlen = saved_maxlen + self.saved_BSA = saved_BSA + self.saved_relBSA = saved_relBSA + self.saved_helix_ratio = saved_helix_ratio + self.saved_strand_ratio = saved_strand_ratio + self.cyclic = cyclic + + @classmethod + def get_ss_info(cls, pdb_path: str): + dssp, keys = dssp_dict_from_pdb_file(pdb_path, DSSP='mkdssp') + ss_info = {} + for key in keys: + chain_id, value = key[0], dssp[key] + if chain_id not in ss_info: + ss_info[chain_id] = [] + ss_type = value[1] + if ss_type in ['H', 'G', 'I']: + ss_info[chain_id].append('a') + elif ss_type in ['B', 'E', 'T', 'S']: + ss_info[chain_id].append('b') + elif ss_type == '-': + ss_info[chain_id].append('c') + else: + raise ValueError(f'SS type {ss_type} cannot be recognized!') + return ss_info + + @classmethod + def get_bsa(self, receptor_chain: Chain.Chain, ligand_chain: Chain.Chain): + lig_chain_id = ligand_chain.get_id() + tmp_structure = Structure.Structure('tmp') + tmp_model = Model.Model(0) + tmp_structure.add(tmp_model) + tmp_model.add(ligand_chain) + unbounded_SASA = calcBioPDB(tmp_structure)[0].residueAreas()[lig_chain_id] + unbounded_SASA = [k.total for k in unbounded_SASA.values()] + + tmp_model.add(receptor_chain) + bounded_SASA = calcBioPDB(tmp_structure)[0].residueAreas()[lig_chain_id] + bounded_SASA = [k.total for k in bounded_SASA.values()] + + abs_bsa = sum(unbounded_SASA[1:-1]) - sum(bounded_SASA[1:-1]) + rel_bsa = abs_bsa / sum(unbounded_SASA[1:-1]) + rel_nc_bsa = (unbounded_SASA[0] + unbounded_SASA[-1] - bounded_SASA[0] - bounded_SASA[-1]) / (unbounded_SASA[0] + unbounded_SASA[-1]) + + return abs_bsa, rel_bsa, rel_nc_bsa, tmp_structure + + def filter_pdb(self, pdb_path, selected_chains=None): + parser = PDBParser(QUIET=True) + ss_info = self.get_ss_info(pdb_path) + structure = parser.get_structure('anonym', pdb_path) + + for model in structure.get_models(): # use model 1 only + structure = model + break + + results = [] + for chain in structure.get_chains(): + if selected_chains is not None and chain.get_id() not in selected_chains: + continue + chain_ss_info = None if ss_info is None else ss_info[chain.get_id()] + results.extend(self.filter_chain(chain, chain_ss_info)) + + return results + + + def filter_chain(self, chain, ss_info=None): + + non_standard = False + for res in chain: + if res.get_resname() not in AA3TO1: + non_standard = True + break + if non_standard: + return [] + + if len(ss_info) != len(chain): + return [] + + cb_coord = [] + seq = '' + for res in chain: + seq += AA3TO1[res.get_resname()] + try: + cb_coord.append(res['CB'].get_coord()) + except: + tmp_coord = np.array([ + res['N'].get_coord(), + res['CA'].get_coord(), + res['C'].get_coord(), + res['O'].get_coord() + ]) + cb_coord.append(add_cb(tmp_coord)) + cb_coord = np.array(cb_coord) + cb_contact = np.linalg.norm(cb_coord[None,:,:,] - cb_coord[:,None,:],axis=-1) + if self.cyclic: + possible_ss = (cb_contact >= 3.5) & (cb_contact <= 5) #https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4987930/ + else: + possible_ss = np.ones(cb_contact.shape, dtype=bool) + possible_ss = np.triu(np.tril(possible_ss, self.max_loop_len - 1), self.min_loop_len - 1) + ss_pair = np.where(possible_ss) + accepted, saved_spans = [], [] + for i, j in zip(ss_pair[0],ss_pair[1]): + redundant = False + for exist_i, exist_j in saved_spans: + overlap = min(j, exist_j) - max(i, exist_i) + 1 + if overlap / (j - i + 1) > 0.4 or overlap / (exist_j - exist_i + 1) > 0.4: + redundant = True + break + if redundant: + continue + #20A neighbor + min_dist = np.min(cb_contact[i : j + 1], axis=0) + min_dist[max(i - 5, 0):min(j + 6, len(seq))] = 21 + neighbors_20A = np.where(min_dist < 20)[0] + if len(neighbors_20A) < 16: + continue + + #sequence filter + pep_seq = seq[i:j+1] + #cystine 2P and DGDA filter + if self.re_filter.search(pep_seq) is not None: + continue + prot_param=ProteinAnalysis(pep_seq) + aa_percent = prot_param.get_amino_acids_percent() + max_ratio = max(aa_percent.values()) + #Discard if any amino acid represents more than 25% of the total sequence + if max_ratio > 0.25: + continue + hydrophobic_ratio = sum([aa_percent[k] for k in hydrophobic_residues]) + #hydrophobic amino acids exceeds 45% + if hydrophobic_ratio > 0.45: + continue + #charged amino acids exceeds 45% or less than 25% + charged_ratio = sum([aa_percent[k] for k in charged_residues]) + if charged_ratio > 0.45 or charged_ratio < 0.25: + continue + #instablility index>40 + if prot_param.instability_index() >= 40: + continue + + # #TPSA filter (for cell penetration) + # mol_weight = prot_param.molecular_weight() + # pepsmile='O' + # for k in pep_seq: + # pepsmile=pepsmile[:-1] + aaSMILES[k] + # pepsmile = MolFromSmiles(pepsmile) + # tpsa = CalcTPSA(pepsmile) + # if tpsa <= mol_weight * 0.2: + # continue + + #build structure and get BSA + receptor_chain = Chain.Chain('R') + ligand_chain = Chain.Chain('L') + for k,res in enumerate(chain): + if k >= i and k <= j: + ligand_chain.add(res.copy()) + elif k in neighbors_20A: + receptor_chain.add(res.copy()) + + abs_bsa, rel_bsa, rel_nc_bsa, tmp_structure = self.get_bsa(receptor_chain, ligand_chain) + if abs_bsa <= self.min_BSA or rel_bsa <= self.min_relBSA or (self.cyclic and rel_nc_bsa >= self.max_relncBSA): + continue + + #prepare for output + length = j - i + 1 + if ss_info is None: + helix_ratio = -1 + strand_ratio = -1 + coil_ratio = -1 + else: + ssa = ss_info[i:j+1] + helix_ratio = ssa.count('a') / length + strand_ratio = ssa.count('b') / length + coil_ratio = ssa.count('c') / length + # helix_ratio = (ssa.count("G") + ssa.count("H") + ssa.count("I") + ssa.count("T")) / length + # strand_ratio = (ssa.count("E") + ssa.count("B")) / length + # coil_ratio = (ssa.count("S")+ssa.count("C")) / length + if length <= self.saved_maxlen and abs_bsa >= self.saved_BSA and rel_bsa >= self.saved_relBSA and helix_ratio <= self.saved_helix_ratio and strand_ratio <= self.saved_strand_ratio: + output_structure = deepcopy(tmp_structure) + else: + output_structure = None + accepted.append(( + i , j, length, abs_bsa, rel_bsa, helix_ratio, strand_ratio, coil_ratio, output_structure + )) + saved_spans.append((i, j)) + + return accepted + + +def get_non_redundant(mmap_dir): + np.random.seed(12) + index_path = os.path.join(mmap_dir, 'index.txt') + parent_dir = mmap_dir + + # load index file + items = {} + with open(index_path, 'r') as fin: + lines = fin.readlines() + for line in lines: + values = line.strip().split('\t') + _id, seq = values[0], values[-1] + chain, pdb_file = _id.split('_') + items[_id] = (seq, chain, pdb_file) + + # make temporary directory + tmp_dir = os.path.join(parent_dir, 'tmp') + if not os.path.exists(tmp_dir): + os.makedirs(tmp_dir) + else: + raise ValueError(f'Working directory {tmp_dir} exists!') + + # 1. get non-redundant dimer by 90% seq-id + fasta = os.path.join(tmp_dir, 'seq.fasta') + with open(fasta, 'w') as fout: + for _id in items: + fout.write(f'>{_id}\n{items[_id][0]}\n') + id2clu, clu2id = clustering(fasta, tmp_dir, 0.9) + non_redundant = [] + for clu in clu2id: + ids = clu2id[clu] + non_redundant.append(np.random.choice(ids)) + print_log(f'Non-redundant entries: {len(non_redundant)}') + shutil.rmtree(tmp_dir) + + # 2. construct non_redundant items + indexes = {} + for _id in non_redundant: + _, chain, pdb_file = items[_id] + if pdb_file not in indexes: + indexes[pdb_file] = [] + indexes[pdb_file].append(chain) + + return indexes + + +def mp_worker(data_dir, tmp_dir, pdb_file, selected_chains, pep_filter, pdb_out_dir, queue): + category = pdb_file[4:6] + category_dir = os.path.join(data_dir, category) + path = os.path.join(category_dir, pdb_file) + tmp_file = os.path.join(tmp_dir, f'{pdb_file}.decompressed') + pdb_id = get_filename(pdb_file.split('.')[0]) + + # uncompress the file to the tmp file + with gzip.open(path, 'rb') as fin: + with open(tmp_file, 'wb') as fout: + shutil.copyfileobj(fin, fout) + + files = [] + try: + # biotitie_pdb_file = biotite_pdb.PDBFile.read(tmp_file) + # biotite_struct = biotitie_pdb_file.get_structure(model=1) + # ss_info = { chain: annotate_sse(biotite_struct, chain_id=chain) for chain in selected_chains } + results = pep_filter.filter_pdb(tmp_file, selected_chains=selected_chains) + for item in results: + i, j, struct = item[0], item[1], item[-1] + if struct is None: + continue + io = PDBIO() + io.set_structure(struct) + _id = pdb_id + f'_{i}_{j}' + save_path = os.path.join(pdb_out_dir, _id + '.pdb') + io.save(save_path) + files.append(save_path) + except Exception: # pdbs with missing backbone coordinates or DSSP failed + pass + queue.put((pdb_file, files)) + os.remove(tmp_file) + + +def process_iterator(indexes, data_dir, tmp_dir, out_dir, pocket_th, n_cpu): + pdb_out_dir = os.path.join(out_dir, 'pdbs') + if not os.path.exists(pdb_out_dir): + os.makedirs(pdb_out_dir) + pep_filter = Filter() + + file_cnt, pointer, filenames = 0, 0, list(indexes.keys()) + id2task = {} + queue = mp.Queue() + # initialize tasks + for _ in range(n_cpu): + task_id = filenames[pointer] + id2task[task_id] = mp.Process( + target=mp_worker, + args=(data_dir, tmp_dir, task_id, indexes[task_id], pep_filter, pdb_out_dir, queue) + ) + id2task[task_id].start() + pointer += 1 + + while True: + if len(id2task) == 0: + break + + if not queue.qsize: # no finished ones + time.sleep(1) + continue + + pdb_file, paths = queue.get() + file_cnt += 1 + id2task[pdb_file].join() + del id2task[pdb_file] + + # add the next task + if pointer < len(filenames): + task_id = filenames[pointer] + id2task[task_id] = mp.Process( + target=mp_worker, + args=(data_dir, tmp_dir, task_id, indexes[task_id], pep_filter, pdb_out_dir, queue) + ) + id2task[task_id].start() + pointer += 1 + + # handle processed data + for save_path in paths: + _id = get_filename(save_path) + + list_blocks, chains = pdb_to_list_blocks(save_path, return_chain_ids=True) + if chains[0] == 'L': + list_blocks, chains = (list_blocks[1], list_blocks[0]), (chains[1], chains[0]) + + rec_blocks, lig_blocks = list_blocks + rec_chain, lig_chain = chains + try: + _, (pocket_idx, _) = blocks_cb_interface(rec_blocks, lig_blocks, pocket_th) + except KeyError: + continue + rec_num_units = sum([len(block) for block in rec_blocks]) + lig_num_units = sum([len(block) for block in lig_blocks]) + rec_data = [block.to_tuple() for block in rec_blocks] + lig_data = [block.to_tuple() for block in lig_blocks] + rec_seq = ''.join([AA3TO1[block.abrv] for block in rec_blocks]) + lig_seq = ''.join([AA3TO1[block.abrv] for block in lig_blocks]) + + yield _id, (rec_data, lig_data), [ + len(rec_blocks), len(lig_blocks), rec_num_units, lig_num_units, + rec_chain, lig_chain, rec_seq, lig_seq, + ','.join([str(idx) for idx in pocket_idx]), + ], file_cnt + + +def main(args): + indexes = get_non_redundant(args.database_dir) + cnt = len(indexes) + tmp_dir = './tmp' + if not os.path.exists(tmp_dir): + os.makedirs(tmp_dir) + + print_log(f'Processing data from directory: {args.pdb_dir}.') + print_log(f'Number of entries: {cnt}') + create_mmap( + process_iterator(indexes, args.pdb_dir, tmp_dir, args.out_dir, args.pocket_th, args.n_cpu), + args.out_dir, cnt) + + print_log('Finished!') + + shutil.rmtree(tmp_dir) + + +if __name__ == '__main__': + main(parse()) diff --git a/scripts/data_process/monomer.py b/scripts/data_process/monomer.py new file mode 100644 index 0000000000000000000000000000000000000000..7f1ba1be8e3c4c847738018ba0abbcd99d9a5c0e --- /dev/null +++ b/scripts/data_process/monomer.py @@ -0,0 +1,111 @@ +#!/usr/bin/python +# -*- coding:utf-8 -*- +import os +import gzip +import shutil +import argparse + +import numpy as np + +from utils.logger import print_log +from utils.file_utils import get_filename, cnt_num_files +from data.format import VOCAB +from data.converter.pdb_to_list_blocks import pdb_to_list_blocks +from data.converter.blocks_to_data import blocks_to_data +from data.mmap_dataset import create_mmap + + +def parse(): + parser = argparse.ArgumentParser(description='Process PDB to monomers') + parser.add_argument('--pdb_dir', type=str, required=True, + help='Directory of pdb database') + parser.add_argument('--out_dir', type=str, required=True, + help='Output directory') + return parser.parse_args() + + +def process_iterator(data_dir): + + tmp_dir = './tmp' + if not os.path.exists(tmp_dir): + os.makedirs(tmp_dir) + + file_cnt = 0 + for category in os.listdir(data_dir): + category_dir = os.path.join(data_dir, category) + for pdb_file in os.listdir(category_dir): + file_cnt += 1 + path = os.path.join(category_dir, pdb_file) + tmp_file = os.path.join(tmp_dir, f'{pdb_file}.decompressed') + + try: + # uncompress the file to the tmp file + with gzip.open(path, 'rb') as fin: + with open(tmp_file, 'wb') as fout: + shutil.copyfileobj(fin, fout) + + list_blocks, chains = pdb_to_list_blocks(tmp_file, return_chain_ids=True) + except Exception as e: + print_log(f'Parsing {pdb_file} failed: {e}', level='WARN') + continue + + for blocks, chain in zip(list_blocks, chains): + + # find broken chains: sequence starts from N end + filter_blocks, NC_coords = [], [] + for block in blocks: + N_coord, C_coord, CA_coord = None, None, None + for atom in block: + if atom.name == 'N': + N_coord = atom.coordinate + elif atom.name == 'C': + C_coord = atom.coordinate + elif atom.name == 'CA': + CA_coord = atom.coordinate + if N_coord and C_coord and CA_coord: + filter_blocks.append(block) + NC_coords.append(N_coord) + NC_coords.append(C_coord) + + if len(filter_blocks) == 0: # no valid residues + continue + + NC_coords = np.array(NC_coords) + pep_bond_len = np.linalg.norm(NC_coords[1::2][:-1] - NC_coords[2::2], axis=-1) + # broken = np.nonzero(pep_bond_len > 1.5)[0] + + if np.any(pep_bond_len > 1.5): + continue + + blocks = filter_blocks + item_id = chain + '_' + pdb_file + # data = blocks_to_data(blocks) + num_blocks = len(blocks) + num_units = sum([len(block.units) for block in blocks]) + data = [block.to_tuple() for block in blocks] + + seq = ''.join([VOCAB.abrv_to_symbol(block.abrv) for block in blocks]) + + # id, data, properties, whether this entry is finished for producing data + yield item_id, data, [num_blocks, num_units, chain, seq], file_cnt + + if os.path.exists(tmp_file): + os.remove(tmp_file) + + shutil.rmtree(tmp_dir) + +def main(args): + + cnt = cnt_num_files(args.pdb_dir, recursive=True) + + print_log(f'Processing data from directory: {args.pdb_dir}.') + print_log(f'Number of entries: {cnt}') + create_mmap( + process_iterator(args.pdb_dir), + args.out_dir, cnt) + + print_log('Finished!') + + +if __name__ == '__main__': + main(parse()) \ No newline at end of file diff --git a/scripts/data_process/pepbdb.py b/scripts/data_process/pepbdb.py new file mode 100644 index 0000000000000000000000000000000000000000..66eb9b1a0622742b4e00adaa62d9e41e6b536845 --- /dev/null +++ b/scripts/data_process/pepbdb.py @@ -0,0 +1,114 @@ +#!/usr/bin/python +# -*- coding:utf-8 -*- +import os +import re +import argparse +from tqdm import tqdm + +from Bio.PDB.MMCIFParser import MMCIFParser +from Bio.PDB import PDBIO + +from data.mmap_dataset import create_mmap +from data.format import VOCAB +from data.converter.pdb_to_list_blocks import pdb_to_list_blocks +from data.converter.list_blocks_to_pdb import list_blocks_to_pdb +from data.converter.blocks_interface import blocks_interface, blocks_cb_interface +from utils.logger import print_log + + +def parse(): + parser = argparse.ArgumentParser(description='Process PepBDB dataset') + parser.add_argument('--index', type=str, default=None, help='Index file of the dataset') + parser.add_argument('--out_dir', type=str, required=True, help='Output Directory') + parser.add_argument('--pocket_th', type=float, default=10.0, + help='Threshold for determining pocket') + return parser.parse_args() + + +def process_iterator(items, pdb_dir, pdb_out_dir, pocket_th): + if not os.path.exists(pdb_out_dir): + os.makedirs(pdb_out_dir) + + for cnt, pdb_id in enumerate(items): + summary = items[pdb_id] + rec_chain, lig_chain = summary['rec_chain'], summary['pep_chain'] + non_standard = 0 + try: + rec_blocks = pdb_to_list_blocks(os.path.join(pdb_dir, pdb_id, 'receptor.pdb'), selected_chains=[rec_chain])[0] + lig_blocks = pdb_to_list_blocks(os.path.join(pdb_dir, pdb_id, 'peptide.pdb'), selected_chains=[lig_chain])[0] + except (KeyError, FileNotFoundError): + continue + _, (_, pep_if_idx) = blocks_interface(rec_blocks, lig_blocks, 6.0) # 6A for atomic interaction + # if len(pep_if_idx) / len(lig_blocks) < 0.3: # too less contacts + # continue + if len(pep_if_idx) == 0: + continue + try: + _, (pocket_idx, _) = blocks_cb_interface(rec_blocks, lig_blocks, pocket_th) # 10A for pocket size based on CB + except KeyError: + print_log(f'{pdb_id} missing backbone atoms') + continue # missing both CB and backbone atoms + rec_num_units = sum([len(block) for block in rec_blocks]) + lig_num_units = sum([len(block) for block in lig_blocks]) + + data = ([block.to_tuple() for block in rec_blocks], [block.to_tuple() for block in lig_blocks]) + rec_seq = ''.join([VOCAB.abrv_to_symbol(block.abrv) for block in rec_blocks]) + lig_seq = ''.join([VOCAB.abrv_to_symbol(block.abrv) for block in lig_blocks]) + + # if '?' in [rec_seq[i] for i in pocket_idx] or '?' in lig_seq: + if '?' in lig_seq: + non_standard = 1 # has non-standard amino acids + + try: + list_blocks_to_pdb( + [rec_blocks, lig_blocks], + [rec_chain, lig_chain], + os.path.join(pdb_out_dir, pdb_id + '.pdb') + ) + except Exception: + # things like XE1 in 4cin_C, unknown atom + continue + + yield pdb_id, data, [ + len(rec_blocks), len(lig_blocks), rec_num_units, lig_num_units, + rec_chain, lig_chain, rec_seq, lig_seq, non_standard, + ','.join([str(idx) for idx in pocket_idx]), + ], cnt + + +def main(args): + + if not os.path.exists(args.out_dir): + os.makedirs(args.out_dir) + + # 1. get index file + with open(args.index, 'r') as fin: + lines = fin.readlines() + indexes = {} + for line in lines: + line = re.split(r'\s+', line.strip()) + if line[-1] != 'prot': + continue + pdb_id = line[0] + indexes[pdb_id + '_' + line[1]] = { + 'rec_chain': line[4], + 'pep_chain': line[1] + } + print_log(f'Total {len(indexes)} entries') + # 2. process pdb files into our format (mmap) + create_mmap( + process_iterator( + indexes, + os.path.join(os.path.dirname(args.index), 'pepbdb'), + os.path.join(args.out_dir, 'pdbs'), + args.pocket_th + ), + args.out_dir, len(indexes)) + + print_log('Finished!') + + return + + +if __name__ == '__main__': + main(parse()) \ No newline at end of file diff --git a/scripts/data_process/process.py b/scripts/data_process/process.py new file mode 100644 index 0000000000000000000000000000000000000000..a14bc228f267a21315ee9f6d018674eee8e351be --- /dev/null +++ b/scripts/data_process/process.py @@ -0,0 +1,83 @@ +#!/usr/bin/python +# -*- coding:utf-8 -*- +import os +import argparse + +from data.mmap_dataset import create_mmap +from data.format import VOCAB +from data.converter.pdb_to_list_blocks import pdb_to_list_blocks +from data.converter.blocks_interface import blocks_interface, blocks_cb_interface +from utils.logger import print_log + + +def parse(): + parser = argparse.ArgumentParser(description='Process protein-peptide complexes') + parser.add_argument('--index', type=str, default=None, help='Index file of the dataset') + parser.add_argument('--out_dir', type=str, required=True, help='Output Directory') + parser.add_argument('--pocket_th', type=float, default=10.0, + help='Threshold for determining binding site') + return parser.parse_args() + + +def process_iterator(items, pocket_th): + + for cnt, pdb_id in enumerate(items): + summary = items[pdb_id] + rec_chain, lig_chain = summary['rec_chain'], summary['pep_chain'] + non_standard = 0 + rec_blocks, lig_blocks = pdb_to_list_blocks(summary['pdb_path'], selected_chains=[rec_chain, lig_chain]) + _, (_, pep_if_idx) = blocks_interface(rec_blocks, lig_blocks, 6.0) # 6A for atomic interaction + if len(pep_if_idx) == 0: + continue + try: + _, (pocket_idx, _) = blocks_cb_interface(rec_blocks, lig_blocks, pocket_th) # 10A for pocket size based on CB + except KeyError: + print_log(f'{pdb_id} missing backbone atoms') + continue # missing both CB and backbone atoms + rec_num_units = sum([len(block) for block in rec_blocks]) + lig_num_units = sum([len(block) for block in lig_blocks]) + + data = ([block.to_tuple() for block in rec_blocks], [block.to_tuple() for block in lig_blocks]) + rec_seq = ''.join([VOCAB.abrv_to_symbol(block.abrv) for block in rec_blocks]) + lig_seq = ''.join([VOCAB.abrv_to_symbol(block.abrv) for block in lig_blocks]) + + # if '?' in [rec_seq[i] for i in pocket_idx] or '?' in lig_seq: + if '?' in lig_seq: + non_standard = 1 # has non-standard amino acids + + yield pdb_id, data, [ + len(rec_blocks), len(lig_blocks), rec_num_units, lig_num_units, + rec_chain, lig_chain, rec_seq, lig_seq, non_standard, + ','.join([str(idx) for idx in pocket_idx]), + ], cnt + + +def main(args): + + if not os.path.exists(args.out_dir): + os.makedirs(args.out_dir) + + # 1. get index file + with open(args.index, 'r') as fin: + lines = fin.readlines() + indexes = {} + root_dir = os.path.dirname(args.index) + for line in lines: + line = line.strip().split('\t') + pdb_id = line[0] + indexes[pdb_id] = { + 'rec_chain': line[1], + 'pep_chain': line[2], + 'pdb_path': os.path.join(root_dir, 'pdbs', pdb_id + '.pdb') + } + + # 3. process pdb files into our format (mmap) + create_mmap( + process_iterator(indexes, args.pocket_th), + args.out_dir, len(indexes)) + + print_log('Finished!') + + +if __name__ == '__main__': + main(parse()) \ No newline at end of file diff --git a/scripts/data_process/split.py b/scripts/data_process/split.py new file mode 100644 index 0000000000000000000000000000000000000000..636fa3e090e3c3d90f00e7bbb7d6b1bfc8290dd5 --- /dev/null +++ b/scripts/data_process/split.py @@ -0,0 +1,57 @@ +#!/usr/bin/python +# -*- coding:utf-8 -*- +import os +import argparse + +import numpy as np + +from utils.logger import print_log + + +def parse(): + parser = argparse.ArgumentParser(description='Split peptide data') + parser.add_argument('--train_index', type=str, required=True, help='Path for training index') + parser.add_argument('--valid_index', type=str, required=True, help='Path for validation index') + parser.add_argument('--test_index', type=str, default=None, help='Path for test index') + parser.add_argument('--processed_dir', type=str, required=True, help='processed directory') + return parser.parse_args() + + +def read_index(mmap_dir): + items = {} + index = os.path.join(mmap_dir, 'index.txt') + with open(index, 'r') as fin: + lines = fin.readlines() + for line in lines: + values = line.strip().split('\t') + items[values[0]] = line + return items + + +def transform(items, path, out): + ids = {} + with open(path, 'r') as fin: + lines = fin.readlines() + for line in lines: + ids[line.split('\t')[0]] = 1 + with open(out, 'w') as fout: + for _id in ids: fout.write(items[_id]) + + +def main(args): + + # load index file + items = read_index(args.processed_dir) + + # load training/validation/(test) + transform(items, args.train_index, os.path.join(args.processed_dir, 'train_index.txt')) + transform(items, args.valid_index, os.path.join(args.processed_dir, 'valid_index.txt')) + if args.test_index is not None: + transform(items, args.test_index, os.path.join(args.processed_dir, 'test_index.txt')) + + print_log('Done') + + +if __name__ == '__main__': + np.random.seed(12) + main(parse()) \ No newline at end of file diff --git a/scripts/run_exp_pipe.sh b/scripts/run_exp_pipe.sh new file mode 100644 index 0000000000000000000000000000000000000000..76ed8947da4345fc698d571b1fd41ba847459b3e --- /dev/null +++ b/scripts/run_exp_pipe.sh @@ -0,0 +1,95 @@ +#!/bin/bash +########## setup project directory ########## +CODE_DIR=`realpath $(dirname "$0")/..` +echo "Locate the project folder at ${CODE_DIR}" +cd ${CODE_DIR} + +######### check number of args ########## +HELP="Usage example: GPU=0 bash $0 [mode: e.g. 1111]" +if [ -z $1 ]; then + echo "Experiment name missing. ${HELP}" + exit 1; +else + NAME=$1 +fi +if [ -z $2 ]; then + echo "Autoencoder config missing. ${HELP}" + exit 1; +else + AECONFIG=$2 +fi +if [ -z $3 ]; then + echo "LDM config missing. ${HELP}" + exit 1; +else + LDMCONFIG=$3 +fi +if [ -z $4 ]; then + echo "setup LDM dist config missing. ${HELP}" + exit 1; +else + LATENT_DIST_CONFIG=$4 +fi +if [ -z $5 ]; then + echo "LDM test config missing. ${HELP}" + exit 1; +else + TEST_CONFIG=$5 +fi +if [ -z $6 ]; then + MODE=1111 +else + MODE=$6 +fi +echo "Mode: $MODE, [train AE] / [train LDM] / [Generate] / [Evalulation]" +TRAIN_AE_FLAG=${MODE:0:1} +TRAIN_LDM_FLAG=${MODE:1:1} +GENERATE_FLAG=${MODE:2:1} +EVAL_FLAG=${MODE:3:1} + +AE_SAVE_DIR=./exps/$NAME/AE +LDM_SAVE_DIR=./exps/$NAME/LDM +OUTLOG=./exps/$NAME/output.log + +if [[ ! -e ./exps/$NAME ]]; then + mkdir -p ./exps/$NAME +elif [[ -e $AE_SAVE_DIR ]] && [ "$TRAIN_AE_FLAG" = "1" ]; then + echo "Directory ${AE_SAVE_DIR} exisits! But training flag is 1!" + exit 1; +elif [[ -e $LDM_SAVE_DIR ]] && [ "$TRAIN_LDM_FLAG" = "1" ]; then + echo "Directory ${LDM_SAVE_DIR} exisits! But training flag is 1!" + exit 1; +fi + +########## train autoencoder ########## +echo "Training Autoencoder with config $AECONFIG:" > $OUTLOG +cat $AECONFIG >> $OUTLOG +if [ "$TRAIN_AE_FLAG" = "1" ]; then + bash scripts/train.sh $AECONFIG --trainer.config.save_dir=$AE_SAVE_DIR +fi + +########## train ldm ########## +echo "Training LDM with config $LDMCONFIG:" >> $OUTLOG +cat $LDMCONFIG >> $OUTLOG +AE_CKPT=`cat ${AE_SAVE_DIR}/version_0/checkpoint/topk_map.txt | head -n 1 | awk -F " " '{print $2}'` +echo "Using Autoencoder checkpoint: ${AE_CKPT}" >> $OUTLOG +if [ "$TRAIN_LDM_FLAG" = "1" ]; then + bash scripts/train.sh $LDMCONFIG --trainer.config.save_dir=$LDM_SAVE_DIR --model.autoencoder_ckpt=$AE_CKPT +fi + +########## get latent distance ########## +LDM_CKPT=`cat ${LDM_SAVE_DIR}/version_0/checkpoint/topk_map.txt | head -n 1 | awk -F " " '{print $2}'` +echo "Get distances in latent space" >> $OUTLOG +python setup_latent_guidance.py --config ${LATENT_DIST_CONFIG} --ckpt ${LDM_CKPT} --gpu ${GPU:0:1} >> $OUTLOG + +########## generate ########## +echo "Generate results Using LDM checkpoint: ${LDM_CKPT}" >> $OUTLOG +if [ "$GENERATE_FLAG" = "1" ]; then + python generate.py --config $TEST_CONFIG --ckpt $LDM_CKPT --gpu ${GPU:0:1} +fi + +########## cal metrics ########## +if [ "$EVAL_FLAG" = "1" ]; then + echo "Evaluation:" >> $OUTLOG + python cal_metrics.py --results ${LDM_SAVE_DIR}/version_0/results/results.jsonl >> $OUTLOG +fi diff --git a/scripts/train.sh b/scripts/train.sh new file mode 100644 index 0000000000000000000000000000000000000000..3397b40d89847866116bd60a63a8f539b23cad29 --- /dev/null +++ b/scripts/train.sh @@ -0,0 +1,49 @@ +#!/bin/bash +########## Instruction ########## +# This script takes three optional environment variables: +# GPU / ADDR / PORT +# e.g. Use gpu 0, 1 and 4 for training, set distributed training +# master address and port to localhost:9901, the command is as follows: +# +# GPU="0,1,4" ADDR=localhost PORT=9901 bash train.sh +# +# Default value: GPU=-1 (use cpu only), ADDR=localhost, PORT=9901 +# Note that if your want to run multiple distributed training tasks, +# either the addresses or ports should be different between +# each pair of tasks. +######### end of instruction ########## + + +########## setup project directory ########## +CODE_DIR=`realpath $(dirname "$0")/..` +echo "Locate the project folder at ${CODE_DIR}" + + +########## parsing yaml configs ########## +if [ -z $1 ]; then + echo "Config missing. Usage example: GPU=0,1 bash $0 [optional arguments]" + exit 1; +fi + + +########## setup distributed training ########## +GPU="${GPU:--1}" # default using CPU +MASTER_ADDR="${ADDR:-localhost}" +MASTER_PORT="${PORT:-9901}" +echo "Using GPUs: $GPU" +echo "Master address: ${MASTER_ADDR}, Master port: ${MASTER_PORT}" + +export CUDA_VISIBLE_DEVICES=$GPU +GPU_ARR=(`echo $GPU | tr ',' ' '`) + +if [ ${#GPU_ARR[@]} -gt 1 ]; then + export OMP_NUM_THREADS=2 + PREFIX="torchrun --nproc_per_node=${#GPU_ARR[@]} --rdzv_backend=c10d --rdzv_endpoint=${MASTER_ADDR}:${MASTER_PORT} --nnodes=1" +else + PREFIX="python" +fi + + +########## start training ########## +cd $CODE_DIR +${PREFIX} train.py --gpus "${!GPU_ARR[@]}" --config $@ diff --git a/setup_latent_guidance.py b/setup_latent_guidance.py new file mode 100644 index 0000000000000000000000000000000000000000..f27b11db651b82f5d1ed64d31941dde0e2af8459 --- /dev/null +++ b/setup_latent_guidance.py @@ -0,0 +1,58 @@ +#!/usr/bin/python +# -*- coding:utf-8 -*- +import os +import yaml +import argparse +from tqdm import tqdm + +import torch + +from generate import get_best_ckpt, to_device +from data import create_dataloader, create_dataset + + +def main(args): + config = yaml.safe_load(open(args.config, 'r')) + # load model + b_ckpt = args.ckpt if args.ckpt.endswith('.ckpt') else get_best_ckpt(args.ckpt) + print(f'Using checkpoint {b_ckpt}') + model = torch.load(b_ckpt, map_location='cpu') + device = torch.device('cpu' if args.gpu == -1 else f'cuda:{args.gpu}') + model.to(device) + model.eval() + + # load data + _, _, test_set = create_dataset(config['dataset']) + test_loader = create_dataloader(test_set, config['dataloader']) + + all_dists = [] + + with torch.no_grad(): + for batch in tqdm(test_loader): + batch = to_device(batch, device) + H, Z, _, _ = model.autoencoder.encode( + batch['X'], batch['S'], batch['mask'], batch['position_ids'], + batch['lengths'], batch['atom_mask'], no_randomness=True + ) + pos = batch['position_ids'][batch['mask']] + Z = Z.squeeze(1) + dists = torch.norm(Z[1:] - Z[:-1], dim=-1) # [N] + pos_dist = pos[1:] - pos[:-1] + dists = dists[pos_dist == 1] + all_dists.append(dists) + all_dists = torch.cat(all_dists, dim=0) + mean, std = torch.mean(all_dists), torch.std(all_dists) + print(mean, std) + model.set_consec_dist(mean.item(), std.item()) + torch.save(model, b_ckpt) + +def parse(): + parser = argparse.ArgumentParser(description='Calculate distance between consecutive latent points') + parser.add_argument('--config', type=str, required=True) + parser.add_argument('--ckpt', type=str, required=True) + parser.add_argument('--gpu', type=int, default=0) + return parser.parse_args() + + +if __name__ == '__main__': + main(parse()) \ No newline at end of file diff --git a/train.py b/train.py new file mode 100644 index 0000000000000000000000000000000000000000..bcb68503b03fbd8420d30e68a5b203c26bafa147 --- /dev/null +++ b/train.py @@ -0,0 +1,78 @@ +#!/usr/bin/python +# -*- coding:utf-8 -*- +import os +import argparse + +import yaml +import torch + +from utils.logger import print_log +from utils.random_seed import setup_seed, SEED +from utils.config_utils import overwrite_values +from utils import register as R + +########### Import your packages below ########## +import models +from trainer import create_trainer +from data import create_dataset, create_dataloader +from utils.nn_utils import count_parameters + + +def parse(): + parser = argparse.ArgumentParser(description='training') + + # device + parser.add_argument('--gpus', type=int, nargs='+', required=True, help='gpu to use, -1 for cpu') + parser.add_argument("--local_rank", type=int, default=-1, + help="Local rank. Necessary for using the torch.distributed.launch utility.") + + # config + parser.add_argument('--config', type=str, required=True, help='Path to the yaml configure') + parser.add_argument('--seed', type=int, default=SEED, help='Random seed') + + return parser.parse_known_args() + + +def load_ckpt(model, ckpt): + trained_model = torch.load(ckpt, map_location='cpu') + model.load_state_dict(trained_model.state_dict()) + return model + + +def main(args, opt_args): + + # load config + config = yaml.safe_load(open(args.config, 'r')) + config = overwrite_values(config, opt_args) + + ########## define your model ######### + model = R.construct(config['model']) + if 'load_ckpt' in config: + model = load_ckpt(model, config['load_ckpt']) + + ########### load your train / valid set ########### + train_set, valid_set, _ = create_dataset(config['dataset']) + + ########## define your trainer/trainconfig ######### + if len(args.gpus) > 1: + args.local_rank = int(os.environ['LOCAL_RANK']) + torch.cuda.set_device(args.local_rank) + torch.distributed.init_process_group(backend='nccl', world_size=len(args.gpus)) + else: + args.local_rank = -1 + + if args.local_rank <= 0: + print_log(f'Number of parameters: {count_parameters(model) / 1e6} M') + + train_loader = create_dataloader(train_set, config['dataloader'], len(args.gpus)) + valid_loader = create_dataloader(valid_set, config['dataloader'], validation=True) + + trainer = create_trainer(config, model, train_loader, valid_loader) + trainer.train(args.gpus, args.local_rank) + + +if __name__ == '__main__': + args, opt_args = parse() + print_log(f'Overwritting args: {opt_args}') + setup_seed(args.seed) + main(args, opt_args) diff --git a/trainer/__init__.py b/trainer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..64b4c263f2fad3664169eb69a9f20b4639da5b68 --- /dev/null +++ b/trainer/__init__.py @@ -0,0 +1,17 @@ +#!/usr/bin/python +# -*- coding:utf-8 -*- +from .autoencoder_trainer import AutoEncoderTrainer +from .ldm_trainer import LDMTrainer + +import utils.register as R + + +def create_trainer(config, model, train_loader, valid_loader): + return R.construct( + config['trainer'], + model=model, + train_loader=train_loader, + valid_loader=valid_loader, + save_config=config) + + diff --git a/trainer/abs_trainer.py b/trainer/abs_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..5af424cb0f6f33b7e2dd14085236f2e52e875973 --- /dev/null +++ b/trainer/abs_trainer.py @@ -0,0 +1,300 @@ +#!/usr/bin/python +# -*- coding:utf-8 -*- +import os +import re +import yaml +from copy import deepcopy +from tqdm import tqdm + +import numpy as np +import torch +from torch.utils.tensorboard import SummaryWriter + +from utils.oom_decorator import OOMReturn, safe_backward +from utils.logger import print_log + +########### Import your packages below ########## + + +class TrainConfig: + def __init__(self, save_dir, max_epoch, warmup=0, + metric_min_better=True, patience=3, + grad_clip=None, save_topk=-1, # -1 for save all + grad_interval=1, # parameter update interval + val_freq=1, # frequence for validation + **kwargs): + self.save_dir = save_dir + self.max_epoch = max_epoch + self.warmup = warmup + self.metric_min_better = metric_min_better + self.patience = patience if patience > 0 else max_epoch + self.grad_clip = grad_clip + self.save_topk = save_topk + self.grad_interval = grad_interval + self.val_freq = val_freq + self.__dict__.update(kwargs) + + def add_parameter(self, **kwargs): + self.__dict__.update(kwargs) + + def __str__(self): + return str(self.__class__) + ': ' + str(self.__dict__) + + +class Trainer: + def __init__(self, model, train_loader, valid_loader, config: dict, save_config: dict): + self.model = model + self.config = TrainConfig(**config) + self.save_config = save_config + self.optimizer = self.get_optimizer() + sched_config = self.get_scheduler(self.optimizer) + if sched_config is None: + sched_config = { + 'scheduler': None, + 'frequency': None + } + self.scheduler = sched_config['scheduler'] + self.sched_freq = sched_config['frequency'] + self.train_loader = train_loader + self.valid_loader = valid_loader + + # distributed training + self.local_rank = -1 + + # log + self.version = self._get_version() + self.config.save_dir = os.path.join(self.config.save_dir, f'version_{self.version}') + self.model_dir = os.path.join(self.config.save_dir, 'checkpoint') + self.writer = None # initialize right before training + self.writer_buffer = {} + + # training process recording + self.global_step = 0 + self.valid_global_step = 0 + self.epoch = 0 + self.last_valid_metric = None + self.topk_ckpt_map = [] # smaller index means better ckpt + self.patience = self.config.patience + + @classmethod + def to_device(cls, data, device): + if isinstance(data, dict): + for key in data: + data[key] = cls.to_device(data[key], device) + elif isinstance(data, list) or isinstance(data, tuple): + res = [cls.to_device(item, device) for item in data] + data = type(data)(res) + elif hasattr(data, 'to'): + data = data.to(device) + return data + + def _is_main_proc(self): + return self.local_rank == 0 or self.local_rank == -1 + + def _get_version(self): + version, pattern = -1, r'version_(\d+)' + if os.path.exists(self.config.save_dir): + for fname in os.listdir(self.config.save_dir): + ver = re.findall(pattern, fname) + if len(ver): + version = max(int(ver[0]), version) + return version + 1 + + def is_oom_return(self, value): + return isinstance(value, OOMReturn) + + def _train_epoch(self, device): + if self.train_loader.sampler is not None and self.local_rank != -1: # distributed + self.train_loader.sampler.set_epoch(self.epoch) + t_iter = tqdm(self.train_loader) if self._is_main_proc() else self.train_loader + for batch in t_iter: + batch = self.to_device(batch, device) + loss = self.train_step(batch, self.global_step) + if self.is_oom_return(loss): + print_log(f'Out of memory, local rank {self.local_rank}', level='WARN') + loss = loss.fake_loss + elif torch.isnan(loss): + print_log(f'Loss is nan, local_rank {self.local_rank}', level='WARN') + loss = sum([p.norm() for p in self.model.parameters() if p.dtype == torch.float]) * 0.0 + self.optimizer.zero_grad() + backward_ok = safe_backward(loss, self.model) + if not backward_ok: + print_log(f'Backward out of memory, skip', level='WARN') + loss = loss.detach() # manually delete the computing graph + if self.config.grad_clip is not None: + ori_grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.grad_clip) + # recording gradients + self.log('Grad Norm', ori_grad_norm.cpu(), self.global_step) + self.optimizer.step() + if hasattr(t_iter, 'set_postfix'): + t_iter.set_postfix(loss=loss.item(), version=self.version) + self.global_step += 1 + if self.sched_freq == 'batch': + self.scheduler.step() + if self.sched_freq == 'epoch': + self.scheduler.step() + self._train_epoch_end(device) + + def _train_epoch_end(self, device): + return + + def _aggregate_val_metric(self, metric_arr): + return np.mean(metric_arr) + + def _valid_epoch_begin(self, device): + return + + def _valid_epoch(self, device): + metric_arr = [] + self.model.eval() + self._valid_epoch_begin(device) + with torch.no_grad(): + t_iter = tqdm(self.valid_loader) if self._is_main_proc() else self.valid_loader + for batch in t_iter: + batch = self.to_device(batch, device) + metric = self.valid_step(batch, self.valid_global_step) + metric_arr.append(metric.cpu().item()) + self.valid_global_step += 1 + + # judge + valid_metric = self._aggregate_val_metric(metric_arr) + if self._is_main_proc(): + save_path = os.path.join(self.model_dir, f'epoch{self.epoch}_step{self.global_step}.ckpt') + module_to_save = self.model.module if self.local_rank == 0 else self.model + torch.save(module_to_save, save_path) + self._maintain_topk_checkpoint(valid_metric, save_path) + print_log(f'Validation: {valid_metric}, save path: {save_path}') + if self._metric_better(valid_metric): + self.patience = self.config.patience + else: + self.patience -= 1 + if self.sched_freq == 'val_epoch': + self.scheduler.step(valid_metric) + self.last_valid_metric = valid_metric + # write valid_metric + for name in self.writer_buffer: + value = np.mean(self.writer_buffer[name]) + if self._is_main_proc(): + print_log(f'{name}: {value}') + self.log(name, value, self.epoch) + self.writer_buffer = {} + self._valid_epoch_end(device) + self.model.train() + + def _valid_epoch_end(self, device): + return + + def _metric_better(self, new): + old = self.last_valid_metric + if old is None: + return True + if self.config.metric_min_better: + return new < old + else: + return old < new + + def _maintain_topk_checkpoint(self, valid_metric, ckpt_path): + topk = self.config.save_topk + if self.config.metric_min_better: + better = lambda a, b: a < b + else: + better = lambda a, b: a > b + insert_pos = len(self.topk_ckpt_map) + for i, (metric, _) in enumerate(self.topk_ckpt_map): + if better(valid_metric, metric): + insert_pos = i + break + self.topk_ckpt_map.insert(insert_pos, (valid_metric, ckpt_path)) + + # maintain topk + if topk > 0: + while len(self.topk_ckpt_map) > topk: + last_ckpt_path = self.topk_ckpt_map[-1][1] + os.remove(last_ckpt_path) + self.topk_ckpt_map.pop() + + # save map + topk_map_path = os.path.join(self.model_dir, 'topk_map.txt') + with open(topk_map_path, 'w') as fout: + for metric, path in self.topk_ckpt_map: + fout.write(f'{metric}: {path}\n') + + def _modify_writer(self): + return + + def train(self, device_ids, local_rank): + # set local rank + self.local_rank = local_rank + # init writer + if self._is_main_proc(): + self.writer = SummaryWriter(self.config.save_dir) + self._modify_writer() + if not os.path.exists(self.model_dir): + os.makedirs(self.model_dir) + with open(os.path.join(self.config.save_dir, 'train_config.yaml'), 'w') as fout: + yaml.safe_dump(self.save_config, fout) + # main device + main_device_id = local_rank if local_rank != -1 else device_ids[0] + device = torch.device('cpu' if main_device_id == -1 else f'cuda:{main_device_id}') + self.model.to(device) + if local_rank != -1: + print_log(f'Using data parallel, local rank {local_rank}, all {device_ids}') + self.model = torch.nn.parallel.DistributedDataParallel( + self.model, device_ids=[local_rank], output_device=local_rank, + find_unused_parameters=True + ) + else: + print_log(f'training on {device_ids}') + for _ in range(self.config.max_epoch): + print_log(f'epoch{self.epoch} starts') if self._is_main_proc() else 1 + self._train_epoch(device) + if (self.epoch + 1) % self.config.val_freq == 0: + print_log(f'validating ...') if self._is_main_proc() else 1 + self._valid_epoch(device) + self.epoch += 1 + if self.patience <= 0: + break + + def log(self, name, value, step, val=False, batch_size=1): + if self._is_main_proc(): + if isinstance(value, torch.Tensor): + value = value.cpu().item() + if val: + if name not in self.writer_buffer: + self.writer_buffer[name] = [] + self.writer_buffer[name].extend([value] * batch_size) + else: + self.writer.add_scalar(name, value, step) + + # define optimizer + def get_optimizer(self): + opt_cfg = deepcopy(self.config.optimizer) + cls = getattr(torch.optim, opt_cfg.pop('class')) + # optimizer = cls(self.model.parameters(), **opt_cfg) + optimizer = cls(filter(lambda p: p.requires_grad, self.model.parameters()), **opt_cfg) + return optimizer + + # scheduler example: linear. Return None if no scheduler is needed. + def get_scheduler(self, optimizer): + if not hasattr(self.config, 'scheduler'): + return None + sched_cfg = deepcopy(self.config.scheduler) + cls = getattr(torch.optim.lr_scheduler, sched_cfg.pop('class')) + freq = sched_cfg.pop('frequency') + return { + 'scheduler': cls(optimizer, **sched_cfg), + 'frequency': freq # batch/epoch/val_epoch + } + + ########## Overload these functions below ########## + # train step, note that batch should be dict/list/tuple/instance. Objects with .to(device) attribute will be automatically moved to the same device as the model + def train_step(self, batch, batch_idx): + loss = self.model(batch) + self.log('Loss/train', loss, batch_idx) + return loss + + # validation step + def valid_step(self, batch, batch_idx): + loss = self.model(batch) + self.log('Loss/validation', loss, batch_idx, val=True) + return loss diff --git a/trainer/autoencoder_trainer.py b/trainer/autoencoder_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..bcf47e2e9350b9f10d067bd732055597ab9c5f9e --- /dev/null +++ b/trainer/autoencoder_trainer.py @@ -0,0 +1,59 @@ +#!/usr/bin/python +# -*- coding:utf-8 -*- +from .abs_trainer import Trainer +from utils import register as R + + +@R.register('AutoEncoderTrainer') +class AutoEncoderTrainer(Trainer): + def __init__(self, model, train_loader, valid_loader, config: dict, save_config: dict): + super().__init__(model, train_loader, valid_loader, config, save_config) + self.max_step = self.config.max_epoch * len(self.train_loader) + + ########## Override start ########## + + def train_step(self, batch, batch_idx): + return self.share_step(batch, batch_idx, val=False) + + def valid_step(self, batch, batch_idx): + return self.share_step(batch, batch_idx, val=True) + + def _train_epoch_end(self, device): + dataset = self.train_loader.dataset + if hasattr(dataset, 'update_epoch'): + dataset.update_epoch() + return super()._train_epoch_end(device) + + ########## Override end ########## + + def share_step(self, batch, batch_idx, val=False): + results = self.model(**batch) + if self.is_oom_return(results): + return results + loss, seq_detail, structure_detail, (h_kl_loss, z_kl_loss, coord_reg_loss) = results + snll, aar = seq_detail + closs, struct_loss_profile = structure_detail + # ed_loss, r_ed_losses = ed_detail + + log_type = 'Validation' if val else 'Train' + + self.log(f'Overall/Loss/{log_type}', loss, batch_idx, val) + + self.log(f'Seq/SNLL/{log_type}', snll, batch_idx, val) + self.log(f'Seq/KLloss/{log_type}', h_kl_loss, batch_idx, val) + self.log(f'Seq/AAR/{log_type}', aar, batch_idx, val) + + self.log(f'Struct/CLoss/{log_type}', closs, batch_idx, val) + self.log(f'Struct/KLloss/{log_type}', z_kl_loss, batch_idx, val) + self.log(f'Struct/CoordRegloss/{log_type}', coord_reg_loss, batch_idx, val) + for name in struct_loss_profile: + self.log(f'Struct/{name}/{log_type}', struct_loss_profile[name], batch_idx, val) + # self.log(f'Struct/XLoss/{log_type}', xloss, batch_idx, val) + # self.log(f'Struct/BondLoss/{log_type}', bond_loss, batch_idx, val) + # self.log(f'Struct/SidechainBondLoss/{log_type}', sc_bond_loss, batch_idx, val) + + if not val: + lr = self.optimizer.state_dict()['param_groups'][0]['lr'] + self.log('lr', lr, batch_idx, val) + + return loss diff --git a/trainer/ldm_trainer.py b/trainer/ldm_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..4edd2a8b90332d9030ba575adfd10ac1fc1efac6 --- /dev/null +++ b/trainer/ldm_trainer.py @@ -0,0 +1,98 @@ +#!/usr/bin/python +# -*- coding:utf-8 -*- +from math import pi, cos + +import torch +from torch_scatter import scatter_mean + +from .abs_trainer import Trainer +from utils import register as R + + +@R.register('LDMTrainer') +class LDMTrainer(Trainer): + def __init__(self, model, train_loader, valid_loader, config: dict, save_config: dict, criterion: str='AAR'): + super().__init__(model, train_loader, valid_loader, config, save_config) + self.max_step = self.config.max_epoch * len(self.train_loader) + self.criterion = criterion + assert criterion in ['AAR', 'RMSD', 'Loss'], f'Criterion {criterion} not implemented' + self.rng_state = None + + ########## Override start ########## + + def train_step(self, batch, batch_idx): + results = self.model(**batch) + if self.is_oom_return(results): + return results + loss, loss_dict = results + + self.log('Overall/Loss/Train', loss, batch_idx, val=False) + + if 'H' in loss_dict: + self.log('Seq/Loss_H/Train', loss_dict['H'], batch_idx, val=False) + + if 'X' in loss_dict: + self.log('Struct/Loss_X/Train', loss_dict['X'], batch_idx, val=False) + + lr = self.optimizer.state_dict()['param_groups'][0]['lr'] + self.log('lr', lr, batch_idx, val=False) + + return loss + + def _valid_epoch_begin(self, device): + self.rng_state = torch.random.get_rng_state() + torch.manual_seed(12) # each validation epoch uses the same initial state + return super()._valid_epoch_begin(device) + + def _valid_epoch_end(self, device): + torch.random.set_rng_state(self.rng_state) + return super()._valid_epoch_end(device) + + def valid_step(self, batch, batch_idx): + loss, loss_dict = self.model(**batch) + self.log('Overall/Loss/Validation', loss, batch_idx, val=True) + if 'H' in loss_dict: self.log('Seq/Loss_H/Validation', loss_dict['H'], batch_idx, val=True) + if 'X' in loss_dict: self.log('Struct/Loss_X/Validation', loss_dict['X'], batch_idx, val=True) + # disable sidechain optimization as it may stuck for early validations where the model is still weak + if self.local_rank != -1: # ddp + sample_X, sample_S, _ = self.model.module.sample(**batch, return_tensor=True, optimize_sidechain=False) + else: + sample_X, sample_S, _ = self.model.sample(**batch, return_tensor=True, optimize_sidechain=False) + mask_generate = batch['mask'] + # batch ids + batch_ids = torch.zeros_like(mask_generate).long() + batch_ids[torch.cumsum(batch['lengths'], dim=0)[:-1]] = 1 + batch_ids.cumsum_(dim=0) + batch_ids = batch_ids[mask_generate] + + if sample_S is not None: + # aar + aar = (batch['S'][mask_generate] == sample_S).float() + aar = torch.mean(scatter_mean(aar, batch_ids, dim=-1)) + self.log('Seq/AAR/Validation', aar, batch_idx, val=True) + + # ca rmsd + if sample_X is not None: + atom_mask = batch['atom_mask'][mask_generate][:, 1] + rmsd = ((batch['X'][mask_generate][:, 1][atom_mask] - sample_X[:, 1][atom_mask]) ** 2).sum(-1) # [Ntgt] + rmsd = torch.sqrt(scatter_mean(rmsd, batch_ids[atom_mask], dim=-1)) # [bs] + rmsd = torch.mean(rmsd) + + self.log('Struct/CA_RMSD/Validation', rmsd, batch_idx, val=True) + + if self.criterion == 'AAR': + return aar.detach() + elif self.criterion == 'RMSD': + return rmsd.detach() + elif self.criterion == 'Loss': + return loss.detach() + else: + raise NotImplementedError(f'Criterion {self.criterion} not implemented') + + def _train_epoch_end(self, device): + dataset = self.train_loader.dataset + if hasattr(dataset, 'update_epoch'): + dataset.update_epoch() + return super()._train_epoch_end(device) + + ########## Override end ########## \ No newline at end of file diff --git a/utils/config_utils.py b/utils/config_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7cf399d6b0370cab81b0504ad8f34b75cf6175c8 --- /dev/null +++ b/utils/config_utils.py @@ -0,0 +1,33 @@ +#!/usr/bin/python +# -*- coding:utf-8 -*- +from typing import List + + +def format_args(args: List[str]): + clean_args = [] + for arg in args: + if not (arg.startswith('-') or arg.startswith('--')): # value + clean_args.append(arg) + else: + arg = arg.lstrip('-').lstrip('-') + clean_args.extend(arg.split('=')) + return clean_args + + +def get_parent_dict(config: dict, key: str): + key_each_depth = key.split('.') + for k in key_each_depth[:-1]: + if k not in config: + raise KeyError(f'Path key {key} not in the dict') + config = config[k] + return config, key_each_depth[-1] # last key + + +def overwrite_values(config, args): + args = format_args(args) + keys, values = args[0::2], args[1::2] + for key, value in zip(keys, values): + parent, last_key = get_parent_dict(config, key) + ori_value = parent[last_key] + parent[last_key] = type(ori_value)(value) + return config diff --git a/utils/const.py b/utils/const.py new file mode 100644 index 0000000000000000000000000000000000000000..5516cbb944b034430653c647184382fe66c60c72 --- /dev/null +++ b/utils/const.py @@ -0,0 +1,733 @@ +#!/usr/bin/python +# -*- coding:utf-8 -*- +# Copyright Generate Biomedicines, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Dictionary containing ideal internal coordinates and chi angle assignments + for building amino acid 3D coordinates""" +from typing import Dict + + +AA_GEOMETRY: Dict[str, dict] = { + "ALA": { + "atoms": ["CB"], + "chi_indices": [], + "parents": [["N", "C", "CA"]], + "types": {"C": "C", "CA": "CT1", "CB": "CT3", "N": "NH1", "O": "O"}, + "z-angles": [111.09], + "z-dihedrals": [123.23], + "z-lengths": [1.55], + }, + "ARG": { + "atoms": ["CB", "CG", "CD", "NE", "CZ", "NH1", "NH2"], + "chi_indices": [1, 2, 3, 4], + "parents": [ + ["N", "C", "CA"], + ["N", "CA", "CB"], + ["CA", "CB", "CG"], + ["CB", "CG", "CD"], + ["CG", "CD", "NE"], + ["CD", "NE", "CZ"], + ["NH1", "NE", "CZ"], + ], + "types": { + "C": "C", + "CA": "CT1", + "CB": "CT2", + "CD": "CT2", + "CG": "CT2", + "CZ": "C", + "N": "NH1", + "NE": "NC2", + "NH1": "NC2", + "NH2": "NC2", + "O": "O", + }, + "z-angles": [112.26, 115.95, 114.01, 107.09, 123.05, 118.06, 122.14], + "z-dihedrals": [123.64, 180.0, 180.0, 180.0, 180.0, 180.0, 178.64], + "z-lengths": [1.56, 1.55, 1.54, 1.5, 1.34, 1.33, 1.33], + }, + "ASN": { + "atoms": ["CB", "CG", "OD1", "ND2"], + "chi_indices": [1, 2], + "parents": [ + ["N", "C", "CA"], + ["N", "CA", "CB"], + ["CA", "CB", "CG"], + ["OD1", "CB", "CG"], + ], + "types": { + "C": "C", + "CA": "CT1", + "CB": "CT2", + "CG": "CC", + "N": "NH1", + "ND2": "NH2", + "O": "O", + "OD1": "O", + }, + "z-angles": [113.04, 114.3, 122.56, 116.15], + "z-dihedrals": [121.18, 180.0, 180.0, -179.19], + "z-lengths": [1.56, 1.53, 1.23, 1.35], + }, + "ASP": { + "atoms": ["CB", "CG", "OD1", "OD2"], + "chi_indices": [1, 2], + "parents": [ + ["N", "C", "CA"], + ["N", "CA", "CB"], + ["CA", "CB", "CG"], + ["OD1", "CB", "CG"], + ], + "types": { + "C": "C", + "CA": "CT1", + "CB": "CT2A", + "CG": "CC", + "N": "NH1", + "O": "O", + "OD1": "OC", + "OD2": "OC", + }, + "z-angles": [114.1, 112.6, 117.99, 117.7], + "z-dihedrals": [122.33, 180.0, 180.0, -170.23], + "z-lengths": [1.56, 1.52, 1.26, 1.25], + }, + "CYS": { + "atoms": ["CB", "SG"], + "chi_indices": [1], + "parents": [["N", "C", "CA"], ["N", "CA", "CB"]], + "types": {"C": "C", "CA": "CT1", "CB": "CT2", "N": "NH1", "O": "O", "SG": "S"}, + "z-angles": [111.98, 113.87], + "z-dihedrals": [121.79, 180.0], + "z-lengths": [1.56, 1.84], + }, + "GLN": { + "atoms": ["CB", "CG", "CD", "OE1", "NE2"], + "chi_indices": [1, 2, 3], + "parents": [ + ["N", "C", "CA"], + ["N", "CA", "CB"], + ["CA", "CB", "CG"], + ["CB", "CG", "CD"], + ["OE1", "CG", "CD"], + ], + "types": { + "C": "C", + "CA": "CT1", + "CB": "CT2", + "CD": "CC", + "CG": "CT2", + "N": "NH1", + "NE2": "NH2", + "O": "O", + "OE1": "O", + }, + "z-angles": [111.68, 115.52, 112.5, 121.52, 116.84], + "z-dihedrals": [121.91, 180.0, 180.0, 180.0, 179.57], + "z-lengths": [1.55, 1.55, 1.53, 1.23, 1.35], + }, + "GLU": { + "atoms": ["CB", "CG", "CD", "OE1", "OE2"], + "chi_indices": [1, 2, 3], + "parents": [ + ["N", "C", "CA"], + ["N", "CA", "CB"], + ["CA", "CB", "CG"], + ["CB", "CG", "CD"], + ["OE1", "CG", "CD"], + ], + "types": { + "C": "C", + "CA": "CT1", + "CB": "CT2A", + "CD": "CC", + "CG": "CT2", + "N": "NH1", + "O": "O", + "OE1": "OC", + "OE2": "OC", + }, + "z-angles": [111.71, 115.69, 115.73, 114.99, 120.08], + "z-dihedrals": [121.9, 180.0, 180.0, 180.0, -179.1], + "z-lengths": [1.55, 1.56, 1.53, 1.26, 1.25], + }, + "GLY": { + "atoms": [], + "chi_indices": [], + "parents": [], + "types": {"C": "C", "CA": "CT2", "N": "NH1", "O": "O"}, + "z-angles": [], + "z-dihedrals": [], + "z-lengths": [], + }, + "HIS": { + "atoms": ["CB", "CG", "ND1", "CD2", "CE1", "NE2"], + "chi_indices": [1, 2], + "parents": [ + ["N", "C", "CA"], + ["N", "CA", "CB"], + ["CA", "CB", "CG"], + ["ND1", "CB", "CG"], + ["CB", "CG", "ND1"], + ["CB", "CG", "CD2"], + ], + "types": { + "C": "C", + "CA": "CT1", + "CB": "CT2", + "CD2": "CPH1", + "CE1": "CPH2", + "CG": "CPH1", + "N": "NH1", + "ND1": "NR1", + "NE2": "NR2", + "O": "O", + }, + "z-angles": [109.99, 114.05, 124.1, 129.6, 107.03, 110.03], + "z-dihedrals": [122.46, 180.0, 90.0, -171.29, -173.21, 171.99], + "z-lengths": [1.55, 1.5, 1.38, 1.36, 1.35, 1.38], + }, + "HSD": { + "atoms": ["CB", "CG", "ND1", "CD2", "CE1", "NE2"], + "chi_indices": [1, 2], + "parents": [ + ["N", "C", "CA"], + ["N", "CA", "CB"], + ["CA", "CB", "CG"], + ["ND1", "CB", "CG"], + ["CB", "CG", "ND1"], + ["CB", "CG", "CD2"], + ], + "types": { + "C": "C", + "CA": "CT1", + "CB": "CT2", + "CD2": "CPH1", + "CE1": "CPH2", + "CG": "CPH1", + "N": "NH1", + "ND1": "NR1", + "NE2": "NR2", + "O": "O", + }, + "z-angles": [109.99, 114.05, 124.1, 129.6, 107.03, 110.03], + "z-dihedrals": [122.46, 180.0, 90.0, -171.29, -173.21, 171.99], + "z-lengths": [1.55, 1.5, 1.38, 1.36, 1.35, 1.38], + }, + "HSE": { + "atoms": ["CB", "CG", "ND1", "CD2", "CE1", "NE2"], + "chi_indices": [], + "parents": [ + ["N", "C", "CA"], + ["N", "CA", "CB"], + ["CA", "CB", "CG"], + ["ND1", "CB", "CG"], + ["CB", "CG", "ND1"], + ["CB", "CG", "CD2"], + ], + "types": { + "C": "C", + "CA": "CT1", + "CB": "CT2", + "CD2": "CPH1", + "CE1": "CPH2", + "CG": "CPH1", + "N": "NH1", + "ND1": "NR2", + "NE2": "NR1", + "O": "O", + }, + "z-angles": [111.67, 116.94, 120.17, 129.71, 105.2, 105.8], + "z-dihedrals": [123.52, 180.0, 90.0, -178.26, -179.2, 178.66], + "z-lengths": [1.56, 1.51, 1.39, 1.36, 1.32, 1.38], + }, + "HSP": { + "atoms": ["CB", "CG", "ND1", "CD2", "CE1", "NE2"], + "chi_indices": [], + "parents": [ + ["N", "C", "CA"], + ["N", "CA", "CB"], + ["CA", "CB", "CG"], + ["ND1", "CB", "CG"], + ["CB", "CG", "ND1"], + ["CB", "CG", "CD2"], + ], + "types": { + "C": "C", + "CA": "CT1", + "CB": "CT2A", + "CD2": "CPH1", + "CE1": "CPH2", + "CG": "CPH1", + "N": "NH1", + "ND1": "NR3", + "NE2": "NR3", + "O": "O", + }, + "z-angles": [109.38, 114.18, 122.94, 128.93, 108.9, 106.93], + "z-dihedrals": [125.13, 180.0, 90.0, -165.26, -167.62, 167.13], + "z-lengths": [1.55, 1.52, 1.37, 1.35, 1.33, 1.37], + }, + "ILE": { + "atoms": ["CB", "CG1", "CG2", "CD1"], + "chi_indices": [1, 3], + "parents": [ + ["N", "C", "CA"], + ["N", "CA", "CB"], + ["CG1", "CA", "CB"], + ["CA", "CB", "CG1"], + ], + "types": { + "C": "C", + "CA": "CT1", + "CB": "CT1", + "CD": "CT3", + "CG1": "CT2", + "CG2": "CT3", + "N": "NH1", + "O": "O", + }, + "z-angles": [112.93, 113.63, 113.93, 114.09], + "z-dihedrals": [124.22, 180.0, -130.04, 180.0], + "z-lengths": [1.57, 1.55, 1.55, 1.54], + }, + "LEU": { + "atoms": ["CB", "CG", "CD1", "CD2"], + "chi_indices": [1, 2], + "parents": [ + ["N", "C", "CA"], + ["N", "CA", "CB"], + ["CA", "CB", "CG"], + ["CD1", "CB", "CG"], + ], + "types": { + "C": "C", + "CA": "CT1", + "CB": "CT2", + "CD1": "CT3", + "CD2": "CT3", + "CG": "CT1", + "N": "NH1", + "O": "O", + }, + "z-angles": [112.12, 117.46, 110.48, 112.57], + "z-dihedrals": [121.52, 180.0, 180.0, 120.0], + "z-lengths": [1.55, 1.55, 1.54, 1.54], + }, + "LYS": { + "atoms": ["CB", "CG", "CD", "CE", "NZ"], + "chi_indices": [1, 2, 3, 4], + "parents": [ + ["N", "C", "CA"], + ["N", "CA", "CB"], + ["CA", "CB", "CG"], + ["CB", "CG", "CD"], + ["CG", "CD", "CE"], + ], + "types": { + "C": "C", + "CA": "CT1", + "CB": "CT2", + "CD": "CT2", + "CE": "CT2", + "CG": "CT2", + "N": "NH1", + "NZ": "NH3", + "O": "O", + }, + "z-angles": [111.36, 115.76, 113.28, 112.33, 110.46], + "z-dihedrals": [122.23, 180.0, 180.0, 180.0, 180.0], + "z-lengths": [1.56, 1.54, 1.54, 1.53, 1.46], + }, + "MET": { + "atoms": ["CB", "CG", "SD", "CE"], + "chi_indices": [1, 2, 3], + "parents": [ + ["N", "C", "CA"], + ["N", "CA", "CB"], + ["CA", "CB", "CG"], + ["CB", "CG", "SD"], + ], + "types": { + "C": "C", + "CA": "CT1", + "CB": "CT2", + "CE": "CT3", + "CG": "CT2", + "N": "NH1", + "O": "O", + "SD": "S", + }, + "z-angles": [111.88, 115.92, 110.28, 98.94], + "z-dihedrals": [121.62, 180.0, 180.0, 180.0], + "z-lengths": [1.55, 1.55, 1.82, 1.82], + }, + "PHE": { + "atoms": ["CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ"], + "chi_indices": [1, 2], + "parents": [ + ["N", "C", "CA"], + ["N", "CA", "CB"], + ["CA", "CB", "CG"], + ["CD1", "CB", "CG"], + ["CB", "CG", "CD1"], + ["CB", "CG", "CD2"], + ["CG", "CD1", "CE1"], + ], + "types": { + "C": "C", + "CA": "CT1", + "CB": "CT2", + "CD1": "CA", + "CD2": "CA", + "CE1": "CA", + "CE2": "CA", + "CG": "CA", + "CZ": "CA", + "N": "NH1", + "O": "O", + }, + "z-angles": [112.45, 112.76, 120.32, 120.76, 120.63, 120.62, 119.93], + "z-dihedrals": [122.49, 180.0, 90.0, -177.96, -177.37, 177.2, -0.12], + "z-lengths": [1.56, 1.51, 1.41, 1.41, 1.4, 1.4, 1.4], + }, + "PRO": { + "atoms": ["CB", "CG", "CD"], + "chi_indices": [1, 2], + "parents": [["N", "C", "CA"], ["N", "CA", "CB"], ["CA", "CB", "CG"]], + "types": { + "C": "C", + "CA": "CP1", + "CB": "CP2", + "CD": "CP3", + "CG": "CP2", + "N": "N", + "O": "O", + }, + "z-angles": [111.74, 104.39, 103.21], + "z-dihedrals": [113.74, 31.61, -34.59], + "z-lengths": [1.54, 1.53, 1.53], + }, + "SER": { + "atoms": ["CB", "OG"], + "chi_indices": [1], + "parents": [["N", "C", "CA"], ["N", "CA", "CB"]], + "types": { + "C": "C", + "CA": "CT1", + "CB": "CT2", + "N": "NH1", + "O": "O", + "OG": "OH1", + }, + "z-angles": [111.4, 112.45], + "z-dihedrals": [124.75, 180.0], + "z-lengths": [1.56, 1.43], + }, + "THR": { + "atoms": ["CB", "OG1", "CG2"], + "chi_indices": [1], + "parents": [["N", "C", "CA"], ["N", "CA", "CB"], ["OG1", "CA", "CB"]], + "types": { + "C": "C", + "CA": "CT1", + "CB": "CT1", + "CG2": "CT3", + "N": "NH1", + "O": "O", + "OG1": "OH1", + }, + "z-angles": [112.74, 112.16, 115.91], + "z-dihedrals": [126.46, 180.0, -124.13], + "z-lengths": [1.57, 1.43, 1.53], + }, + "TRP": { + "atoms": ["CB", "CG", "CD2", "CD1", "CE2", "NE1", "CE3", "CZ3", "CH2", "CZ2"], + "chi_indices": [1, 2], + "parents": [ + ["N", "C", "CA"], + ["N", "CA", "CB"], + ["CA", "CB", "CG"], + ["CD2", "CB", "CG"], + ["CD1", "CG", "CD2"], + ["CG", "CD2", "CE2"], + ["CE2", "CG", "CD2"], + ["CE2", "CD2", "CE3"], + ["CD2", "CE3", "CZ3"], + ["CE3", "CZ3", "CH2"], + ], + "types": { + "C": "C", + "CA": "CT1", + "CB": "CT2", + "CD1": "CA", + "CD2": "CPT", + "CE2": "CPT", + "CE3": "CAI", + "CG": "CY", + "CH2": "CA", + "CZ2": "CAI", + "CZ3": "CA", + "N": "NH1", + "NE1": "NY", + "O": "O", + }, + "z-angles": [ + 111.23, + 115.14, + 123.95, + 129.18, + 106.65, + 107.87, + 132.54, + 118.16, + 120.97, + 120.87, + ], + "z-dihedrals": [ + 122.68, + 180.0, + 90.0, + -172.81, + -0.08, + 0.14, + 179.21, + -0.2, + 0.1, + 0.01, + ], + "z-lengths": [1.56, 1.52, 1.44, 1.37, 1.41, 1.37, 1.4, 1.4, 1.4, 1.4], + }, + "TYR": { + "atoms": ["CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ", "OH"], + "chi_indices": [1, 2], + "parents": [ + ["N", "C", "CA"], + ["N", "CA", "CB"], + ["CA", "CB", "CG"], + ["CD1", "CB", "CG"], + ["CB", "CG", "CD1"], + ["CB", "CG", "CD2"], + ["CG", "CD1", "CE1"], + ["CE1", "CE2", "CZ"], + ], + "types": { + "C": "C", + "CA": "CT1", + "CB": "CT2", + "CD1": "CA", + "CD2": "CA", + "CE1": "CA", + "CE2": "CA", + "CG": "CA", + "CZ": "CA", + "N": "NH1", + "O": "O", + "OH": "OH1", + }, + "z-angles": [112.34, 112.94, 120.49, 120.46, 120.4, 120.56, 120.09, 120.25], + "z-dihedrals": [122.27, 180.0, 90.0, -176.46, -175.49, 175.32, -0.19, -178.98], + "z-lengths": [1.56, 1.51, 1.41, 1.41, 1.4, 1.4, 1.4, 1.41], + }, + "VAL": { + "atoms": ["CB", "CG1", "CG2"], + "chi_indices": [1], + "parents": [["N", "C", "CA"], ["N", "CA", "CB"], ["CG1", "CA", "CB"]], + "types": { + "C": "C", + "CA": "CT1", + "CB": "CT1", + "CG1": "CT3", + "CG2": "CT3", + "N": "NH1", + "O": "O", + }, + "z-angles": [111.23, 113.97, 112.17], + "z-dihedrals": [122.95, 180.0, 123.99], + "z-lengths": [1.57, 1.54, 1.54], + }, +} + + +# our constants +# elements +periodic_table = [ # Periodic Table + # 1 + 'H', 'He', + # 2 + 'Li', 'Be', 'B', 'C', 'N', 'O', 'F', 'Ne', + # 3 + 'Na', 'Mg', 'Al', 'Si', 'P', 'S', 'Cl', 'Ar', + # 4 + 'K', 'Ca', 'Sc', 'Ti', 'V', 'Cr', 'Mn', 'Fe', 'Co', 'Ni', 'Cu', 'Zn', + 'Ga', 'Ge', 'As', 'Se', 'Br', 'Kr', + # 5 + 'Rb', 'Sr', 'Y', 'Zr', 'Nb', 'Mo', 'Tc', 'Ru', 'Rh', 'Pd', 'Ag', 'Cd', + 'In', 'Sn', 'Sb', 'Te', 'I', 'Xe', + # 6 + 'Cs', 'Ba', 'La', 'Ce', 'Pr', 'Nd', 'Pm', 'Sm', 'Eu', 'Gd', 'Tb', 'Dy', + 'Ho', 'Er', 'Tm', 'Yb', 'Lu', + 'Hf', 'Ta', 'W', 'Re', 'Os', 'Ir', 'Pt', 'Au', 'Hg', 'Tl', 'Pb', 'Bi', + 'Po', 'At', 'Rn', + # 7 + 'Fr', 'Ra', 'Ac', 'Th', 'Pa', 'U', 'Np', 'Pu', 'Am', 'Cm', 'Bk', + 'Cf', 'Es', 'Fm', 'Md', 'No', 'Lr', + 'Rf', 'Db', 'Sg', 'Bh', 'Hs', 'Mt', 'Ds', 'Rg', 'Cn', 'Nh', 'Fl', 'Mc', + 'Lv', 'Ts', 'Og' +] + +# amino acids +aas = [ + ('G', 'GLY'), ('A', 'ALA'), ('V', 'VAL'), ('L', 'LEU'), + ('I', 'ILE'), ('F', 'PHE'), ('W', 'TRP'), ('Y', 'TYR'), + ('D', 'ASP'), ('H', 'HIS'), ('N', 'ASN'), ('E', 'GLU'), + ('K', 'LYS'), ('Q', 'GLN'), ('M', 'MET'), ('R', 'ARG'), + ('S', 'SER'), ('T', 'THR'), ('C', 'CYS'), ('P', 'PRO') +] + +# backbone atoms +backbone_atoms = ['N', 'CA', 'C', 'O'] + +# backbone bonds +# 1: single bond +# 2: double bond +# 3: triple bond +# 4: conjugate system (e.g. aromatic) +backbone_bonds = [ + ('N', 'CA', 1), + ('CA', 'C', 1), + ('C', 'O', 4) # conjugate with adjacent N +] + +# side-chain atoms +sidechain_atoms = { symbol: AA_GEOMETRY[aa]['atoms'] for symbol, aa in aas } +# sidechain_atoms = { +# 'G': [], # -H +# 'A': ['CB'], # -CH3 +# 'V': ['CB', 'CG1', 'CG2'], # -CH-(CH3)2 +# 'L': ['CB', 'CG', 'CD1', 'CD2'], # -CH2-CH(CH3)2 +# 'I': ['CB', 'CG1', 'CG2', 'CD1'], # -CH(CH3)-CH2-CH3 +# 'F': ['CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ'], # -CH2-C6H5 +# 'W': ['CB', 'CG', 'CD1', 'CD2', 'NE1', 'CE2', 'CE3', 'CZ2', 'CZ3', 'CH2'], # -CH2-C8NH6 +# 'Y': ['CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', 'OH'], # -CH2-C6H4-OH +# 'D': ['CB', 'CG', 'OD1', 'OD2'], # -CH2-COOH +# 'H': ['CB', 'CG', 'ND1', 'CD2', 'CE1', 'NE2'], # -CH2-C3H3N2 +# 'N': ['CB', 'CG', 'OD1', 'ND2'], # -CH2-CONH2 +# 'E': ['CB', 'CG', 'CD', 'OE1', 'OE2'], # -(CH2)2-COOH +# 'K': ['CB', 'CG', 'CD', 'CE', 'NZ'], # -(CH2)4-NH2 +# 'Q': ['CB', 'CG', 'CD', 'OE1', 'NE2'], # -(CH2)-CONH2 +# 'M': ['CB', 'CG', 'SD', 'CE'], # -(CH2)2-S-CH3 +# 'R': ['CB', 'CG', 'CD', 'NE', 'CZ', 'NH1', 'NH2'], # -(CH2)3-NHC(NH)NH2 +# 'S': ['CB', 'OG'], # -CH2-OH +# 'T': ['CB', 'OG1', 'CG2'], # -CH(CH3)-OH +# 'C': ['CB', 'SG'], # -CH2-SH +# 'P': ['CB', 'CG', 'CD'], # -C3H6 +# } + +# side-chain bonds +sidechain_bonds = { + 'G': [], + 'A': [('CA', 'CB', 1)], + 'V': [('CA', 'CB', 1), ('CB', 'CG1', 1), ('CB', 'CG2', 1)], + 'L': [('CA', 'CB', 1), ('CB', 'CG', 1), ('CG', 'CD1', 1), ('CG', 'CD2', 1)], + 'I': [('CA', 'CB', 1), ('CB', 'CG1', 1), ('CB', 'CG2', 1), ('CG1', 'CD1', 1)], + 'F': [('CA', 'CB', 1), ('CB', 'CG', 1), ('CG', 'CD1', 4), ('CG', 'CD2', 4), ('CD1', 'CE1', 4), ('CD2', 'CE2', 4), ('CE1', 'CZ', 4), ('CE2', 'CZ', 4)], + 'W': [('CA', 'CB', 1), ('CB', 'CG', 1), ('CG', 'CD1', 4), ('CG', 'CD2', 4), ('CD1', 'NE1', 4), ('CD2', 'CE2', 4), ('CD2', 'CE3', 4), ('CE2', 'NE1', 4), + ('CE2', 'CZ2', 4), ('CZ2', 'CH2', 4), ('CE3', 'CZ3', 4), ('CZ3', 'CH2', 4)], + 'Y': [('CA', 'CB', 1), ('CB', 'CG', 1), ('CG', 'CD1', 4), ('CG', 'CD2', 4), ('CD1', 'CE1', 4), ('CD2', 'CE2', 4), ('CE1', 'CZ', 4), ('CE2', 'CZ', 4), ('CZ', 'OH', 1)], + 'D': [('CA', 'CB', 1), ('CB', 'CG', 1), ('CG', 'OD1', 4), ('CG', 'OD2', 4)], + 'H': [('CA', 'CB', 1), ('CB', 'CG', 1), ('CG', 'ND1', 4), ('CG', 'CD2', 4), ('ND1', 'CE1', 4), ('CD2', 'NE2', 4), ('CE1', 'NE2', 4)], + 'N': [('CA', 'CB', 1), ('CB', 'CG', 1), ('CG', 'OD1', 4), ('CG', 'ND2', 4)], + 'E': [('CA', 'CB', 1), ('CB', 'CG', 1), ('CG', 'CD', 1), ('CD', 'OE1', 4), ('CD', 'OE2', 4)], + 'K': [('CA', 'CB', 1), ('CB', 'CG', 1), ('CG', 'CD', 1), ('CD', 'CE', 1), ('CE', 'NZ', 1)], + 'Q': [('CA', 'CB', 1), ('CB', 'CG', 1), ('CG', 'CD', 1), ('CD', 'OE1', 4), ('CD', 'NE2', 4)], + 'M': [('CA', 'CB', 1), ('CB', 'CG', 1), ('CG', 'SD', 1), ('SD', 'CE', 1)], + 'R': [('CA', 'CB', 1), ('CB', 'CG', 1), ('CG', 'CD', 1), ('CD', 'NE', 1), ('NE', 'CZ', 4), ('CZ', 'NH1', 4), ('CZ', 'NH2', 4)], + 'S': [('CA', 'CB', 1), ('CB', 'OG', 1)], + 'T': [('CA', 'CB', 1), ('CB', 'OG1', 1), ('CB', 'CG2', 1)], + 'C': [('CA', 'CB', 1), ('CB', 'SG', 1)], + 'P': [('CA', 'CB', 1), ('CB', 'CG', 1), ('CG', 'CD', 1), ('CD', 'N', 1)] +} + +# atoms for defining chi angles on the side chains +chi_angles_atoms = { + "ALA": [], + # Chi5 in arginine is always 0 +- 5 degrees, so ignore it. + "ARG": [ + ["N", "CA", "CB", "CG"], + ["CA", "CB", "CG", "CD"], + ["CB", "CG", "CD", "NE"], + ["CG", "CD", "NE", "CZ"], + ], + "ASN": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "OD1"]], + "ASP": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "OD1"]], + "CYS": [["N", "CA", "CB", "SG"]], + "GLN": [ + ["N", "CA", "CB", "CG"], + ["CA", "CB", "CG", "CD"], + ["CB", "CG", "CD", "OE1"], + ], + "GLU": [ + ["N", "CA", "CB", "CG"], + ["CA", "CB", "CG", "CD"], + ["CB", "CG", "CD", "OE1"], + ], + "GLY": [], + "HIS": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "ND1"]], + "ILE": [["N", "CA", "CB", "CG1"], ["CA", "CB", "CG1", "CD1"]], + "LEU": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]], + "LYS": [ + ["N", "CA", "CB", "CG"], + ["CA", "CB", "CG", "CD"], + ["CB", "CG", "CD", "CE"], + ["CG", "CD", "CE", "NZ"], + ], + "MET": [ + ["N", "CA", "CB", "CG"], + ["CA", "CB", "CG", "SD"], + ["CB", "CG", "SD", "CE"], + ], + "PHE": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]], + "PRO": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD"]], + "SER": [["N", "CA", "CB", "OG"]], + "THR": [["N", "CA", "CB", "OG1"]], + "TRP": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]], + "TYR": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]], + "VAL": [["N", "CA", "CB", "CG1"]], +} + +# amino acid smiles +aa_smiles = { + 'G': 'C(C(=O)O)N', + 'A': 'O=C(O)C(N)C', + 'V': 'CC(C)[C@@H](C(=O)O)N', + 'L': 'CC(C)C[C@@H](C(=O)O)N', + 'I': 'CC[C@H](C)[C@@H](C(=O)O)N', + 'F': 'NC(C(=O)O)Cc1ccccc1', + 'W': 'c1ccc2c(c1)c(c[nH]2)C[C@@H](C(=O)O)N', + 'Y': 'N[C@@H](Cc1ccc(O)cc1)C(O)=O', + 'D': 'O=C(O)CC(N)C(=O)O', + 'H': 'O=C([C@H](CC1=CNC=N1)N)O', + 'N': 'NC(=O)CC(N)C(=O)O', + 'E': 'OC(=O)CCC(N)C(=O)O', + 'K': 'NCCCC(N)C(=O)O', + 'Q': 'O=C(N)CCC(N)C(=O)O', + 'M': 'CSCC[C@H](N)C(=O)O', + 'R': 'NC(=N)NCCCC(N)C(=O)O', + 'S': 'C([C@@H](C(=O)O)N)O', + 'T': 'C[C@H]([C@@H](C(=O)O)N)O', + 'C': 'C([C@@H](C(=O)O)N)S', + 'P': 'OC(=O)C1CCCN1' +} \ No newline at end of file diff --git a/utils/decorators.py b/utils/decorators.py new file mode 100644 index 0000000000000000000000000000000000000000..89cf0b8eab51019e1e9c523f9ba47d61dec295ac --- /dev/null +++ b/utils/decorators.py @@ -0,0 +1,49 @@ +#!/usr/bin/python +# -*- coding:utf-8 -*- +''' +Currently support the following decorators: +1. @singleton +2. @timeout(seconds) +''' +import functools +from concurrent import futures +TimeoutError = futures.TimeoutError + + +''' +Singleton class: + +@singleton +class A: + ... +''' + +def singleton(cls): + _instance = {} + + def inner(*args, **kwargs): + if cls not in _instance: + _instance[cls] = cls(*args, **kwargs) + return _instance[cls] + return inner + + +''' +Throw TimeoutError when a function exceeds 1 second: + +@timeout(1) +def func(...): + ... +''' +class timeout: + __executor = futures.ThreadPoolExecutor(1) + + def __init__(self, seconds): + self.seconds = seconds + + def __call__(self, func): + @functools.wraps(func) + def wrapper(*args, **kw): + future = timeout.__executor.submit(func, *args, **kw) + return future.result(timeout=self.seconds) + return wrapper \ No newline at end of file diff --git a/utils/file_utils.py b/utils/file_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f65f6aa2f48c024ea6f74b050a8d71dca7259a2b --- /dev/null +++ b/utils/file_utils.py @@ -0,0 +1,19 @@ +#!/usr/bin/python +# -*- coding:utf-8 -*- +import os +from os.path import basename, splitext + + +def get_filename(path): + return basename(splitext(path)[0]) + + +def cnt_num_files(directory, recursive=False): + cnt = 0 + for sub in os.listdir(directory): + sub = os.path.join(directory, sub) + if os.path.isfile(sub): + cnt += 1 + elif os.path.isdir(sub) and recursive: + cnt += cnt_num_files(sub, recursive=True) + return cnt \ No newline at end of file diff --git a/utils/logger.py b/utils/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..3d3ee264cdeec67ad1878e80e23a481cd5087704 --- /dev/null +++ b/utils/logger.py @@ -0,0 +1,35 @@ +#!/usr/bin/python +# -*- coding:utf-8 -*- +import os +import sys +import datetime + + +LEVELS = ['TRACE', 'DEBUG', 'INFO', 'WARN', 'ERROR'] +LEVELS_MAP = None + + +def init_map(): + global LEVELS_MAP, LEVELS + LEVELS_MAP = {} + for idx, level in enumerate(LEVELS): + LEVELS_MAP[level] = idx + + +def get_prio(level): + global LEVELS_MAP + if LEVELS_MAP is None: + init_map() + return LEVELS_MAP[level.upper()] + + +def print_log(s, level='INFO', end='\n', no_prefix=False): + pth_prio = get_prio(os.getenv('LOG', 'INFO')) + prio = get_prio(level) + if prio >= pth_prio: + if not no_prefix: + now = datetime.datetime.now() + prefix = now.strftime("%Y-%m-%d %H:%M:%S") + f'::{level.upper()}::' + print(prefix, end='') + print(s, end=end) + sys.stdout.flush() diff --git a/utils/network.py b/utils/network.py new file mode 100644 index 0000000000000000000000000000000000000000..d179f9a4aace4925990e5bb990e4f7a2508737ee --- /dev/null +++ b/utils/network.py @@ -0,0 +1,19 @@ +#!/usr/bin/python +# -*- coding:utf-8 -*- +import requests + +from .logger import print_log + + +def url_get(url, tries=3): + for i in range(tries): + if i > 0: + print_log(f'Trying for the {i + 1} times', level='WARN') + try: + res = requests.get(url) + except ConnectionError: + continue + if res.status_code == 200: + return res + print_log(f'Get {url} failed', level='WARN') + return None \ No newline at end of file diff --git a/utils/nn_utils.py b/utils/nn_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8d22797f87a12defdddec8965808c09058d8b6b4 --- /dev/null +++ b/utils/nn_utils.py @@ -0,0 +1,345 @@ +#!/usr/bin/python +# -*- coding:utf-8 -*- +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch_scatter import scatter_sum + + +def count_parameters(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + +def sequential_and(*tensors): + res = tensors[0] + for mat in tensors[1:]: + res = torch.logical_and(res, mat) + return res + + +def sequential_or(*tensors): + res = tensors[0] + for mat in tensors[1:]: + res = torch.logical_or(res, mat) + return res + + +def graph_to_batch(tensor, batch_id, padding_value=0, mask_is_pad=True): + ''' + :param tensor: [N, D1, D2, ...] + :param batch_id: [N] + :param mask_is_pad: 1 in the mask indicates padding if set to True + ''' + lengths = scatter_sum(torch.ones_like(batch_id), batch_id) # [bs] + bs, max_n = lengths.shape[0], torch.max(lengths) + batch = torch.ones((bs, max_n, *tensor.shape[1:]), dtype=tensor.dtype, device=tensor.device) * padding_value + # generate pad mask: 1 for pad and 0 for data + pad_mask = torch.zeros((bs, max_n + 1), dtype=torch.long, device=tensor.device) + pad_mask[(torch.arange(bs, device=tensor.device), lengths)] = 1 + pad_mask = (torch.cumsum(pad_mask, dim=-1)[:, :-1]).bool() + data_mask = torch.logical_not(pad_mask) + # fill data + batch[data_mask] = tensor + mask = pad_mask if mask_is_pad else data_mask + return batch, mask + + +def variadic_arange(size): + """ + from https://torchdrug.ai/docs/_modules/torchdrug/layers/functional/functional.html#variadic_arange + + Return a 1-D tensor that contains integer intervals of variadic sizes. + This is a variadic variant of ``torch.arange(stop).expand(batch_size, -1)``. + + Suppose there are :math:`N` intervals. + + Parameters: + size (LongTensor): size of intervals of shape :math:`(N,)` + """ + starts = size.cumsum(0) - size + + range = torch.arange(size.sum(), device=size.device) + range = range - starts.repeat_interleave(size) + return range + + +def variadic_meshgrid(input1, size1, input2, size2): + """ + from https://torchdrug.ai/docs/_modules/torchdrug/layers/functional/functional.html#variadic_meshgrid + Compute the Cartesian product for two batches of sets with variadic sizes. + + Suppose there are :math:`N` sets in each input, + and the sizes of all sets are summed to :math:`B_1` and :math:`B_2` respectively. + + Parameters: + input1 (Tensor): input of shape :math:`(B_1, ...)` + size1 (LongTensor): size of :attr:`input1` of shape :math:`(N,)` + input2 (Tensor): input of shape :math:`(B_2, ...)` + size2 (LongTensor): size of :attr:`input2` of shape :math:`(N,)` + + Returns + (Tensor, Tensor): the first and the second elements in the Cartesian product + """ + grid_size = size1 * size2 + local_index = variadic_arange(grid_size) + local_inner_size = size2.repeat_interleave(grid_size) + offset1 = (size1.cumsum(0) - size1).repeat_interleave(grid_size) + offset2 = (size2.cumsum(0) - size2).repeat_interleave(grid_size) + index1 = torch.div(local_index, local_inner_size, rounding_mode="floor") + offset1 + index2 = local_index % local_inner_size + offset2 + return input1[index1], input2[index2] + + +@torch.no_grad() +def length_to_batch_id(S, lengths): + # generate batch id + batch_id = torch.zeros_like(S) # [N] + batch_id[torch.cumsum(lengths, dim=0)[:-1]] = 1 + batch_id.cumsum_(dim=0) # [N], item idx in the batch + return batch_id + + +def scatter_sort(src: torch.Tensor, index: torch.Tensor, dim=0, descending=False, eps=1e-12): + ''' + from https://github.com/rusty1s/pytorch_scatter/issues/48 + WARN: the range between src.max() and src.min() should not be too wide for numerical stability + + reproducible + ''' + # f_src = src.float() + # f_min, f_max = f_src.min(dim)[0], f_src.max(dim)[0] + # norm = (f_src - f_min)/(f_max - f_min + eps) + index.float()*(-1)**int(descending) + # perm = norm.argsort(dim=dim, descending=descending) + + # return src[perm], perm + src, src_perm = torch.sort(src, dim=dim, descending=descending) + index = index.take_along_dim(src_perm, dim=dim) + index, index_perm = torch.sort(index, dim=dim, stable=True) + src = src.take_along_dim(index_perm, dim=dim) + perm = src_perm.take_along_dim(index_perm, dim=0) + return src, perm + + +def scatter_topk(src: torch.Tensor, index: torch.Tensor, k: int, dim=0, largest=True): + indices = torch.arange(src.shape[dim], device=src.device) + src, perm = scatter_sort(src, index, dim, descending=largest) + index, indices = index[perm], indices[perm] + mask = torch.ones_like(index).bool() + mask[k:] = index[k:] != index[:-k] + return src[mask], indices[mask] + + +@torch.no_grad() +def knn_edges(all_edges, k_neighbors, X=None, atom_mask=None, given_dist=None): + ''' + :param all_edges: [2, E], (row, col) + :param X: [N, n_channel, 3], coordinates + :param atom_mask: [N, n_channel], 1 for having atom + :param given_dist: [E], given distance of edges + IMPORTANT: either given_dist should be given, or both X and atom_mask should be given + ''' + assert (given_dist is not None) or (X is not None and atom_mask is not None), \ + 'either given_dist should be given, or both X and atom_mask should be given' + + # get distance on each edge + if given_dist is None: + row, col = all_edges + dist = torch.norm(X[row][:, :, None, :] - X[col][:, None, :, :], dim=-1) # [E, n_channel, n_channel] + dist_mask = atom_mask[row][:, :, None] & atom_mask[col][:, None, :] # [E, n_channel, n_channel] + dist = torch.where(dist_mask, dist, torch.ones_like(dist) * float('inf')) # [E, n_channel, n_channel] + dist, _ = dist.view(dist.shape[0], -1).min(axis=-1) # [E] + else: + dist = given_dist + + # get topk for each node + _, indices = scatter_topk(dist, row, k=k_neighbors, largest=False) + edges = torch.stack([all_edges[0][indices], all_edges[1][indices]], dim=0) # [2, k*N] + return edges # [2, E] + + +class EdgeConstructor: + def __init__(self, cor_idx, col_idx, atom_pos_pad_idx, rec_seg_id) -> None: + self.cor_idx, self.col_idx = cor_idx, col_idx + self.atom_pos_pad_idx = atom_pos_pad_idx + self.rec_seg_id = rec_seg_id + + # buffer + self._reset_buffer() + + def _reset_buffer(self): + self.row = None + self.col = None + self.row_global = None + self.col_global = None + self.row_seg = None + self.col_seg = None + self.offsets = None + self.max_n = None + self.gni2lni = None + self.not_global_edges = None + + def get_batch_edges(self, batch_id): + # construct tensors to map between global / local node index + lengths = scatter_sum(torch.ones_like(batch_id), batch_id) # [bs] + N, max_n = batch_id.shape[0], torch.max(lengths) + offsets = F.pad(torch.cumsum(lengths, dim=0)[:-1], pad=(1, 0), value=0) # [bs] + # global node index to local index. lni2gni can be implemented as lni + offsets[batch_id] + gni = torch.arange(N, device=batch_id.device) + gni2lni = gni - offsets[batch_id] # [N] + + # all possible edges (within the same graph) + # same bid (get rid of self-loop and none edges) + same_bid = torch.zeros(N, max_n, device=batch_id.device) + same_bid[(gni, lengths[batch_id] - 1)] = 1 + same_bid = 1 - torch.cumsum(same_bid, dim=-1) + # shift right and pad 1 to the left + same_bid = F.pad(same_bid[:, :-1], pad=(1, 0), value=1) + same_bid[(gni, gni2lni)] = 0 # delete self loop + row, col = torch.nonzero(same_bid).T # [2, n_edge_all] + col = col + offsets[batch_id[row]] # mapping from local to global node index + return (row, col), (offsets, max_n, gni2lni) + + def _prepare(self, S, batch_id, segment_ids) -> None: + (row, col), (offsets, max_n, gni2lni) = self.get_batch_edges(batch_id) + + # not global edges + is_global = sequential_or(S == self.cor_idx, S == self.col_idx) # [N] + row_global, col_global = is_global[row], is_global[col] + not_global_edges = torch.logical_not(torch.logical_or(row_global, col_global)) + + # segment ids + row_seg, col_seg = segment_ids[row], segment_ids[col] + + # add to buffer + self.row, self.col = row, col + self.offsets, self.max_n, self.gni2lni = offsets, max_n, gni2lni + self.row_global, self.col_global = row_global, col_global + self.not_global_edges = not_global_edges + self.row_seg, self.col_seg = row_seg, col_seg + + def _construct_inner_edges(self, X, batch_id, k_neighbors, atom_pos): + row, col = self.row, self.col + # all possible ctx edges: same seg, not global + select_edges = torch.logical_and(self.row_seg == self.col_seg, self.not_global_edges) + ctx_all_row, ctx_all_col = row[select_edges], col[select_edges] + # ctx edges + inner_edges = _knn_edges( + X, atom_pos, torch.stack([ctx_all_row, ctx_all_col]).T, + self.atom_pos_pad_idx, k_neighbors, + (self.offsets, batch_id, self.max_n, self.gni2lni)) + return inner_edges + + def _construct_outer_edges(self, X, batch_id, k_neighbors, atom_pos): + row, col = self.row, self.col + # all possible inter edges: not same seg, not global + select_edges = torch.logical_and(self.row_seg != self.col_seg, self.not_global_edges) + inter_all_row, inter_all_col = row[select_edges], col[select_edges] + outer_edges = _knn_edges( + X, atom_pos, torch.stack([inter_all_row, inter_all_col]).T, + self.atom_pos_pad_idx, k_neighbors, + (self.offsets, batch_id, self.max_n, self.gni2lni)) + return outer_edges + + def _construct_global_edges(self): + row, col = self.row, self.col + # edges between global and normal nodes + select_edges = torch.logical_and(self.row_seg == self.col_seg, torch.logical_not(self.not_global_edges)) + global_normal = torch.stack([row[select_edges], col[select_edges]]) # [2, nE] + # edges between global and global nodes + select_edges = torch.logical_and(self.row_global, self.col_global) # self-loop has been deleted + global_global = torch.stack([row[select_edges], col[select_edges]]) # [2, nE] + return global_normal, global_global + + def _construct_seq_edges(self): + row, col = self.row, self.col + # add additional edge to neighbors in 1D sequence (except epitope) + select_edges = sequential_and( + torch.logical_or((row - col) == 1, (row - col) == -1), # adjacent in the graph + self.not_global_edges, # not global edges (also ensure the edges are in the same segment) + self.row_seg != self.rec_seg_id # not epitope + ) + seq_adj = torch.stack([row[select_edges], col[select_edges]]) # [2, nE] + return seq_adj + + @torch.no_grad() + def construct_edges(self, X, S, batch_id, k_neighbors, atom_pos, segment_ids): + ''' + Memory efficient with complexity of O(Nn) where n is the largest number of nodes in the batch + ''' + # prepare inputs + self._prepare(S, batch_id, segment_ids) + + ctx_edges, inter_edges = [], [] + + # edges within chains + inner_edges = self._construct_inner_edges(X, batch_id, k_neighbors, atom_pos) + # edges between global nodes and normal/global nodes + global_normal, global_global = self._construct_global_edges() + # edges on the 1D sequence + seq_edges = self._construct_seq_edges() + + # # construct context edges + ctx_edges = torch.cat([inner_edges, global_normal, global_global, seq_edges], dim=1) # [2, E] + + # construct interaction edges + inter_edges = self._construct_outer_edges(X, batch_id, k_neighbors, atom_pos) + + self._reset_buffer() + return ctx_edges, inter_edges + + +class GMEdgeConstructor(EdgeConstructor): + ''' + Edge constructor for graph matching (kNN internel edges and all bipartite edges) + ''' + def _construct_inner_edges(self, X, batch_id, k_neighbors, atom_pos): + row, col = self.row, self.col + # all possible ctx edges: both in ag or ab, not global + row_is_rec = self.row_seg == self.rec_seg_id + col_is_rec = self.col_seg == self.rec_seg_id + select_edges = torch.logical_and(row_is_rec == col_is_rec, self.not_global_edges) + ctx_all_row, ctx_all_col = row[select_edges], col[select_edges] + # ctx edges + inner_edges = _knn_edges( + X, atom_pos, torch.stack([ctx_all_row, ctx_all_col]).T, + self.atom_pos_pad_idx, k_neighbors, + (self.offsets, batch_id, self.max_n, self.gni2lni)) + return inner_edges + + def _construct_global_edges(self): + row, col = self.row, self.col + # edges between global and normal nodes + select_edges = torch.logical_and(self.row_seg == self.col_seg, torch.logical_not(self.not_global_edges)) + global_normal = torch.stack([row[select_edges], col[select_edges]]) # [2, nE] + # edges between global and global nodes + select_edges = sequential_and(self.row_global, self.col_global) # self-loop has been deleted + global_global = torch.stack([row[select_edges], col[select_edges]]) # [2, nE] + return global_normal, global_global + + def _construct_outer_edges(self, X, batch_id, k_neighbors, atom_pos): + row, col = self.row, self.col + # all possible inter edges: one in ag and one in ab, not global + row_is_rec = self.row_seg == self.rec_seg_id + col_is_rec = self.col_seg == self.rec_seg_id + select_edges = torch.logical_and(row_is_rec != col_is_rec, self.not_global_edges) + inter_all_row, inter_all_col = row[select_edges], col[select_edges] + return torch.stack([inter_all_row, inter_all_col]) # [2, E] + + +class SinusoidalPositionEmbedding(nn.Module): + """ + Sin-Cos Positional Embedding + """ + def __init__(self, output_dim): + super(SinusoidalPositionEmbedding, self).__init__() + self.output_dim = output_dim + + def forward(self, position_ids): + device = position_ids.device + position_ids = position_ids[None] # [1, N] + indices = torch.arange(self.output_dim // 2, device=device, dtype=torch.float) + indices = torch.pow(10000.0, -2 * indices / self.output_dim) + embeddings = torch.einsum('bn,d->bnd', position_ids, indices) + embeddings = torch.stack([torch.sin(embeddings), torch.cos(embeddings)], dim=-1) + embeddings = embeddings.reshape(-1, self.output_dim) + return embeddings \ No newline at end of file diff --git a/utils/oom_decorator.py b/utils/oom_decorator.py new file mode 100644 index 0000000000000000000000000000000000000000..6d494377d6a581988e387a6de836f223c4e4e913 --- /dev/null +++ b/utils/oom_decorator.py @@ -0,0 +1,44 @@ +#!/usr/bin/python +# -*- coding:utf-8 -*- +from collections import namedtuple +from functools import wraps + +import torch + + +OOMReturn = namedtuple('OOMReturn', ['fake_loss']) + + +def oom_decorator(forward): + @wraps(forward) + + def deco_func(self, *args, **kwargs): + try: + output = forward(self, *args, **kwargs) + return output + except RuntimeError as e: + if 'out of memory' in str(e): + output = sum([p.norm() for p in self.parameters() if p.dtype == torch.float]) * 0.0 + return OOMReturn(output) + else: + raise e + + return deco_func + + +def safe_backward(loss, model): + if isinstance(model, torch.nn.parallel.DistributedDataParallel): + loss.backward() # regrettedly, we cannot handle backward oom in distributed training + return True + + try: + loss.backward() + return True + except RuntimeError as e: + if 'out of memory' in str(e): + fake_loss = sum([p.norm() for p in model.parameters() if p.dtype == torch.float]) * 0.0 + fake_loss.backward() + torch.cuda.empty_cache() + return False + else: + raise e \ No newline at end of file diff --git a/utils/random_seed.py b/utils/random_seed.py new file mode 100644 index 0000000000000000000000000000000000000000..8d68bacc8bc4bee9d3a79da9e9fcb2e346ff3ec5 --- /dev/null +++ b/utils/random_seed.py @@ -0,0 +1,16 @@ +#!/usr/bin/python +# -*- coding:utf-8 -*- +import torch +import numpy as np +import random + + +def setup_seed(seed): + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + torch.backends.cudnn.deterministic = True + + +SEED = 12 \ No newline at end of file diff --git a/utils/register.py b/utils/register.py new file mode 100644 index 0000000000000000000000000000000000000000..0725e4c47894d1b2fd0235cf7a424e9ae33e5b25 --- /dev/null +++ b/utils/register.py @@ -0,0 +1,27 @@ +#!/usr/bin/python +# -*- coding:utf-8 -*- +from typing import Dict +from copy import deepcopy + +_NAMESPACE = {} + +def register(name): + def decorator(cls): + assert name not in _NAMESPACE, f'Class {name} already registered' + _NAMESPACE[name] = cls + return cls + return decorator + + +def get(name): + if name not in _NAMESPACE: + raise ValueError(f'Class {name} not registered') + return _NAMESPACE[name] + + +def construct(config: Dict, **kwargs): + config = deepcopy(config) + cls_name = config.pop('class') + cls = get(cls_name) + config.update(kwargs) + return cls(**config) \ No newline at end of file