diffdock / utils /visualise.py
gcorso's picture
first commit
4a3f787
raw history blame
No virus
2.13 kB
from rdkit.Chem.rdmolfiles import MolToPDBBlock, MolToPDBFile
import rdkit.Chem
from rdkit import Geometry
from collections import defaultdict
import copy
import numpy as np
import torch
class PDBFile:
def __init__(self, mol):
self.parts = defaultdict(dict)
self.mol = copy.deepcopy(mol)
[self.mol.RemoveConformer(j) for j in range(mol.GetNumConformers()) if j]
def add(self, coords, order, part=0, repeat=1):
if type(coords) in [rdkit.Chem.Mol, rdkit.Chem.RWMol]:
block = MolToPDBBlock(coords).split('\n')[:-2]
self.parts[part][order] = {'block': block, 'repeat': repeat}
return
elif type(coords) is np.ndarray:
coords = coords.astype(np.float64)
elif type(coords) is torch.Tensor:
coords = coords.double().numpy()
for i in range(coords.shape[0]):
self.mol.GetConformer(0).SetAtomPosition(i, Geometry.Point3D(coords[i, 0], coords[i, 1], coords[i, 2]))
block = MolToPDBBlock(self.mol).split('\n')[:-2]
self.parts[part][order] = {'block': block, 'repeat': repeat}
def write(self, path=None, limit_parts=None):
is_first = True
str_ = ''
for part in sorted(self.parts.keys()):
if limit_parts and part >= limit_parts:
break
part = self.parts[part]
keys_positive = sorted(filter(lambda x: x >=0, part.keys()))
keys_negative = sorted(filter(lambda x: x < 0, part.keys()))
keys = list(keys_positive) + list(keys_negative)
for key in keys:
block = part[key]['block']
times = part[key]['repeat']
for _ in range(times):
if not is_first:
block = [line for line in block if 'CONECT' not in line]
is_first = False
str_ += 'MODEL\n'
str_ += '\n'.join(block)
str_ += '\nENDMDL\n'
if not path:
return str_
with open(path, 'w') as f:
f.write(str_)