math2latex / app.py
msdkhairi's picture
Initial Commit
bf9d0ba
import torch
import torchvision.transforms as transforms
import matplotlib
import matplotlib.pyplot as plt
import gradio as gr
from math2latex.data import Tokenizer
from math2latex.model import ResNetTransformer
# Global variables to hold the setup components
model, tokenizer = None, None
def get_formulas(filename):
with open(filename, 'r') as f:
formulas = f.readlines()
return formulas
def latex2image(latex_expression, image_name, image_size_in=(3, 0.6), fontsize=12, dpi=200):
# Runtime Configuration Parameters
matplotlib.rcParams["mathtext.fontset"] = "cm" # Font changed to Computer Modern
# matplotlib.rcParams['text.usetex'] = True # Use LaTeX to write all text
fig = plt.figure(figsize=image_size_in, dpi=dpi)
text = fig.text(
x=0.5,
y=0.5,
s=latex_expression,
horizontalalignment="center",
verticalalignment="center",
fontsize=fontsize,
)
plt.savefig(image_name)
plt.close(fig)
def setup():
global model, tokenizer
# setup the model
checkpoint_path = 'checkpoints/model.ckpt'
model = ResNetTransformer()
state_dict = torch.load(checkpoint_path, map_location='cpu')['state_dict']
state_dict = {k.replace('model.', ''): v for k, v in state_dict.items()}
model.load_state_dict(state_dict)
model.to("cpu")
model.eval()
# # setup the tokenizer
formulas = get_formulas('dataset/im2latex_formulas.norm.processed.lst')
tokenizer = Tokenizer(formulas)
def predict_image(image):
global model, tokenizer
if model is None or tokenizer is None:
setup()
transform = transforms.ToTensor()
image = transform(image)
image = image.unsqueeze(0)
with torch.no_grad():
output = model.predict(image)
tokens = tokenizer.decode(output[0].tolist())
return tokens
def predict_and_convert_to_image(image):
latex_code = predict_image(image)
image_name = 'temp.png'
latex_code_modified = latex_code.replace(" ", "") # Remove spaces from the LaTeX code
latex_code_modified = rf"""${latex_code_modified}$"""
latex2image(latex_code_modified, image_name)
# Return both the LaTeX code and the path of the generated image
return latex_code, image_name
def main():
setup()
examples = [
["dataset/formula_images_processed/78228211ca.png"],
["dataset/formula_images_processed/2b891b21ac.png"],
["dataset/formula_images_processed/a8ec0c091c.png"],
]
demo = gr.Interface(
fn=predict_and_convert_to_image,
inputs='image',
outputs=['text', 'image'],
# examples=examples,
title='Image to LaTeX code',
description='Convert an image of a mathematical formula to LaTeX code and view the result as an image. Upload an image of a formula to get both the LaTeX code and the corresponding image or use the examples provided.'
)
demo.launch()
if __name__ == "__main__":
main()