jannisborn commited on
Commit
188d00f
1 Parent(s): 7d76d6f
app.py CHANGED
@@ -3,7 +3,11 @@ import pathlib
3
 
4
  import gradio as gr
5
  import pandas as pd
6
- from gt4sd.algorithms.generation.moler import MoLeR, MoLeRDefaultGenerator
 
 
 
 
7
 
8
  from gt4sd.algorithms.registry import ApplicationsRegistry
9
  from utils import draw_grid_generate
@@ -14,26 +18,19 @@ logger.addHandler(logging.NullHandler())
14
  TITLE = "MoLeR"
15
 
16
 
17
- def run_inference(
18
- algorithm_version: str,
19
- scaffolds: str,
20
- beam_size: int,
21
- number_of_samples: int,
22
- seed: int,
23
- ):
24
- config = MoLeRDefaultGenerator(
25
- algorithm_version=algorithm_version,
26
- scaffolds=scaffolds,
27
- beam_size=beam_size,
28
- num_samples=4,
29
- seed=seed,
30
- num_workers=1,
31
- )
32
- model = MoLeR(configuration=config)
33
  samples = list(model.sample(number_of_samples))
34
 
35
- seed_mols = [] if scaffolds == "" else scaffolds.split(".")
36
- return draw_grid_generate(seed_mols, samples)
37
 
38
 
39
  if __name__ == "__main__":
@@ -42,7 +39,7 @@ if __name__ == "__main__":
42
  all_algos = ApplicationsRegistry.list_available()
43
  algos = [
44
  x["algorithm_version"]
45
- for x in list(filter(lambda x: TITLE in x["algorithm_name"], all_algos))
46
  ]
47
 
48
  # Load metadata
@@ -59,19 +56,15 @@ if __name__ == "__main__":
59
 
60
  demo = gr.Interface(
61
  fn=run_inference,
62
- title="MoLeR (MOlecule-LEvel Representation)",
63
  inputs=[
64
- gr.Dropdown(algos, label="Algorithm version", value="v0"),
65
- gr.Textbox(
66
- label="Scaffolds",
67
- placeholder="CC(C#C)N(C)C(=O)NC1=CC=C(Cl)C=C1",
68
- lines=1,
69
  ),
70
- gr.Slider(minimum=1, maximum=5, value=1, step=1, label="Beam_size"),
71
  gr.Slider(
72
  minimum=1, maximum=50, value=10, label="Number of samples", step=1
73
  ),
74
- gr.Number(value=42, label="Seed", precision=0),
75
  ],
76
  outputs=gr.HTML(label="Output"),
77
  article=article,
 
3
 
4
  import gradio as gr
5
  import pandas as pd
6
+ from gt4sd.algorithms.generation.torchdrug import (
7
+ TorchDrugGenerator,
8
+ TorchDrugGCPN,
9
+ TorchDrugGraphAF,
10
+ )
11
 
12
  from gt4sd.algorithms.registry import ApplicationsRegistry
13
  from utils import draw_grid_generate
 
18
  TITLE = "MoLeR"
19
 
20
 
21
+ def run_inference(algorithm: str, algorithm_version: str, number_of_samples: int):
22
+
23
+ if algorithm == "GCPN":
24
+ config = TorchDrugGCPN(algorithm_version=algorithm_version)
25
+ elif algorithm == "GraphAF":
26
+ config = TorchDrugGraphAF(algorithm_version=algorithm_version)
27
+ else:
28
+ raise ValueError(f"Unsupported model {algorithm}.")
29
+
30
+ model = TorchDrugGenerator(configuration=config)
 
 
 
 
 
 
31
  samples = list(model.sample(number_of_samples))
32
 
33
+ return draw_grid_generate(samples=samples, n_cols=5)
 
34
 
35
 
36
  if __name__ == "__main__":
 
39
  all_algos = ApplicationsRegistry.list_available()
40
  algos = [
41
  x["algorithm_version"]
42
+ for x in list(filter(lambda x: "TorchDrug" in x["algorithm_name"], all_algos))
43
  ]
44
 
45
  # Load metadata
 
56
 
57
  demo = gr.Interface(
58
  fn=run_inference,
59
+ title="TorchDrug (GCPN and GraphAF)",
60
  inputs=[
61
+ gr.Dropdown(["GCPN", "GraphAF"], label="Algorithm", value="GCPN"),
62
+ gr.Dropdown(
63
+ list(set(algos)), label="Algorithm version", value="zinc250k_v0"
 
 
64
  ),
 
65
  gr.Slider(
66
  minimum=1, maximum=50, value=10, label="Number of samples", step=1
67
  ),
 
68
  ],
69
  outputs=gr.HTML(label="Output"),
70
  article=article,
model_cards/article.md CHANGED
@@ -1,37 +1,37 @@
1
  # Model documentation & parameters
2
 
3
- **Algorithm Version**: Which model checkpoint to use (trained on different datasets).
4
 
5
- **Scaffolds**: One or multiple scaffolds (or seed molecules), provided as '.'-separated SMILES. If empty, no scaffolds are used.
6
 
7
  **Number of samples**: How many samples should be generated (between 1 and 50).
8
 
9
- **Beam size**: Beam size used in beam search decoding (the higher the slower but better).
10
-
11
- **Seed**: The random seed used for initialization.
12
-
13
 
14
- # Model card
15
 
16
- **Model Details**: MoLeR is a graph-based molecular generative model that can be conditioned (primed) on scaffolds. The model decorates scaffolds with realistic structural motifs.
17
 
18
- **Developers**: Krzysztof Maziarz and co-authors from Microsoft Research and Novartis (full reference at bottom).
19
 
20
- **Distributors**: Developer's code wrapped and distributed by GT4SD Team (2023) from IBM Research.
21
 
22
- **Model date**: Released around March 2022.
23
 
24
- **Model version**: Model provided by original authors, see [their GitHub repo](https://github.com/microsoft/molecule-generation).
 
 
 
25
 
26
- **Model type**: An encoder-decoder-based GNN for molecular generation.
27
 
28
- **Information about training algorithms, parameters, fairness constraints or other applied approaches, and features**: Trained by the original authors with the default parameters provided [on GitHub](https://github.com/microsoft/molecule-generation).
29
 
30
- **Paper or other resource for more information**: Learning to Extend Molecular Scaffolds with Structural Motifs (ICLR 2022).
 
31
 
32
- **License**: MIT
33
 
34
- **Where to send questions or comments about the model**: Open an issue on original author's [GitHub repository](https://github.com/microsoft/molecule-generation).
35
 
36
  **Intended Use. Use cases that were envisioned during development**: Chemical research, in particular drug discovery.
37
 
@@ -41,9 +41,9 @@
41
 
42
  **Factors**: Not applicable.
43
 
44
- **Metrics**: Validation loss on decoding correct molecules. Evaluated on several downstream tasks.
45
 
46
- **Datasets**: 1.5M drug-like molecules from GuacaMol benchmark. Finetuning on 20 molecular optimization tasks from GuacaMol.
47
 
48
  **Ethical Considerations**: Unclear, please consult with original authors in case of questions.
49
 
@@ -54,12 +54,12 @@ Model card prototype inspired by [Mitchell et al. (2019)](https://dl.acm.org/doi
54
  ## Citation
55
 
56
  ```bib
57
- @inproceedings{maziarz2021learning,
58
- author={Krzysztof Maziarz and Henry Richard Jackson{-}Flux and Pashmina Cameron and
59
- Finton Sirockin and Nadine Schneider and Nikolaus Stiefl and Marwin H. S. Segler and Marc Brockschmidt},
60
- title = {Learning to Extend Molecular Scaffolds with Structural Motifs},
61
- booktitle = {The Tenth International Conference on Learning Representations, {ICLR}},
62
- year = {2022}
63
  }
64
  ```
65
 
 
1
  # Model documentation & parameters
2
 
3
+ **Algorithm**: Which model to use (GCPN or GraphAF).
4
 
5
+ **Algorithm Version**: Which model checkpoint to use (trained on different datasets).
6
 
7
  **Number of samples**: How many samples should be generated (between 1 and 50).
8
 
 
 
 
 
9
 
10
+ # Model card -- GCPN
11
 
12
+ **Model Details**: GCPN is a graph-based molecular generative model that can be optimized with RL for goal-directed graph generation.
13
 
14
+ **Developers**: Jiaxuan You and co-authors from Stanford.
15
 
16
+ **Distributors**: Code provided by TorchDrug developers, wrapped and distributed by GT4SD Team (2023) from IBM Research.
17
 
18
+ **Model date**: Published in 2018.
19
 
20
+ **Model version**: Models trained by GT4SD team on the tasks provided by TorchDrug repo [(see their tutorial)](https://torchdrug.ai/docs/tutorials/generation.html).
21
+ - **ZINC_250k**: 250,000 drug-like molecules with a maximum atom number of 38, taken from [ZINC](https://zinc.docking.org).
22
+ - **QED**: ZINC dataset, but the model was optimized with Proximal Policy Optimization (PPO) to generate molecules with high QED scores.
23
+ - **pLogP**: ZINC dataset, but the model was optimized with Proximal Policy Optimization (PPO) to generate molecules with high pLogP scores.
24
 
25
+ **Model type**: A graph-based molecular generative model that can be optimized with RL for goal-directed graph generation.
26
 
27
+ **Information about training algorithms, parameters, fairness constraints or other applied approaches, and features**: Default parameters as provided in [(TorchDrug tutorial)](https://torchdrug.ai/docs/tutorials/generation.html).
28
 
29
+ **Paper or other resource for more information**: [Graph Convolutional Policy Network for
30
+ Goal-Directed Molecular Graph Generation (NeurIPS 2018)](https://proceedings.neurips.cc/paper/2018/file/d60678e8f2ba9c540798ebbde31177e8-Paper.pdf).
31
 
32
+ **License**: TorchDrug: Apache-2.0 license.
33
 
34
+ **Where to send questions or comments about the model**: Open an issue on [TorchDrug repository](https://github.com/DeepGraphLearning/torchdrug).
35
 
36
  **Intended Use. Use cases that were envisioned during development**: Chemical research, in particular drug discovery.
37
 
 
41
 
42
  **Factors**: Not applicable.
43
 
44
+ **Metrics**: Validation loss on decoding correct molecules.
45
 
46
+ **Datasets**: 250,000 drug-like molecules from [ZINC](https://zinc.docking.org) (with a maximum atom number of 38).
47
 
48
  **Ethical Considerations**: Unclear, please consult with original authors in case of questions.
49
 
 
54
  ## Citation
55
 
56
  ```bib
57
+ @article{you2018graph,
58
+ title={Graph convolutional policy network for goal-directed molecular graph generation},
59
+ author={You, Jiaxuan and Liu, Bowen and Ying, Zhitao and Pande, Vijay and Leskovec, Jure},
60
+ journal={Advances in neural information processing systems},
61
+ volume={31},
62
+ year={2018}
63
  }
64
  ```
65
 
model_cards/description.md CHANGED
@@ -1,6 +1,10 @@
1
  <img align="right" src="https://raw.githubusercontent.com/GT4SD/gt4sd-core/main/docs/_static/gt4sd_logo.png" alt="logo" width="120" >
2
 
3
- MoLeR (Maziarz et al., (2022), *ICLR*) is a graph-based molecular generative model that can be conditioned (primed) on scaffolds. This model r is provided and distributed by the **GT4SD** (Generative Toolkit for Scientific Discovery).
 
 
 
 
4
 
5
  For **examples** and **documentation** of the model parameters, please see below.
6
  Moreover, we provide a **model card** ([Mitchell et al. (2019)](https://dl.acm.org/doi/abs/10.1145/3287560.3287596?casa_token=XD4eHiE2cRUAAAAA:NL11gMa1hGPOUKTAbtXnbVQBDBbjxwcjGECF_i-WC_3g1aBgU1Hbz_f2b4kI_m1in-w__1ztGeHnwHs)) at the bottom of this page.
 
1
  <img align="right" src="https://raw.githubusercontent.com/GT4SD/gt4sd-core/main/docs/_static/gt4sd_logo.png" alt="logo" width="120" >
2
 
3
+
4
+ [TorchDrug](https://github.com/DeepGraphLearning/torchdrug) is a PyTorch toolbox on graph models for drug discovery.
5
+ We, the developers of **GT4SD** (Generative Toolkit for Scientific Discovery), provide access to two graph-based molecular generative models distributed by TorchDrug:
6
+ - **GCPN**: Graph Convolutional Policy Network ([You et al., (2018), *NeurIPS*](https://proceedings.neurips.cc/paper/2018/hash/d60678e8f2ba9c540798ebbde31177e8-Abstract.html))
7
+ - **GraphAF**: GraphAF: a Flow-based Autoregressive Model for Molecular Graph Generation ([Shi et al., (2020), *ICLR*](https://openreview.net/forum?id=S1esMkHYPr))
8
 
9
  For **examples** and **documentation** of the model parameters, please see below.
10
  Moreover, we provide a **model card** ([Mitchell et al. (2019)](https://dl.acm.org/doi/abs/10.1145/3287560.3287596?casa_token=XD4eHiE2cRUAAAAA:NL11gMa1hGPOUKTAbtXnbVQBDBbjxwcjGECF_i-WC_3g1aBgU1Hbz_f2b4kI_m1in-w__1ztGeHnwHs)) at the bottom of this page.
model_cards/examples.csv CHANGED
@@ -1,5 +1,4 @@
1
- v0,,1,4,0
2
- v0,CC(=O)NC1=NC2=CC(OCC3=CC=CN(CC4=CC=C(Cl)C=C4)C3=O)=CC=C2N1,1,10,0
3
- v0,C12C=CC=NN1C(C#CC1=C(C)C=CC3C(NC4=CC(C(F)(F)F)=CC=C4)=NOC1=3)=CN=2.CCO,3,5,5
4
-
5
 
 
1
+ GCPN_zinc250k_v0,5
2
+ GCPN_qed_v0,10
3
+ GraphAF_plogp_v0,5
 
4
 
utils.py CHANGED
@@ -1,21 +1,17 @@
1
- import json
2
  import logging
3
- import os
4
  from collections import defaultdict
5
- from typing import Dict, List, Tuple
6
 
7
  import mols2grid
8
  import pandas as pd
9
- from rdkit import Chem
10
- from terminator.selfies import decoder
11
 
12
  logger = logging.getLogger(__name__)
13
  logger.addHandler(logging.NullHandler())
14
 
15
 
16
  def draw_grid_generate(
17
- seeds: List[str],
18
  samples: List[str],
 
19
  n_cols: int = 3,
20
  size=(140, 200),
21
  ) -> str:
 
 
1
  import logging
 
2
  from collections import defaultdict
3
+ from typing import List
4
 
5
  import mols2grid
6
  import pandas as pd
 
 
7
 
8
  logger = logging.getLogger(__name__)
9
  logger.addHandler(logging.NullHandler())
10
 
11
 
12
  def draw_grid_generate(
 
13
  samples: List[str],
14
+ seeds: List[str] = [],
15
  n_cols: int = 3,
16
  size=(140, 200),
17
  ) -> str: