3demo / app.py
sadgaj's picture
Update app.py
15611ba
raw
history blame contribute delete
No virus
2.46 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from rdkit.Chem import Draw
from rdkit import Chem
import selfies as sf
sf_output="zju"
def greet1(name):
tokenizer = AutoTokenizer.from_pretrained("zjunlp/MolGen")
model = AutoModelForSeq2SeqLM.from_pretrained("zjunlp/MolGen")
sf_input = tokenizer(name, return_tensors="pt")
# beam search
molecules = model.generate(input_ids=sf_input["input_ids"],
attention_mask=sf_input["attention_mask"],
max_length=15,
min_length=5,
num_return_sequences=4,
num_beams=5)
sf_output = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True).replace(" ","") for g in molecules]
return sf_output
def greet2(name):
tokenizer = AutoTokenizer.from_pretrained("zjunlp/MolGen")
model = AutoModelForSeq2SeqLM.from_pretrained("zjunlp/MolGen")
sf_input = tokenizer(name, return_tensors="pt")
# beam search
molecules = model.generate(input_ids=sf_input["input_ids"],
attention_mask=sf_input["attention_mask"],
max_length=15,
min_length=5,
num_return_sequences=4,
num_beams=5)
sf_output = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True).replace(" ","") for g in molecules]
smis = [sf.decoder(i) for i in sf_output]
mols = []
for smi in smis:
mol = Chem.MolFromSmiles(smi)
mols.append(mol)
img = Draw.MolsToGridImage(
mols,
molsPerRow=4,
subImgSize=(200,200),
legends=['' for x in mols]
)
return img
def greet3(name):
return name
examples = [
['[C][=C][C][=C][C][=C][Ring1][=Branch1]'],['[C]']
]
greeter_1 = gr.Interface(greet1, inputs="textbox", outputs="text")
greeter_2 = gr.Interface(greet2 , inputs="textbox", outputs="image")
#greeter_2.launch()
demo = gr.Parallel(greeter_1, greeter_2,title="Molecular Language Model as Multi-task Generator",
examples=examples)
demo.launch()
#iface = gr.Interface(fn=greet2, inputs="text", outputs="image", title="Molecular Language Model as Multi-task Generator",
# )
#iface.launch()