DiffLinker / src /visualizer.py
igashov
DiffLinker code
95ba5bc
raw
history blame
No virus
7.63 kB
import torch
import os
import imageio
import matplotlib.pyplot as plt
import numpy as np
import glob
import random
from sklearn.decomposition import PCA
from src import const
from src.molecule_builder import get_bond_order
def save_xyz_file(path, one_hot, positions, node_mask, names, is_geom, suffix=''):
idx2atom = const.GEOM_IDX2ATOM if is_geom else const.IDX2ATOM
for batch_i in range(one_hot.size(0)):
mask = node_mask[batch_i].squeeze()
n_atoms = mask.sum()
atom_idx = torch.where(mask)[0]
f = open(os.path.join(path, f'{names[batch_i]}_{suffix}.xyz'), "w")
f.write("%d\n\n" % n_atoms)
atoms = torch.argmax(one_hot[batch_i], dim=1)
for atom_i in atom_idx:
atom = atoms[atom_i].item()
atom = idx2atom[atom]
f.write("%s %.9f %.9f %.9f\n" % (
atom, positions[batch_i, atom_i, 0], positions[batch_i, atom_i, 1], positions[batch_i, atom_i, 2]
))
f.close()
def load_xyz_files(path, suffix=''):
files = []
for fname in os.listdir(path):
if fname.endswith(f'_{suffix}.xyz'):
files.append(fname)
files = sorted(files, key=lambda f: -int(f.replace(f'_{suffix}.xyz', '').split('_')[-1]))
return [os.path.join(path, fname) for fname in files]
def load_molecule_xyz(file, is_geom):
atom2idx = const.GEOM_ATOM2IDX if is_geom else const.ATOM2IDX
idx2atom = const.GEOM_IDX2ATOM if is_geom else const.IDX2ATOM
with open(file, encoding='utf8') as f:
n_atoms = int(f.readline())
one_hot = torch.zeros(n_atoms, len(idx2atom))
charges = torch.zeros(n_atoms, 1)
positions = torch.zeros(n_atoms, 3)
f.readline()
atoms = f.readlines()
for i in range(n_atoms):
atom = atoms[i].split(' ')
atom_type = atom[0]
one_hot[i, atom2idx[atom_type]] = 1
position = torch.Tensor([float(e) for e in atom[1:]])
positions[i, :] = position
return positions, one_hot, charges
def draw_sphere(ax, x, y, z, size, color, alpha):
u = np.linspace(0, 2 * np.pi, 100)
v = np.linspace(0, np.pi, 100)
xs = size * np.outer(np.cos(u), np.sin(v))
ys = size * np.outer(np.sin(u), np.sin(v)) #* 0.8
zs = size * np.outer(np.ones(np.size(u)), np.cos(v))
ax.plot_surface(x + xs, y + ys, z + zs, rstride=2, cstride=2, color=color, alpha=alpha)
def plot_molecule(ax, positions, atom_type, alpha, spheres_3d, hex_bg_color, is_geom, fragment_mask=None):
x = positions[:, 0]
y = positions[:, 1]
z = positions[:, 2]
# Hydrogen, Carbon, Nitrogen, Oxygen, Flourine
idx2atom = const.GEOM_IDX2ATOM if is_geom else const.IDX2ATOM
colors_dic = np.array(const.COLORS)
radius_dic = np.array(const.RADII)
area_dic = 1500 * radius_dic ** 2
areas = area_dic[atom_type]
radii = radius_dic[atom_type]
colors = colors_dic[atom_type]
if fragment_mask is None:
fragment_mask = torch.ones(len(x))
for i in range(len(x)):
for j in range(i + 1, len(x)):
p1 = np.array([x[i], y[i], z[i]])
p2 = np.array([x[j], y[j], z[j]])
dist = np.sqrt(np.sum((p1 - p2) ** 2))
atom1, atom2 = idx2atom[atom_type[i]], idx2atom[atom_type[j]]
draw_edge_int = get_bond_order(atom1, atom2, dist)
line_width = (3 - 2) * 2 * 2
draw_edge = draw_edge_int > 0
if draw_edge:
if draw_edge_int == 4:
linewidth_factor = 1.5
else:
linewidth_factor = 1
linewidth_factor *= 0.5
ax.plot(
[x[i], x[j]], [y[i], y[j]], [z[i], z[j]],
linewidth=line_width * linewidth_factor * 2,
c=hex_bg_color,
alpha=alpha
)
# from pdb import set_trace
# set_trace()
if spheres_3d:
# idx = torch.where(fragment_mask[:len(x)] == 0)[0]
# ax.scatter(
# x[idx],
# y[idx],
# z[idx],
# alpha=0.9 * alpha,
# edgecolors='#FCBA03',
# facecolors='none',
# linewidths=2,
# s=900
# )
for i, j, k, s, c, f in zip(x, y, z, radii, colors, fragment_mask):
if f == 1:
alpha = 1.0
draw_sphere(ax, i.item(), j.item(), k.item(), 0.5 * s, c, alpha)
else:
ax.scatter(x, y, z, s=areas, alpha=0.9 * alpha, c=colors)
def plot_data3d(positions, atom_type, is_geom, camera_elev=0, camera_azim=0, save_path=None, spheres_3d=False,
bg='black', alpha=1., fragment_mask=None):
black = (0, 0, 0)
white = (1, 1, 1)
hex_bg_color = '#FFFFFF' if bg == 'black' else '#000000' #'#666666'
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(projection='3d')
ax.set_aspect('auto')
ax.view_init(elev=camera_elev, azim=camera_azim)
if bg == 'black':
ax.set_facecolor(black)
else:
ax.set_facecolor(white)
ax.xaxis.pane.set_alpha(0)
ax.yaxis.pane.set_alpha(0)
ax.zaxis.pane.set_alpha(0)
ax._axis3don = False
if bg == 'black':
ax.w_xaxis.line.set_color("black")
else:
ax.w_xaxis.line.set_color("white")
plot_molecule(
ax, positions, atom_type, alpha, spheres_3d, hex_bg_color, is_geom=is_geom, fragment_mask=fragment_mask
)
max_value = positions.abs().max().item()
axis_lim = min(40, max(max_value / 1.5 + 0.3, 3.2))
ax.set_xlim(-axis_lim, axis_lim)
ax.set_ylim(-axis_lim, axis_lim)
ax.set_zlim(-axis_lim, axis_lim)
dpi = 120 if spheres_3d else 50
if save_path is not None:
plt.savefig(save_path, bbox_inches='tight', pad_inches=0.0, dpi=dpi)
# plt.savefig(save_path, bbox_inches='tight', pad_inches=0.0, dpi=dpi, transparent=True)
if spheres_3d:
img = imageio.imread(save_path)
img_brighter = np.clip(img * 1.4, 0, 255).astype('uint8')
imageio.imsave(save_path, img_brighter)
else:
plt.show()
plt.close()
def visualize_chain(
path, spheres_3d=False, bg="black", alpha=1.0, wandb=None, mode="chain", is_geom=False, fragment_mask=None
):
files = load_xyz_files(path)
save_paths = []
# Fit PCA to the final molecule – to obtain the best orientation for visualization
positions, one_hot, charges = load_molecule_xyz(files[-1], is_geom=is_geom)
pca = PCA(n_components=3)
pca.fit(positions)
for i in range(len(files)):
file = files[i]
positions, one_hot, charges = load_molecule_xyz(file, is_geom=is_geom)
atom_type = torch.argmax(one_hot, dim=1).numpy()
# Transform positions of each frame according to the best orientation of the last frame
positions = pca.transform(positions)
positions = torch.tensor(positions)
fn = file[:-4] + '.png'
plot_data3d(
positions, atom_type,
save_path=fn,
spheres_3d=spheres_3d,
alpha=alpha,
bg=bg,
camera_elev=90,
camera_azim=90,
is_geom=is_geom,
fragment_mask=fragment_mask,
)
save_paths.append(fn)
imgs = [imageio.imread(fn) for fn in save_paths]
dirname = os.path.dirname(save_paths[0])
gif_path = dirname + '/output.gif'
imageio.mimsave(gif_path, imgs, subrectangles=True)
if wandb is not None:
wandb.log({mode: [wandb.Video(gif_path, caption=gif_path)]})