christofid commited on
Commit
89e8857
·
1 Parent(s): dc87e4f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -12
app.py CHANGED
@@ -4,25 +4,35 @@ import gradio as gr
4
  import pandas as pd
5
  from gt4sd.algorithms.generation.hugging_face import (
6
  HuggingFaceSeq2SeqGenerator,
7
- HuggingFaceGenerationAlgorithm
8
  )
9
  from transformers import AutoTokenizer
10
 
11
  logger = logging.getLogger(__name__)
12
  logger.addHandler(logging.NullHandler())
13
 
 
 
 
 
 
 
 
 
 
14
  def run_inference(
15
  model_name_or_path: str,
16
- prefix: str,
17
  prompt: str,
18
  num_beams: int,
19
  ):
 
20
 
21
  config = HuggingFaceSeq2SeqGenerator(
22
  algorithm_version=model_name_or_path,
23
- prefix=prefix,
24
  prompt=prompt,
25
- num_beams=num_beams
26
  )
27
 
28
  model = HuggingFaceGenerationAlgorithm(config)
@@ -30,22 +40,23 @@ def run_inference(
30
 
31
  text = list(model.sample(1))[0]
32
 
33
- text = text.replace(prefix+prompt,"")
34
  text = text.split(tokenizer.eos_token)[0]
35
  text = text.replace(tokenizer.pad_token, "")
36
  text = text.strip()
37
 
38
-
39
  return text
40
 
41
 
42
  if __name__ == "__main__":
43
 
44
- # Preparation (retrieve all available algorithms)
45
- models = ["text-chem-t5-small-standard", "text-chem-t5-small-augm",
46
- "text-chem-t5-base-standard", "text-chem-t5-base-augm"]
 
 
 
47
 
48
- # Load metadata
49
  metadata_root = pathlib.Path(__file__).parent.joinpath("model_cards")
50
 
51
  examples = pd.read_csv(metadata_root.joinpath("examples.csv"), header=None).fillna(
@@ -67,8 +78,16 @@ if __name__ == "__main__":
67
  label="Language model",
68
  value="text-chem-t5-base-augm",
69
  ),
70
- gr.Textbox(
71
- label="Prefix", placeholder="A task-specific prefix", lines=1
 
 
 
 
 
 
 
 
72
  ),
73
  gr.Textbox(
74
  label="Text prompt",
 
4
  import pandas as pd
5
  from gt4sd.algorithms.generation.hugging_face import (
6
  HuggingFaceSeq2SeqGenerator,
7
+ HuggingFaceGenerationAlgorithm,
8
  )
9
  from transformers import AutoTokenizer
10
 
11
  logger = logging.getLogger(__name__)
12
  logger.addHandler(logging.NullHandler())
13
 
14
+ task2prefix = {
15
+ "forward": "Predict the product of the following reaction: ",
16
+ "retrosynthesis": "Predict the reaction that produces the following product: ",
17
+ "paragraph to actions": "Which actions are described in the following paragraph: ",
18
+ "molecular captioning": "Caption the following SMILES: ",
19
+ "text-conditional de novo generation": "Write in SMILES the described molecule: ",
20
+ }
21
+
22
+
23
  def run_inference(
24
  model_name_or_path: str,
25
+ task: str,
26
  prompt: str,
27
  num_beams: int,
28
  ):
29
+ instruction = task2prefix[task]
30
 
31
  config = HuggingFaceSeq2SeqGenerator(
32
  algorithm_version=model_name_or_path,
33
+ prefix=instruction,
34
  prompt=prompt,
35
+ num_beams=num_beams,
36
  )
37
 
38
  model = HuggingFaceGenerationAlgorithm(config)
 
40
 
41
  text = list(model.sample(1))[0]
42
 
43
+ text = text.replace(instruction + prompt, "")
44
  text = text.split(tokenizer.eos_token)[0]
45
  text = text.replace(tokenizer.pad_token, "")
46
  text = text.strip()
47
 
 
48
  return text
49
 
50
 
51
  if __name__ == "__main__":
52
 
53
+ models = [
54
+ "text-chem-t5-small-standard",
55
+ "text-chem-t5-small-augm",
56
+ "text-chem-t5-base-standard",
57
+ "text-chem-t5-base-augm",
58
+ ]
59
 
 
60
  metadata_root = pathlib.Path(__file__).parent.joinpath("model_cards")
61
 
62
  examples = pd.read_csv(metadata_root.joinpath("examples.csv"), header=None).fillna(
 
78
  label="Language model",
79
  value="text-chem-t5-base-augm",
80
  ),
81
+ gr.Radio(
82
+ choices=[
83
+ "forward",
84
+ "retrosynthesis",
85
+ "paragraph to actions",
86
+ "molecular captioning",
87
+ "text-conditional de novo generation",
88
+ ],
89
+ label="Task",
90
+ value="paragraph to actions",
91
  ),
92
  gr.Textbox(
93
  label="Text prompt",