igashov commited on
Commit
92263a6
1 Parent(s): 8fd5e3f

Added an option to select anchor atoms

Browse files
Files changed (3) hide show
  1. app.py +102 -55
  2. output.py +91 -3
  3. src/generation.py +38 -0
app.py CHANGED
@@ -9,12 +9,30 @@ import output
9
 
10
  from rdkit import Chem
11
  from src import const
12
- from src.visualizer import save_xyz_file
13
  from src.datasets import get_dataloader, collate_with_fragment_edges, parse_molecule
14
  from src.lightning import DDPM
15
  from src.linker_size_lightning import SizeClassifier
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
- N_SAMPLES = 5
18
 
19
  parser = argparse.ArgumentParser()
20
  parser.add_argument('--ip', type=str, default=None)
@@ -33,13 +51,22 @@ if not os.path.exists(size_gnn_path):
33
  size_nn = SizeClassifier.load_from_checkpoint('models/geom_size_gnn.ckpt', map_location=device).eval().to(device)
34
  print('Loaded SizeGNN model')
35
 
36
- diffusion_path = 'models/geom_difflinker.ckpt'
37
- if not os.path.exists(diffusion_path):
38
- print('Downloading Diffusion model...')
39
- link = 'https://zenodo.org/record/7121300/files/geom_difflinker.ckpt?download=1'
40
- subprocess.run(f'wget {link} -O {diffusion_path}', shell=True)
41
- ddpm = DDPM.load_from_checkpoint('models/geom_difflinker.ckpt', map_location=device).eval().to(device)
42
- print('Loaded diffusion model')
 
 
 
 
 
 
 
 
 
43
 
44
  def read_molecule_content(path):
45
  with open(path, "r") as f:
@@ -60,7 +87,7 @@ def read_molecule(path):
60
 
61
  def show_input(input_file):
62
  if input_file is None:
63
- return ['', gr.Radio.update(visible=False, value='Sample 1')]
64
  if isinstance(input_file, str):
65
  path = input_file
66
  else:
@@ -70,7 +97,8 @@ def show_input(input_file):
70
  msg = output.INVALID_FORMAT_MSG.format(extension=extension)
71
  return [
72
  output.IFRAME_TEMPLATE.format(html=msg),
73
- gr.Radio.update(visible=False)
 
74
  ]
75
 
76
  try:
@@ -78,17 +106,22 @@ def show_input(input_file):
78
  except Exception as e:
79
  return [
80
  f'Could not read the molecule: {e}',
81
- gr.Radio.update(visible=False)
 
82
  ]
83
 
84
  html = output.INITIAL_RENDERING_TEMPLATE.format(molecule=molecule, fmt=extension)
85
  return [
86
  output.IFRAME_TEMPLATE.format(html=html),
87
- gr.Radio.update(visible=False)
 
88
  ]
89
 
90
 
91
  def draw_sample(idx, out_files):
 
 
 
92
  in_file = out_files[0]
93
  in_sdf = in_file if isinstance(in_file, str) else in_file.name
94
 
@@ -97,24 +130,43 @@ def draw_sample(idx, out_files):
97
 
98
  input_fragments_content = read_molecule_content(in_sdf)
99
  generated_molecule_content = read_molecule_content(out_sdf)
 
 
 
 
100
  html = output.SAMPLES_RENDERING_TEMPLATE.format(
101
  fragments=input_fragments_content,
102
- fragments_fmt='sdf',
103
  molecule=generated_molecule_content,
104
- molecule_fmt='sdf',
105
  )
106
  return output.IFRAME_TEMPLATE.format(html=html)
107
 
108
 
109
- def generate(input_file, n_steps, n_atoms):
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  if input_file is None:
111
- return ''
112
 
 
 
113
  path = input_file.name
114
  extension = path.split('.')[-1]
115
  if extension not in ['sdf', 'pdb', 'mol', 'mol2']:
116
  msg = output.INVALID_FORMAT_MSG.format(extension=extension)
117
- return output.IFRAME_TEMPLATE.format(html=msg)
118
 
119
  try:
120
  molecule = read_molecule(path)
@@ -122,16 +174,22 @@ def generate(input_file, n_steps, n_atoms):
122
  name = '.'.join(path.split('/')[-1].split('.')[:-1])
123
  inp_sdf = f'results/input_{name}.sdf'
124
  except Exception as e:
125
- return f'Could not read the molecule: {e}'
 
 
126
 
127
  if molecule.GetNumAtoms() > 50:
128
- return f'Too large molecule: upper limit is 50 heavy atoms'
 
 
129
 
130
  with Chem.SDWriter(inp_sdf) as w:
131
  w.write(molecule)
132
 
133
  positions, one_hot, charges = parse_molecule(molecule, is_geom=True)
134
  anchors = np.zeros_like(charges)
 
 
135
  fragment_mask = np.ones_like(charges)
136
  linker_mask = np.zeros_like(charges)
137
  print('Read and parsed molecule')
@@ -151,7 +209,6 @@ def generate(input_file, n_steps, n_atoms):
151
  print('Created dataloader')
152
 
153
  ddpm.edm.T = n_steps
154
- assert ddpm.center_of_mass == 'fragments'
155
 
156
  if n_atoms == 0:
157
  def sample_fn(_data):
@@ -169,34 +226,21 @@ def generate(input_file, n_steps, n_atoms):
169
  return torch.ones(_data['positions'].shape[0], device=device, dtype=torch.long) * n_atoms
170
 
171
  for data in dataloader:
172
- chain, node_mask = ddpm.sample_chain(data, sample_fn=sample_fn, keep_frames=1)
173
- print('Generated linker')
174
- x = chain[0][:, :, :ddpm.n_dims]
175
- h = chain[0][:, :, ddpm.n_dims:]
176
-
177
- # Put the molecule back to the initial orientation
178
- pos_masked = data['positions'] * data['fragment_mask']
179
- N = data['fragment_mask'].sum(1, keepdims=True)
180
- mean = torch.sum(pos_masked, dim=1, keepdim=True) / N
181
- x = x + mean * node_mask
182
-
183
- names = [f'output_{i+1}_{name}' for i in range(N_SAMPLES)]
184
- save_xyz_file('results', h, x, node_mask, names=names, is_geom=True, suffix='')
185
- print('Saved XYZ files')
186
- break
187
-
188
- out_files = []
189
- for i in range(N_SAMPLES):
190
- out_xyz = f'results/output_{i+1}_{name}_.xyz'
191
- out_sdf = f'results/output_{i+1}_{name}_.sdf'
192
- subprocess.run(f'obabel {out_xyz} -O {out_sdf}', shell=True)
193
- out_files.append(out_sdf)
194
- print('Converted to SDF')
195
 
196
  return [
197
- draw_sample(0, out_files),
198
- [inp_sdf] + out_files,
199
- gr.Radio.update(visible=True, value='Sample 1')
 
200
  ]
201
 
202
 
@@ -215,6 +259,7 @@ with demo:
215
  )
216
  with gr.Box():
217
  with gr.Row():
 
218
  with gr.Column():
219
  gr.Markdown('## Input Fragments')
220
  gr.Markdown('Upload the file with 3D-coordinates of the input fragments in .pdb, .mol2 or .sdf format:')
@@ -238,11 +283,11 @@ with demo:
238
  output_files = gr.File(file_count='multiple', label='Output Files', interactive=False)
239
  with gr.Column():
240
  gr.Markdown('## Visualization')
241
- # gr.Markdown('Below you will see input and output molecules')
242
  samples = gr.Radio(
243
  choices=['Sample 1', 'Sample 2', 'Sample 3', 'Sample 4', 'Sample 5'],
244
  value='Sample 1',
245
- type='index',
246
  show_label=False,
247
  visible=False,
248
  interactive=True,
@@ -252,27 +297,29 @@ with demo:
252
  input_file.change(
253
  fn=show_input,
254
  inputs=[input_file],
255
- outputs=[visualization, samples],
256
  )
257
  input_file.clear(
258
- fn=lambda: [None, '', gr.Radio.update(visible=False)],
259
  inputs=[],
260
- outputs=[input_file, visualization, samples],
261
  )
262
  examples.click(
263
  fn=lambda idx: [f'examples/example_{idx+1}.sdf', 10, 0] + show_input(f'examples/example_{idx+1}.sdf'),
264
  inputs=[examples],
265
- outputs=[input_file, n_steps, n_atoms, visualization, samples]
266
  )
267
  button.click(
268
  fn=generate,
269
- inputs=[input_file, n_steps, n_atoms],
270
- outputs=[visualization, output_files, samples],
 
271
  )
272
  samples.change(
273
  fn=draw_sample,
274
  inputs=[samples, output_files],
275
  outputs=[visualization],
276
  )
 
277
 
278
  demo.launch(server_name=args.ip)
 
9
 
10
  from rdkit import Chem
11
  from src import const
 
12
  from src.datasets import get_dataloader, collate_with_fragment_edges, parse_molecule
13
  from src.lightning import DDPM
14
  from src.linker_size_lightning import SizeClassifier
15
+ from src.generation import N_SAMPLES, generate_linkers, try_to_convert_to_sdf
16
+
17
+ MODELS_METADATA = {
18
+ 'geom_difflinker': {
19
+ 'link': 'https://zenodo.org/record/7121300/files/geom_difflinker.ckpt?download=1',
20
+ 'path': 'models/geom_difflinker.ckpt',
21
+ },
22
+ 'geom_difflinker_given_anchors': {
23
+ 'link': 'https://zenodo.org/record/7775568/files/geom_difflinker_given_anchors.ckpt?download=1',
24
+ 'path': 'models/geom_difflinker_given_anchors.ckpt',
25
+ },
26
+ 'pockets_difflinker': {
27
+ 'link': 'https://zenodo.org/record/7775568/files/pockets_difflinker_full_no_anchors.ckpt?download=1',
28
+ 'path': 'models/pockets_difflinker.ckpt',
29
+ },
30
+ 'pockets_difflinker_given_anchors': {
31
+ 'link': 'https://zenodo.org/record/7775568/files/pockets_difflinker_full.ckpt?download=1',
32
+ 'path': 'models/pockets_difflinker_given_anchors.ckpt',
33
+ },
34
+ }
35
 
 
36
 
37
  parser = argparse.ArgumentParser()
38
  parser.add_argument('--ip', type=str, default=None)
 
51
  size_nn = SizeClassifier.load_from_checkpoint('models/geom_size_gnn.ckpt', map_location=device).eval().to(device)
52
  print('Loaded SizeGNN model')
53
 
54
+
55
+ diffusion_models = {}
56
+ for model_name, metadata in MODELS_METADATA.items():
57
+ link = metadata['link']
58
+ diffusion_path = metadata['path']
59
+ if not os.path.exists(diffusion_path):
60
+ print(f'Downloading {model_name}...')
61
+ subprocess.run(f'wget {link} -O {diffusion_path}', shell=True)
62
+ diffusion_models[model_name] = DDPM.load_from_checkpoint(diffusion_path, map_location=device).eval().to(device)
63
+ print(f'Loaded model {model_name}')
64
+
65
+
66
+ print(os.curdir)
67
+ print(os.path.abspath(os.curdir))
68
+ print(os.listdir(os.curdir))
69
+
70
 
71
  def read_molecule_content(path):
72
  with open(path, "r") as f:
 
87
 
88
  def show_input(input_file):
89
  if input_file is None:
90
+ return ['', gr.Radio.update(visible=False, value='Sample 1'), None]
91
  if isinstance(input_file, str):
92
  path = input_file
93
  else:
 
97
  msg = output.INVALID_FORMAT_MSG.format(extension=extension)
98
  return [
99
  output.IFRAME_TEMPLATE.format(html=msg),
100
+ gr.Radio.update(visible=False),
101
+ None,
102
  ]
103
 
104
  try:
 
106
  except Exception as e:
107
  return [
108
  f'Could not read the molecule: {e}',
109
+ gr.Radio.update(visible=False),
110
+ None,
111
  ]
112
 
113
  html = output.INITIAL_RENDERING_TEMPLATE.format(molecule=molecule, fmt=extension)
114
  return [
115
  output.IFRAME_TEMPLATE.format(html=html),
116
+ gr.Radio.update(visible=False),
117
+ None,
118
  ]
119
 
120
 
121
  def draw_sample(idx, out_files):
122
+ if isinstance(idx, str):
123
+ idx = int(idx.strip().split(' ')[-1]) - 1
124
+
125
  in_file = out_files[0]
126
  in_sdf = in_file if isinstance(in_file, str) else in_file.name
127
 
 
130
 
131
  input_fragments_content = read_molecule_content(in_sdf)
132
  generated_molecule_content = read_molecule_content(out_sdf)
133
+
134
+ fragments_fmt = in_sdf.split('.')[-1]
135
+ molecule_fmt = out_sdf.split('.')[-1]
136
+
137
  html = output.SAMPLES_RENDERING_TEMPLATE.format(
138
  fragments=input_fragments_content,
139
+ fragments_fmt=fragments_fmt,
140
  molecule=generated_molecule_content,
141
+ molecule_fmt=molecule_fmt,
142
  )
143
  return output.IFRAME_TEMPLATE.format(html=html)
144
 
145
 
146
+ def generate(input_file, n_steps, n_atoms, radio_samples, selected_atoms):
147
+ # Parsing selected atoms (javascript output)
148
+ selected_atoms = selected_atoms.strip()
149
+ if selected_atoms == '':
150
+ selected_atoms = []
151
+ else:
152
+ selected_atoms = list(map(int, selected_atoms.split(',')))
153
+
154
+ # Selecting model
155
+ if len(selected_atoms) == 0:
156
+ selected_model_name = 'geom_difflinker'
157
+ else:
158
+ selected_model_name = 'geom_difflinker_given_anchors'
159
+
160
  if input_file is None:
161
+ return [None, None, None, None]
162
 
163
+ print(f'Start generating with model {selected_model_name}, selected_atoms:', selected_atoms)
164
+ ddpm = diffusion_models[selected_model_name]
165
  path = input_file.name
166
  extension = path.split('.')[-1]
167
  if extension not in ['sdf', 'pdb', 'mol', 'mol2']:
168
  msg = output.INVALID_FORMAT_MSG.format(extension=extension)
169
+ return [output.IFRAME_TEMPLATE.format(html=msg), None, None, None]
170
 
171
  try:
172
  molecule = read_molecule(path)
 
174
  name = '.'.join(path.split('/')[-1].split('.')[:-1])
175
  inp_sdf = f'results/input_{name}.sdf'
176
  except Exception as e:
177
+ error = f'Could not read the molecule: {e}'
178
+ msg = output.ERROR_FORMAT_MSG.format(message=error)
179
+ return [output.IFRAME_TEMPLATE.format(html=msg), None, None, None]
180
 
181
  if molecule.GetNumAtoms() > 50:
182
+ error = f'Too large molecule: upper limit is 50 heavy atoms'
183
+ msg = output.ERROR_FORMAT_MSG.format(message=error)
184
+ return [output.IFRAME_TEMPLATE.format(html=msg), None, None, None]
185
 
186
  with Chem.SDWriter(inp_sdf) as w:
187
  w.write(molecule)
188
 
189
  positions, one_hot, charges = parse_molecule(molecule, is_geom=True)
190
  anchors = np.zeros_like(charges)
191
+ anchors[selected_atoms] = 1
192
+
193
  fragment_mask = np.ones_like(charges)
194
  linker_mask = np.zeros_like(charges)
195
  print('Read and parsed molecule')
 
209
  print('Created dataloader')
210
 
211
  ddpm.edm.T = n_steps
 
212
 
213
  if n_atoms == 0:
214
  def sample_fn(_data):
 
226
  return torch.ones(_data['positions'].shape[0], device=device, dtype=torch.long) * n_atoms
227
 
228
  for data in dataloader:
229
+ try:
230
+ generate_linkers(ddpm=ddpm, data=data, sample_fn=sample_fn, name=name)
231
+ except Exception as e:
232
+ error = f'Caught exception while generating linkers: {e}'
233
+ msg = output.ERROR_FORMAT_MSG.format(message=error)
234
+ return [output.IFRAME_TEMPLATE.format(html=msg), None, None, None]
235
+
236
+ out_files = try_to_convert_to_sdf(name)
237
+ out_files = [inp_sdf] + out_files
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
 
239
  return [
240
+ draw_sample(radio_samples, out_files),
241
+ out_files,
242
+ gr.Radio.update(visible=True),
243
+ None
244
  ]
245
 
246
 
 
259
  )
260
  with gr.Box():
261
  with gr.Row():
262
+ hidden = gr.Textbox(visible=False)
263
  with gr.Column():
264
  gr.Markdown('## Input Fragments')
265
  gr.Markdown('Upload the file with 3D-coordinates of the input fragments in .pdb, .mol2 or .sdf format:')
 
283
  output_files = gr.File(file_count='multiple', label='Output Files', interactive=False)
284
  with gr.Column():
285
  gr.Markdown('## Visualization')
286
+ gr.Markdown('**Hint:** click on atoms to select anchor points (optionally)')
287
  samples = gr.Radio(
288
  choices=['Sample 1', 'Sample 2', 'Sample 3', 'Sample 4', 'Sample 5'],
289
  value='Sample 1',
290
+ type='value',
291
  show_label=False,
292
  visible=False,
293
  interactive=True,
 
297
  input_file.change(
298
  fn=show_input,
299
  inputs=[input_file],
300
+ outputs=[visualization, samples, hidden],
301
  )
302
  input_file.clear(
303
+ fn=lambda: [None, '', gr.Radio.update(visible=False), None],
304
  inputs=[],
305
+ outputs=[input_file, visualization, samples, hidden],
306
  )
307
  examples.click(
308
  fn=lambda idx: [f'examples/example_{idx+1}.sdf', 10, 0] + show_input(f'examples/example_{idx+1}.sdf'),
309
  inputs=[examples],
310
+ outputs=[input_file, n_steps, n_atoms, visualization, samples, hidden]
311
  )
312
  button.click(
313
  fn=generate,
314
+ inputs=[input_file, n_steps, n_atoms, samples, hidden],
315
+ outputs=[visualization, output_files, samples, hidden],
316
+ _js=output.RETURN_SELECTION_JS,
317
  )
318
  samples.change(
319
  fn=draw_sample,
320
  inputs=[samples, output_files],
321
  outputs=[visualization],
322
  )
323
+ demo.load(_js=output.STARTUP_JS)
324
 
325
  demo.launch(server_name=args.ip)
output.py CHANGED
@@ -22,9 +22,42 @@ INITIAL_RENDERING_TEMPLATE = """<!DOCTYPE html>
22
  $(document).ready(function() {{
23
  let element = $("#container");
24
  let config = {{ backgroundColor: "white" }};
25
- let viewer = $3Dmol.createViewer( element, config );
26
- viewer.addModel(`{molecule}`, "{fmt}")
27
- viewer.getModel().setStyle({{ stick: {{ colorscheme:"greenCarbon" }} }})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  viewer.zoomTo();
29
  viewer.zoom(0.7);
30
  viewer.render();
@@ -113,8 +146,63 @@ INVALID_FORMAT_MSG = """
113
  </html>
114
  """
115
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
  IFRAME_TEMPLATE = """<iframe style="width: 100%; height: 700px" name="result" allow="midi; geolocation; microphone; camera;
118
  display-capture; encrypted-media;" sandbox="allow-modals allow-forms allow-scripts allow-same-origin allow-popups
119
  allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
120
  allowpaymentrequest="" frameborder="0" srcdoc='{html}'></iframe>"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  $(document).ready(function() {{
23
  let element = $("#container");
24
  let config = {{ backgroundColor: "white" }};
25
+ let viewer = $3Dmol.createViewer(element, config);
26
+ let defaultStyle = {{ stick: {{ colorscheme: "greenCarbon" }} }};
27
+ viewer.addModel(`{molecule}`, "{fmt}");
28
+ viewer.getModel(0).setStyle(defaultStyle);
29
+ // document.cookie = document.cookie + "|selected_atoms:";
30
+
31
+ viewer.getModel(0).setClickable(
32
+ {{}},
33
+ true,
34
+ function (_atom, _viewer, _event, _container) {{
35
+ if (!_atom.isClicked) {{
36
+ _atom.isClicked = true;
37
+ _viewer.addStyle(
38
+ {{"serial": _atom.serial, "model": 0}},
39
+ {{"sphere": {{"color": "magenta", "radius": 0.4}} }}
40
+ );
41
+ // document.cookie = document.cookie + "atom_" + String(_atom.serial) + "-";
42
+ window.parent.postMessage({{
43
+ name: "atom_selection",
44
+ data: {{"atom": _atom.serial, "add": true}}
45
+ // data: JSON.stringify({{"add": _atom.serial}})
46
+ }}, "*");
47
+ }} else {{
48
+ delete _atom.isClicked;
49
+ _viewer.setStyle({{"serial": _atom.serial, "model": 0}}, defaultStyle);
50
+ // document.cookie = document.cookie.replace("atom_" + String(_atom.serial) + "-", "");
51
+ window.parent.postMessage({{
52
+ name: "atom_selection",
53
+ data: {{"atom": _atom.serial, "add": false}}
54
+ // data: JSON.stringify({{"remove": _atom.serial}})
55
+ }}, "*");
56
+ }}
57
+ _viewer.render();
58
+ }}
59
+ );
60
+
61
  viewer.zoomTo();
62
  viewer.zoom(0.7);
63
  viewer.render();
 
146
  </html>
147
  """
148
 
149
+ ERROR_FORMAT_MSG = """
150
+ <!DOCTYPE html>
151
+ <html>
152
+ <head>
153
+ <meta http-equiv="content-type" content="text/html; charset=UTF-8" />
154
+ <style>
155
+ body{{
156
+ font-family:sans-serif
157
+ }}
158
+ </style>
159
+ </head>
160
+
161
+ <body>
162
+ <h3>Error:</h3>
163
+ {message}
164
+ </body>
165
+ </html>
166
+ """
167
+
168
 
169
  IFRAME_TEMPLATE = """<iframe style="width: 100%; height: 700px" name="result" allow="midi; geolocation; microphone; camera;
170
  display-capture; encrypted-media;" sandbox="allow-modals allow-forms allow-scripts allow-same-origin allow-popups
171
  allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
172
  allowpaymentrequest="" frameborder="0" srcdoc='{html}'></iframe>"""
173
+
174
+
175
+ STARTUP_JS = """
176
+ () => {
177
+ window.selected_elements = {}
178
+
179
+ function handleMessage(event) {
180
+ // console.log("New message: ", event.data)
181
+ let atom = event.data.data["atom"];
182
+ let add = event.data.data["add"];
183
+ console.log("add: ", add, " atom: ", atom);
184
+ window.selected_elements[atom] = add;
185
+ }
186
+
187
+ window.addEventListener("message", handleMessage);
188
+ console.log("Listener Added");
189
+ }
190
+ """
191
+
192
+ RETURN_SELECTION_JS = """
193
+ (input_file, n_steps, n_atoms, samples, hidden) => {
194
+ let selected = []
195
+ for (const [atom, add] of Object.entries(window.selected_elements)) {
196
+ if (add) {
197
+ console.log("Adding atom ", atom);
198
+ selected.push(String(atom));
199
+ window.parent.postMessage({
200
+ name: "atom_selection",
201
+ data: {"atom": parseInt(atom), "add": false}
202
+ }, "*");
203
+ }
204
+ }
205
+ console.log("Finished parsing");
206
+ return [input_file, n_steps, n_atoms, samples, selected.join(",")];
207
+ }
208
+ """
src/generation.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+ import subprocess
3
+ import torch
4
+
5
+ from src.visualizer import save_xyz_file
6
+
7
+ N_SAMPLES = 5
8
+
9
+
10
+ def generate_linkers(ddpm, data, sample_fn, name):
11
+ chain, node_mask = ddpm.sample_chain(data, sample_fn=sample_fn, keep_frames=1)
12
+ print('Generated linker')
13
+ x = chain[0][:, :, :ddpm.n_dims]
14
+ h = chain[0][:, :, ddpm.n_dims:]
15
+
16
+ # Put the molecule back to the initial orientation
17
+ pos_masked = data['positions'] * data['fragment_mask']
18
+ N = data['fragment_mask'].sum(1, keepdims=True)
19
+ mean = torch.sum(pos_masked, dim=1, keepdim=True) / N
20
+ x = x + mean * node_mask
21
+
22
+ names = [f'output_{i + 1}_{name}' for i in range(N_SAMPLES)]
23
+ save_xyz_file('results', h, x, node_mask, names=names, is_geom=True, suffix='')
24
+ print('Saved XYZ files')
25
+
26
+
27
+ def try_to_convert_to_sdf(name):
28
+ out_files = []
29
+ for i in range(N_SAMPLES):
30
+ out_xyz = f'results/output_{i + 1}_{name}_.xyz'
31
+ out_sdf = f'results/output_{i + 1}_{name}_.sdf'
32
+ subprocess.run(f'obabel {out_xyz} -O {out_sdf}', shell=True)
33
+ if os.path.exists(out_sdf):
34
+ out_files.append(out_sdf)
35
+ else:
36
+ out_files.append(out_xyz)
37
+
38
+ return out_files