RxnScribe / app.py
thomas0809
change color
1d479e0
raw
history blame
3.08 kB
import gradio as gr
import os
import glob
import cv2
import numpy as np
import torch
from rxnscribe import RxnScribe
from huggingface_hub import hf_hub_download
REPO_ID = "yujieq/RxnScribe"
FILENAME = "pix2seq_reaction_full.ckpt"
ckpt_path = hf_hub_download(REPO_ID, FILENAME)
device = torch.device('cpu')
model = RxnScribe(ckpt_path, device)
def get_markdown(reaction):
output = []
for x in ['reactants', 'conditions', 'products']:
s = ''
for ent in reaction[x]:
if 'smiles' in ent:
s += "\n```\n" + ent['smiles'] + "\n```\n"
elif 'text' in ent:
s += ' '.join(ent['text']) + '<br>'
else:
s += ent['category']
output.append(s)
return output
def predict(image, molscribe, ocr):
predictions = model.predict_image(image, molscribe=molscribe, ocr=ocr)
pred_image = model.draw_predictions_combined(predictions, image=image)
markdown = [[i] + get_markdown(reaction) for i, reaction in enumerate(predictions)]
return pred_image, markdown
with gr.Blocks() as demo:
gr.Markdown("""
<center> <h1>RxnScribe</h1> </center>
Extract chemical reactions from a diagram. Please upload a reaction diagram, RxnScribe will predict the reaction structures in the diagram.
The predicted reactions are visualized in separate images.
<b style="color:red">Red boxes are <i><u style="color:red">reactants</u></i>.</b>
<b style="color:green">Green boxes are <i><u style="color:green">reaction conditions</u></i>.</b>
<b style="color:blue">Blue boxes are <i><u style="color:blue">products</u></i>.</b>
It usually takes 5-10 seconds to process a diagram with this demo.
Check the options to run [MolScribe](https://huggingface.co/spaces/yujieq/MolScribe) and [OCR](https://huggingface.co/spaces/tomofi/EasyOCR) (it will take a longer time, of course).
Code: https://github.com/thomas0809/RxnScribe
Authors: [Yujie Qian](mailto:yujieq@csail.mit.edu), Jiang Guo, Zhengkai Tu, Connor W. Coley, Regina Barzilay. _MIT CSAIL_.
""")
with gr.Column():
with gr.Row():
image = gr.Image(label="Upload reaction diagram", show_label=False, type='pil').style(height=256)
with gr.Row():
molscribe = gr.Checkbox(label="Run MolScribe to recognize molecule structures")
ocr = gr.Checkbox(label="Run OCR to recognize text")
btn = gr.Button("Submit").style(full_width=False)
with gr.Row():
gallery = gr.Image(label='Predicted reactions', show_label=True).style(height="auto")
markdown = gr.Dataframe(
headers=['#', 'reactant', 'condition', 'product'],
datatype=['number'] + ['markdown'] * 3,
wrap=False
)
btn.click(predict, inputs=[image, molscribe, ocr], outputs=[gallery, markdown])
gr.Examples(
examples=sorted(glob.glob('examples/*.png')),
inputs=[image],
outputs=[gallery, markdown],
fn=predict,
)
demo.launch()