Edit model card

You need to agree to share your contact information to access this model

This repository is publicly accessible, but you have to accept the conditions to access its files and content.

Log in or Sign Up to review the conditions and access this model content.

github: https://github.com/chongzicbo/MathImg2Latex/tree/main

import torch
from PIL import Image
from transformers import VisionEncoderDecoderModel
from transformers.models.nougat import NougatTokenizerFast
from nougat_latex.util import process_raw_latex_code
from nougat_latex import NougatLaTexProcessor
os.environ["RUN_ON_GPU_IDs"] = "1"
device = torch.device("cpu")
model_path = "chongzicbo/MathImg2Latex"
tokenizer = NougatTokenizerFast.from_pretrained(model_path)
latex_processor = NougatLaTexProcessor.from_pretrained(model_path)
model = VisionEncoderDecoderModel.from_pretrained(model_path)
model.to(device)
img_path = "/data/code/MathOCR/MathImg2Latex/examples/test_data/test86_screenshot_bigger.jpg"
image = Image.open(img_path)
if not image.mode == "RGB":
    image = image.convert("RGB")

pixel_values = latex_processor(image, return_tensors="pt").pixel_values
task_prompt = tokenizer.bos_token
decoder_input_ids = tokenizer(
    task_prompt, add_special_tokens=False, return_tensors="pt"
).input_ids
with torch.no_grad():
    outputs = model.generate(
        pixel_values.to(device),
        decoder_input_ids=decoder_input_ids.to(device),
        max_length=model.decoder.config.max_length,
        early_stopping=True,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id,
        use_cache=True,
        num_beams=1,
        bad_words_ids=[[tokenizer.unk_token_id]],
        return_dict_in_generate=True,
    )
sequence = tokenizer.batch_decode(outputs.sequences)[0]
sequence = (
    sequence.replace(tokenizer.eos_token, "")
    .replace(tokenizer.pad_token, "")
    .replace(tokenizer.bos_token, "")
)
sequence = process_raw_latex_code(sequence)
print(sequence)
Downloads last month
0
Safetensors
Model size
349M params
Tensor type
I64
·
F32
·
Inference API
Inference API (serverless) does not yet support transformers models for this pipeline type.