import gradio as gr
import os
import glob
import cv2
import numpy as np
import torch
from rxnscribe import RxnScribe
from huggingface_hub import hf_hub_download
REPO_ID = "yujieq/RxnScribe"
FILENAME = "pix2seq_reaction_full.ckpt"
ckpt_path = hf_hub_download(REPO_ID, FILENAME)
device = torch.device('cpu')
model = RxnScribe(ckpt_path, device)
def get_markdown(reaction):
output = []
for x in ['reactants', 'conditions', 'products']:
s = ''
for ent in reaction[x]:
if 'smiles' in ent:
s += ent['smiles'] + '
'
elif 'text' in ent:
s += ' '.join(ent['text']) + '
'
else:
s += ent['category']
output.append(s)
return output
def predict(image, molscribe, ocr):
predictions = model.predict_image(image, molscribe=molscribe, ocr=ocr)
pred_images = model.draw_predictions(predictions, image=image)
markdown = [[i] + get_markdown(reaction) for i, reaction in enumerate(predictions)]
return pred_images, markdown
with gr.Blocks() as demo:
with gr.Column():
with gr.Row():
image = gr.Image(label="Upload reaction diagram", show_label=False, type='pil').style(height=256)
with gr.Row():
molscribe = gr.Checkbox(label="Run MolScribe to recognize molecule structures")
ocr = gr.Checkbox(label="Run OCR to recognize text")
btn = gr.Button("Submit").style(full_width=False)
with gr.Row():
gallery = gr.Gallery(
label="Predicted reactions", show_label=False, elem_id="gallery"
).style(height="auto")
markdown = gr.Dataframe(
headers=['#', 'reactant', 'condition', 'product'],
datatype=['number'] + ['markdown'] * 3,
wrap=False
)
btn.click(predict, inputs=[image, molscribe, ocr], outputs=[gallery, markdown])
demo.launch()