diffdock / utils /visualise.py
gcorso's picture
first commit
4a3f787
from rdkit.Chem.rdmolfiles import MolToPDBBlock, MolToPDBFile
import rdkit.Chem
from rdkit import Geometry
from collections import defaultdict
import copy
import numpy as np
import torch
class PDBFile:
def __init__(self, mol):
self.parts = defaultdict(dict)
self.mol = copy.deepcopy(mol)
[self.mol.RemoveConformer(j) for j in range(mol.GetNumConformers()) if j]
def add(self, coords, order, part=0, repeat=1):
if type(coords) in [rdkit.Chem.Mol, rdkit.Chem.RWMol]:
block = MolToPDBBlock(coords).split('\n')[:-2]
self.parts[part][order] = {'block': block, 'repeat': repeat}
return
elif type(coords) is np.ndarray:
coords = coords.astype(np.float64)
elif type(coords) is torch.Tensor:
coords = coords.double().numpy()
for i in range(coords.shape[0]):
self.mol.GetConformer(0).SetAtomPosition(i, Geometry.Point3D(coords[i, 0], coords[i, 1], coords[i, 2]))
block = MolToPDBBlock(self.mol).split('\n')[:-2]
self.parts[part][order] = {'block': block, 'repeat': repeat}
def write(self, path=None, limit_parts=None):
is_first = True
str_ = ''
for part in sorted(self.parts.keys()):
if limit_parts and part >= limit_parts:
break
part = self.parts[part]
keys_positive = sorted(filter(lambda x: x >=0, part.keys()))
keys_negative = sorted(filter(lambda x: x < 0, part.keys()))
keys = list(keys_positive) + list(keys_negative)
for key in keys:
block = part[key]['block']
times = part[key]['repeat']
for _ in range(times):
if not is_first:
block = [line for line in block if 'CONECT' not in line]
is_first = False
str_ += 'MODEL\n'
str_ += '\n'.join(block)
str_ += '\nENDMDL\n'
if not path:
return str_
with open(path, 'w') as f:
f.write(str_)