jannisborn commited on
Commit
b68abc1
1 Parent(s): 5da68a0
Files changed (3) hide show
  1. README.md +1 -1
  2. app.py +38 -27
  3. model_cards/mol_dct.pkl +0 -0
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: GT4SD - Diffusers (image)
3
  emoji: 💡
4
  colorFrom: green
5
  colorTo: blue
 
1
  ---
2
+ title: GT4SD - GeoDiff
3
  emoji: 💡
4
  colorFrom: green
5
  colorTo: blue
app.py CHANGED
@@ -1,36 +1,46 @@
1
  import logging
2
  import pathlib
 
3
  import gradio as gr
 
4
  import pandas as pd
5
  from gt4sd.algorithms.generation.diffusion import (
6
  DiffusersGenerationAlgorithm,
7
- DDPMGenerator,
8
- DDIMGenerator,
9
- ScoreSdeGenerator,
10
- LDMTextToImageGenerator,
11
- LDMGenerator,
12
- StableDiffusionGenerator,
13
  )
14
  from gt4sd.algorithms.registry import ApplicationsRegistry
 
 
15
 
16
  logger = logging.getLogger(__name__)
17
  logger.addHandler(logging.NullHandler())
18
 
19
 
20
- def run_inference(model_type: str, prompt: str):
 
 
 
 
 
21
 
22
- if prompt == "":
23
- config = eval(f"{model_type}()")
 
 
 
 
24
  else:
25
- config = eval(f"{model_type}(prompt={prompt})")
26
- if config.modality != "token2image" and prompt != "":
27
- raise ValueError(
28
- f"{model_type} is an unconditional generative model, please remove prompt (not={prompt})"
29
- )
 
30
  model = DiffusersGenerationAlgorithm(config)
31
- image = list(model.sample(1))[0]
 
32
 
33
- return image
34
 
35
 
36
  if __name__ == "__main__":
@@ -38,17 +48,16 @@ if __name__ == "__main__":
38
  # Preparation (retrieve all available algorithms)
39
  all_algos = ApplicationsRegistry.list_available()
40
  algos = [
41
- x["algorithm_application"]
42
- for x in list(filter(lambda x: "Diff" in x["algorithm_name"], all_algos))
 
 
43
  ]
44
- algos = [a for a in algos if not "GeoDiff" in a]
45
 
46
  # Load metadata
47
  metadata_root = pathlib.Path(__file__).parent.joinpath("model_cards")
48
 
49
- examples = pd.read_csv(metadata_root.joinpath("examples.csv"), header=None).fillna(
50
- ""
51
- )
52
 
53
  with open(metadata_root.joinpath("article.md"), "r") as f:
54
  article = f.read()
@@ -57,16 +66,18 @@ if __name__ == "__main__":
57
 
58
  demo = gr.Interface(
59
  fn=run_inference,
60
- title="Diffusion-based image generators",
61
  inputs=[
62
  gr.Dropdown(
63
- algos, label="Diffusion model", value="StableDiffusionGenerator"
64
  ),
65
- gr.Textbox(label="Text prompt", placeholder="A blue tree", lines=1),
 
 
66
  ],
67
- outputs=gr.outputs.Image(type="pil"),
68
  article=article,
69
  description=description,
70
- examples=examples.values.tolist(),
71
  )
72
  demo.launch(debug=True, show_error=True)
 
1
  import logging
2
  import pathlib
3
+ import pickle
4
  import gradio as gr
5
+ from typing import Dict, Any
6
  import pandas as pd
7
  from gt4sd.algorithms.generation.diffusion import (
8
  DiffusersGenerationAlgorithm,
9
+ GeoDiffGenerator,
 
 
 
 
 
10
  )
11
  from gt4sd.algorithms.registry import ApplicationsRegistry
12
+ from utils import draw_grid_generate
13
+ from rdkit import Chem
14
 
15
  logger = logging.getLogger(__name__)
16
  logger.addHandler(logging.NullHandler())
17
 
18
 
19
+ def run_inference(
20
+ algorithm_version: str,
21
+ prompt_file: str,
22
+ prompt_id: int,
23
+ number_of_samples: int,
24
+ ):
25
 
26
+ # Read file:
27
+ with open(prompt_file.name, "rb") as f:
28
+ prompts = pickle.load(f)
29
+
30
+ if all(isinstance(x, str) for x in prompts.keys()):
31
+ prompt = prompts[prompt_id]
32
  else:
33
+ prompt = prompts
34
+
35
+ config = GeoDiffGenerator(
36
+ algorithm_version=algorithm_version,
37
+ prompt=prompt,
38
+ )
39
  model = DiffusersGenerationAlgorithm(config)
40
+ results = list(model.sample(number_of_samples))
41
+ smiles = [Chem.MolToSmiles(m) for m in results]
42
 
43
+ return draw_grid_generate(samples=smiles, n_cols=5)
44
 
45
 
46
  if __name__ == "__main__":
 
48
  # Preparation (retrieve all available algorithms)
49
  all_algos = ApplicationsRegistry.list_available()
50
  algos = [
51
+ x["algorithm_version"]
52
+ for x in list(
53
+ filter(lambda x: "GeoDiff" in x["algorithm_application"], all_algos)
54
+ )
55
  ]
 
56
 
57
  # Load metadata
58
  metadata_root = pathlib.Path(__file__).parent.joinpath("model_cards")
59
 
60
+ examples = [[algos[0], metadata_root.joinpath("mol_dct.pkl"), 2]]
 
 
61
 
62
  with open(metadata_root.joinpath("article.md"), "r") as f:
63
  article = f.read()
 
66
 
67
  demo = gr.Interface(
68
  fn=run_inference,
69
+ title="GeoDiff",
70
  inputs=[
71
  gr.Dropdown(
72
+ algos, label="GeoDiff version", value="fusing/gfn-molecule-gen-drugs"
73
  ),
74
+ gr.File(file_types=[".pkl"], label="GeoDiff prompt"),
75
+ gr.Number(value=0, label="Prompt ID", precision=0),
76
+ gr.Slider(minimum=1, maximum=5, value=2, label="Number of samples", step=1),
77
  ],
78
+ outputs=gr.HTML(label="Output"),
79
  article=article,
80
  description=description,
81
+ examples=examples,
82
  )
83
  demo.launch(debug=True, show_error=True)
model_cards/mol_dct.pkl ADDED
Binary file (129 kB). View file