Arabic-Nougat / app.py
MohamedRashad's picture
Update app.py
f6f530b verified
from transformers import (
NougatProcessor,
VisionEncoderDecoderModel,
TextIteratorStreamer,
)
import gradio as gr
import torch
from pathlib import Path
from pdf2image import convert_from_path
import spaces
from threading import Thread
from gradio_pdf import PDF
import subprocess
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
models_supported = {
"arabic-small-nougat": [
NougatProcessor.from_pretrained("MohamedRashad/arabic-small-nougat"),
VisionEncoderDecoderModel.from_pretrained("MohamedRashad/arabic-small-nougat"),
],
"arabic-base-nougat": [
NougatProcessor.from_pretrained("MohamedRashad/arabic-base-nougat"),
VisionEncoderDecoderModel.from_pretrained(
"MohamedRashad/arabic-base-nougat",
torch_dtype=torch.bfloat16,
attn_implementation={"decoder": "flash_attention_2", "encoder": "eager"},
),
],
"arabic-large-nougat": [
NougatProcessor.from_pretrained("MohamedRashad/arabic-large-nougat"),
VisionEncoderDecoderModel.from_pretrained(
"MohamedRashad/arabic-large-nougat",
torch_dtype=torch.bfloat16,
attn_implementation={"decoder": "flash_attention_2", "encoder": "eager"},
),
],
}
@spaces.GPU
def extract_text_from_image(image, model_name):
print(f"Extracting text from image using model: {model_name}")
processor, model = models_supported[model_name]
context_length = model.decoder.config.max_position_embeddings
torch_dtype = model.dtype
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
pixel_values = (
processor(image, return_tensors="pt").pixel_values.to(torch_dtype).to(device)
)
streamer = TextIteratorStreamer(processor.tokenizer, skip_special_tokens=True)
# Start generation in a separate thread
generation_kwargs = {
"pixel_values": pixel_values,
"min_length": 1,
"max_new_tokens": context_length,
"repetition_penalty": 1.5,
"streamer": streamer,
}
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
# Yield tokens as they become available
output = ""
for token in streamer:
output += token
yield output
thread.join()
@spaces.GPU
def extract_text_from_pdf(pdf_path, model_name):
processor, model = models_supported[model_name]
context_length = model.decoder.config.max_position_embeddings
torch_dtype = model.dtype
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
streamer = TextIteratorStreamer(processor.tokenizer, skip_special_tokens=True)
print(f"Extracting text from PDF: {pdf_path}")
images = convert_from_path(pdf_path)
pdf_output = ""
for image in images:
pixel_values = (
processor(image, return_tensors="pt")
.pixel_values.to(torch_dtype)
.to(device)
)
# Start generation in a separate thread
generation_kwargs = {
"pixel_values": pixel_values,
"min_length": 1,
"max_new_tokens": context_length,
"repetition_penalty": 1.5,
"streamer": streamer,
}
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
# Yield tokens as they become available
for token in streamer:
pdf_output += token
yield pdf_output
thread.join()
pdf_output += "\n\n"
yield pdf_output
model_description = """This is the official demo for the Arabic Nougat models. It is an end-to-end Markdown Extraction model that extracts text from images or PDFs and write them in Markdown.
There are three models available:
- [arabic-small-nougat](https://huggingface.co/MohamedRashad/arabic-small-nougat): A small model that is faster but less accurate (a finetune from [facebook/nougat-small](https://huggingface.co/facebook/nougat-small)).
- [arabic-base-nougat](https://huggingface.co/MohamedRashad/arabic-base-nougat): A base model that is more accurate but slower (a finetune from [facebook/nougat-base](https://huggingface.co/facebook/nougat-base)).
- [arabic-large-nougat](https://huggingface.co/MohamedRashad/arabic-large-nougat): The largest of the three (Made from scratch using [riotu-lab/Aranizer-PBE-86k](https://huggingface.co/riotu-lab/Aranizer-PBE-86k) tokenizer and a larger transformer decoder model).
**Disclaimer**: These models hallucinate text and are not perfect. They are trained on a mix of synthetic and real data and may not work well on all types of images.
"""
example_images = list(Path(__file__).parent.glob("*.jpeg"))
example_pdfs = [str(p) for p in Path(__file__).parent.glob("*.pdf")]
with gr.Blocks(title="Arabic Nougat") as demo:
gr.HTML(
"<h1 style='text-align: center'>Arabic End-to-End Structured OCR for textbooks</h1>"
)
gr.Markdown(model_description)
with gr.Tab("Extract Text from Image"):
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Input Image", type="pil")
model_dropdown = gr.Dropdown(
label="Model", choices=list(models_supported.keys()), value=None
)
image_submit_button = gr.Button(value="Submit", variant="primary")
output = gr.Markdown(label="Output Markdown", rtl=True)
image_submit_button.click(
extract_text_from_image,
inputs=[input_image, model_dropdown],
outputs=output,
)
gr.Examples(
example_images,
[input_image],
output,
extract_text_from_image,
cache_examples=False,
)
with gr.Tab("Extract Text from PDF"):
with gr.Row():
with gr.Column():
input_pdf = PDF(label="Input PDF")
model_dropdown = gr.Dropdown(
label="Model", choices=list(models_supported.keys()), value=None
)
pdf_submit_button = gr.Button(value="Submit", variant="primary")
output = gr.Markdown(label="Output Markdown", rtl=True)
pdf_submit_button.click(
extract_text_from_pdf, inputs=[input_pdf, model_dropdown], outputs=output
)
gr.Examples(
example_pdfs,
[input_pdf],
output,
extract_text_from_pdf,
cache_examples=False,
)
demo.queue().launch(share=False)