File size: 7,089 Bytes
5de53c3
 
7782ac2
b0ab0d5
7782ac2
95ba5bc
 
52bf9df
95ba5bc
 
 
 
 
 
 
7782ac2
5de53c3
53f22d0
5de53c3
7782ac2
95ba5bc
 
49021fb
95ba5bc
ff9d86b
 
 
 
 
95ba5bc
 
 
ff9d86b
 
 
 
 
95ba5bc
 
 
 
 
d1da608
95ba5bc
 
 
 
 
 
 
 
 
 
 
0673854
 
 
 
95ba5bc
 
 
 
 
 
 
 
 
 
 
 
52bf9df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7782ac2
52bf9df
 
 
 
 
 
 
 
 
7782ac2
0673854
53f22d0
7c181a3
f9310fd
 
 
 
95ba5bc
 
 
b0ab0d5
 
 
f9310fd
 
 
 
95ba5bc
b0ab0d5
 
 
95ba5bc
 
 
 
 
 
 
 
b0ab0d5
 
 
95ba5bc
 
 
 
 
 
 
 
 
 
f9310fd
95ba5bc
f9310fd
95ba5bc
 
 
 
52bf9df
4f94923
52bf9df
f9310fd
4f94923
7782ac2
 
 
 
 
 
 
711f689
 
52bf9df
ce8384d
52bf9df
 
 
 
 
 
 
 
 
 
 
 
 
7782ac2
 
 
4f94923
7782ac2
 
0ce499b
 
 
 
 
 
 
 
 
5de53c3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
import argparse

import gradio as gr
import numpy as np
import os
import torch
import subprocess
import output

from rdkit import Chem
from src import const
from src.visualizer import save_xyz_file
from src.datasets import get_dataloader, collate_with_fragment_edges, parse_molecule
from src.lightning import DDPM
from src.linker_size_lightning import SizeClassifier

parser = argparse.ArgumentParser()
parser.add_argument('--ip', type=str, default=None)
args = parser.parse_args()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.makedirs("results", exist_ok=True)
os.makedirs("models", exist_ok=True)

size_gnn_path = 'models/geom_size_gnn.ckpt'
if not os.path.exists(size_gnn_path):
    print('Downloading SizeGNN model...')
    link = 'https://zenodo.org/record/7121300/files/geom_size_gnn.ckpt?download=1'
    subprocess.run(f'wget {link} -O {size_gnn_path}', shell=True)
size_nn = SizeClassifier.load_from_checkpoint('models/geom_size_gnn.ckpt', map_location=device).eval().to(device)
print('Loaded SizeGNN model')

diffusion_path = 'models/geom_difflinker.ckpt'
if not os.path.exists(diffusion_path):
    print('Downloading Diffusion model...')
    link = 'https://zenodo.org/record/7121300/files/geom_difflinker.ckpt?download=1'
    subprocess.run(f'wget {link} -O {diffusion_path}', shell=True)
ddpm = DDPM.load_from_checkpoint('models/geom_difflinker.ckpt', map_location=device).eval().to(device)
print('Loaded diffusion model')


def sample_fn(_data):
    output, _ = size_nn.forward(_data, return_loss=False)
    probabilities = torch.softmax(output, dim=1)
    distribution = torch.distributions.Categorical(probs=probabilities)
    samples = distribution.sample()
    sizes = []
    for label in samples.detach().cpu().numpy():
        sizes.append(size_nn.linker_id2size[label])
    sizes = torch.tensor(sizes, device=samples.device, dtype=torch.long)
    return sizes


def read_molecule_content(path):
    with open(path, "r") as f:
        return "".join(f.readlines())


def read_molecule(path):
    if path.endswith('.pdb'):
        return Chem.MolFromPDBFile(path, sanitize=False, removeHs=True)
    elif path.endswith('.mol'):
        return Chem.MolFromMolFile(path, sanitize=False, removeHs=True)
    elif path.endswith('.mol2'):
        return Chem.MolFromMol2File(path, sanitize=False, removeHs=True)
    elif path.endswith('.sdf'):
        return Chem.SDMolSupplier(path, sanitize=False, removeHs=True)[0]
    raise Exception('Unknown file extension')


def show_input(input_file):
    if input_file is None:
        return ''

    path = input_file.name
    extension = path.split('.')[-1]
    if extension not in ['sdf', 'pdb', 'mol', 'mol2']:
        msg = output.INVALID_FORMAT_MSG.format(extension=extension)
        return output.IFRAME_TEMPLATE.format(html=msg)

    try:
        molecule = read_molecule_content(path)
    except Exception as e:
        return f'Could not read the molecule: {e}'

    html = output.HTML_TEMPLATE.format(molecule=molecule, fmt=extension)
    return output.IFRAME_TEMPLATE.format(html=html)


def generate(input_file):
    if input_file is None:
        return ''

    path = input_file.name
    extension = path.split('.')[-1]
    if extension not in ['sdf', 'pdb', 'mol', 'mol2']:
        msg = output.INVALID_FORMAT_MSG.format(extension=extension)
        return output.IFRAME_TEMPLATE.format(html=msg)

    try:
        molecule = read_molecule(path)
        molecule = Chem.RemoveAllHs(molecule)
        name = '.'.join(path.split('/')[-1].split('.')[:-1])
        inp_sdf = f'results/{name}_input.sdf'
        inp_xyz = f'results/{name}_input.xyz'
        out_sdf = f'results/{name}_output.sdf'
        out_xyz = f'results/{name}_output.xyz'
    except Exception as e:
        return f'Could not read the molecule: {e}'

    if molecule.GetNumAtoms() > 50:
        return f'Too large molecule: upper limit is 50 heavy atoms'

    with Chem.SDWriter(inp_sdf) as w:
        w.write(molecule)
    Chem.MolToXYZFile(molecule, inp_xyz)

    positions, one_hot, charges = parse_molecule(molecule, is_geom=True)
    anchors = np.zeros_like(charges)
    fragment_mask = np.ones_like(charges)
    linker_mask = np.zeros_like(charges)
    print('Read and parsed molecule')

    dataset = [{
        'uuid': '0',
        'name': '0',
        'positions': torch.tensor(positions, dtype=const.TORCH_FLOAT, device=device),
        'one_hot': torch.tensor(one_hot, dtype=const.TORCH_FLOAT, device=device),
        'charges': torch.tensor(charges, dtype=const.TORCH_FLOAT, device=device),
        'anchors': torch.tensor(anchors, dtype=const.TORCH_FLOAT, device=device),
        'fragment_mask': torch.tensor(fragment_mask, dtype=const.TORCH_FLOAT, device=device),
        'linker_mask': torch.tensor(linker_mask, dtype=const.TORCH_FLOAT, device=device),
        'num_atoms': len(positions),
    }]
    dataloader = get_dataloader(dataset, batch_size=1, collate_fn=collate_with_fragment_edges)
    print('Created dataloader')

    for data in dataloader:
        chain, node_mask = ddpm.sample_chain(data, sample_fn=sample_fn, keep_frames=1)
        print('Generated linker')
        x = chain[0][:, :, :ddpm.n_dims]
        h = chain[0][:, :, ddpm.n_dims:]
        save_xyz_file('results', h, x, node_mask, names=[name], is_geom=True, suffix='output')
        print('Saved XYZ file')
        subprocess.run(f'obabel {out_xyz} -O {out_sdf}', shell=True)
        print('Converted to SDF')
        break

    generated_molecule = read_molecule_content(out_sdf)
    html = output.HTML_TEMPLATE.format(molecule=generated_molecule, fmt='sdf')
    return [
        output.IFRAME_TEMPLATE.format(html=html),
        [inp_sdf, inp_xyz, out_sdf, out_xyz],
    ]


demo = gr.Blocks()
with demo:
    gr.Markdown('# DiffLinker: Equivariant 3D-Conditional Diffusion Model for Molecular Linker Design')
    with gr.Box():
        with gr.Row():
            with gr.Column():
                gr.Markdown('## Input Fragments')
                gr.Markdown('Upload the file with 3D-coordinates of the input fragments in .pdb, .mol2 or .sdf format:')
                input_file = gr.File(file_count='single', label='Input Fragments')
                button = gr.Button('Generate Linker!')
                gr.Markdown('')
                gr.Markdown('## Output Files')
                gr.Markdown('Download files with the generated molecules here:')
                output_files = gr.File(file_count='multiple', label='Output Files')
            with gr.Column():
                visualization = gr.HTML()

    input_file.change(
        fn=show_input,
        inputs=[input_file],
        outputs=[visualization],
    )
    button.click(
        fn=generate,
        inputs=[input_file],
        outputs=[visualization, output_files],
    )

    examples = gr.Examples(
        examples=[['examples/example_1.sdf'], ['examples/example_2.sdf']],
        inputs=[input_file],
        outputs=[visualization],
        fn=show_input,
        run_on_click=True,
        cache_examples=False,
    )

demo.launch(server_name=args.ip)