from rdkit.Chem.rdmolfiles import MolToPDBBlock, MolToPDBFile import rdkit.Chem from rdkit import Geometry from collections import defaultdict import copy import numpy as np import torch class PDBFile: def __init__(self, mol): self.parts = defaultdict(dict) self.mol = copy.deepcopy(mol) [self.mol.RemoveConformer(j) for j in range(mol.GetNumConformers()) if j] def add(self, coords, order, part=0, repeat=1): if type(coords) in [rdkit.Chem.Mol, rdkit.Chem.RWMol]: block = MolToPDBBlock(coords).split('\n')[:-2] self.parts[part][order] = {'block': block, 'repeat': repeat} return elif type(coords) is np.ndarray: coords = coords.astype(np.float64) elif type(coords) is torch.Tensor: coords = coords.double().numpy() for i in range(coords.shape[0]): self.mol.GetConformer(0).SetAtomPosition(i, Geometry.Point3D(coords[i, 0], coords[i, 1], coords[i, 2])) block = MolToPDBBlock(self.mol).split('\n')[:-2] self.parts[part][order] = {'block': block, 'repeat': repeat} def write(self, path=None, limit_parts=None): is_first = True str_ = '' for part in sorted(self.parts.keys()): if limit_parts and part >= limit_parts: break part = self.parts[part] keys_positive = sorted(filter(lambda x: x >=0, part.keys())) keys_negative = sorted(filter(lambda x: x < 0, part.keys())) keys = list(keys_positive) + list(keys_negative) for key in keys: block = part[key]['block'] times = part[key]['repeat'] for _ in range(times): if not is_first: block = [line for line in block if 'CONECT' not in line] is_first = False str_ += 'MODEL\n' str_ += '\n'.join(block) str_ += '\nENDMDL\n' if not path: return str_ with open(path, 'w') as f: f.write(str_)