|
import logging |
|
import pathlib |
|
import gradio as gr |
|
import pandas as pd |
|
from gt4sd.algorithms.generation.hugging_face import ( |
|
HuggingFaceSeq2SeqGenerator, |
|
HuggingFaceGenerationAlgorithm, |
|
) |
|
from transformers import AutoTokenizer |
|
|
|
logger = logging.getLogger(__name__) |
|
logger.addHandler(logging.NullHandler()) |
|
|
|
task2prefix = { |
|
"forward": "Predict the product of the following reaction: ", |
|
"retrosynthesis": "Predict the reaction that produces the following product: ", |
|
"paragraph to actions": "Which actions are described in the following paragraph: ", |
|
"molecular captioning": "Caption the following smile: ", |
|
"text-conditional de novo generation": "Write in SMILES the described molecule: ", |
|
} |
|
|
|
|
|
def run_inference( |
|
model_name_or_path: str, |
|
task: str, |
|
prompt: str, |
|
num_beams: int, |
|
): |
|
instruction = task2prefix[task] |
|
|
|
config = HuggingFaceSeq2SeqGenerator( |
|
algorithm_version=model_name_or_path, |
|
prefix=instruction, |
|
prompt=prompt, |
|
num_beams=num_beams, |
|
) |
|
|
|
model = HuggingFaceGenerationAlgorithm(config) |
|
tokenizer = AutoTokenizer.from_pretrained("t5-small") |
|
|
|
text = list(model.sample(1))[0] |
|
|
|
text = text.replace(instruction + prompt, "") |
|
text = text.split(tokenizer.eos_token)[0] |
|
text = text.replace(tokenizer.pad_token, "") |
|
text = text.strip() |
|
|
|
return text |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
models = [ |
|
"text-chem-t5-small-standard", |
|
"text-chem-t5-small-augm", |
|
"text-chem-t5-base-standard", |
|
"text-chem-t5-base-augm", |
|
] |
|
|
|
metadata_root = pathlib.Path(__file__).parent.joinpath("model_cards") |
|
|
|
examples = pd.read_csv(metadata_root.joinpath("examples.csv"), header=None).fillna( |
|
"" |
|
) |
|
print("Examples: ", examples.values.tolist()) |
|
|
|
with open(metadata_root.joinpath("article.md"), "r") as f: |
|
article = f.read() |
|
with open(metadata_root.joinpath("description.md"), "r") as f: |
|
description = f.read() |
|
|
|
demo = gr.Interface( |
|
fn=run_inference, |
|
title="Text+chem-T5 model", |
|
inputs=[ |
|
gr.Dropdown( |
|
models, |
|
label="Language model", |
|
value="text-chem-t5-base-augm", |
|
), |
|
gr.Radio( |
|
choices=[ |
|
"forward", |
|
"retrosynthesis", |
|
"paragraph to actions", |
|
"molecular captioning", |
|
"text-conditional de novo generation", |
|
], |
|
label="Task", |
|
value="paragraph to actions", |
|
), |
|
gr.Textbox( |
|
label="Text prompt", |
|
placeholder="I'm a stochastic parrot.", |
|
lines=1, |
|
), |
|
gr.Slider(minimum=1, maximum=50, value=10, label="num_beams", step=1), |
|
], |
|
outputs=gr.Textbox(label="Output"), |
|
article=article, |
|
description=description, |
|
examples=examples.values.tolist(), |
|
) |
|
demo.launch(debug=True, show_error=True) |
|
|