Spaces:
Paused
Paused
import torch | |
import torchvision.transforms as transforms | |
import matplotlib | |
import matplotlib.pyplot as plt | |
import gradio as gr | |
from math2latex.data import Tokenizer | |
from math2latex.model import ResNetTransformer | |
# Global variables to hold the setup components | |
model, tokenizer = None, None | |
def get_formulas(filename): | |
with open(filename, 'r') as f: | |
formulas = f.readlines() | |
return formulas | |
def latex2image(latex_expression, image_name, image_size_in=(3, 0.6), fontsize=12, dpi=200): | |
# Runtime Configuration Parameters | |
matplotlib.rcParams["mathtext.fontset"] = "cm" # Font changed to Computer Modern | |
# matplotlib.rcParams['text.usetex'] = True # Use LaTeX to write all text | |
fig = plt.figure(figsize=image_size_in, dpi=dpi) | |
text = fig.text( | |
x=0.5, | |
y=0.5, | |
s=latex_expression, | |
horizontalalignment="center", | |
verticalalignment="center", | |
fontsize=fontsize, | |
) | |
plt.savefig(image_name) | |
plt.close(fig) | |
def setup(): | |
global model, tokenizer | |
# setup the model | |
checkpoint_path = 'checkpoints/model.ckpt' | |
model = ResNetTransformer() | |
state_dict = torch.load(checkpoint_path, map_location='cpu')['state_dict'] | |
state_dict = {k.replace('model.', ''): v for k, v in state_dict.items()} | |
model.load_state_dict(state_dict) | |
model.to("cpu") | |
model.eval() | |
# # setup the tokenizer | |
formulas = get_formulas('dataset/im2latex_formulas.norm.processed.lst') | |
tokenizer = Tokenizer(formulas) | |
def predict_image(image): | |
global model, tokenizer | |
if model is None or tokenizer is None: | |
setup() | |
transform = transforms.ToTensor() | |
image = transform(image) | |
image = image.unsqueeze(0) | |
with torch.no_grad(): | |
output = model.predict(image) | |
tokens = tokenizer.decode(output[0].tolist()) | |
return tokens | |
def predict_and_convert_to_image(image): | |
latex_code = predict_image(image) | |
image_name = 'temp.png' | |
latex_code_modified = latex_code.replace(" ", "") # Remove spaces from the LaTeX code | |
latex_code_modified = rf"""${latex_code_modified}$""" | |
latex2image(latex_code_modified, image_name) | |
# Return both the LaTeX code and the path of the generated image | |
return latex_code, image_name | |
def main(): | |
setup() | |
examples = [ | |
["dataset/formula_images_processed/78228211ca.png"], | |
["dataset/formula_images_processed/2b891b21ac.png"], | |
["dataset/formula_images_processed/a8ec0c091c.png"], | |
] | |
demo = gr.Interface( | |
fn=predict_and_convert_to_image, | |
inputs='image', | |
outputs=['text', 'image'], | |
# examples=examples, | |
title='Image to LaTeX code', | |
description='Convert an image of a mathematical formula to LaTeX code and view the result as an image. Upload an image of a formula to get both the LaTeX code and the corresponding image or use the examples provided.' | |
) | |
demo.launch() | |
if __name__ == "__main__": | |
main() | |