nougat-latex / nougat_api_app.py
zphilip48
small updating env
af18b7a
import gradio as gr
import os
#os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
#os.environ["CUDA_VISIBLE_DEVICES"] = "3" # just use one GPU on big machine
import torch
print ('Available devices ', torch.cuda.device_count())
print ('Current cuda device ', torch.cuda.current_device())
assert torch.cuda.device_count() == 1
print('GPU Device name:', torch.cuda.get_device_name(torch.cuda.current_device()))
import sys
from functools import partial
from http import HTTPStatus
import cv2
from fastapi import FastAPI, File, UploadFile, Request,Response, BackgroundTasks, HTTPException
from fastapi import APIRouter, Depends
import os
import requests
import uuid
#for the user register and login
#from fastapi.templating import Jinja2Templates
#from schemas.users import UserCreate
#from sqlalchemy.orm import Session
#from sqlalchemy.exc import IntegrityError
#from webapps.users.forms import UserCreateForm
from PIL import Image
from pathlib import Path
import hashlib
from fastapi.middleware.cors import CORSMiddleware
import fitz
import torch
from torch.utils.data import ConcatDataset
from nougat import NougatModel
from nougat.postprocessing import markdown_compatible, close_envs
from nougat.utils.dataset import ImageDataset,LazyDataset
from nougat.utils.checkpoint import get_checkpoint
#from nougat.dataset.rasterize import rasterize_paper
from tqdm import tqdm
import uvicorn
import shutil
import io
import numpy as np
import logging
import pypdfium2
from typing import Optional, List, Union
from PIL import ImageOps,Image
import re
from contextlib import asynccontextmanager
from starlette.types import Message
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.background import BackgroundTask
from urllib.parse import urlparse, unquote
from utils import binarization
from sql_app import models, schemas
from sql_app import crud
from sql_app.db import get_db, engine, async_engine, async_session
from sqlalchemy.orm import Session
from sqlalchemy.ext.asyncio import AsyncSession
from sql_app.schemas import (
InferenceCreate,
InferenceUpdate,
InferenceRead,
#Inference
)
from sql_app.models import Inference,User
from hashlib import sha1,md5
from datetime import datetime
from sql_app.db import Base
import psycopg
import numpy as np
#logging.basicConfig(level=logging.INFO)
logging.basicConfig(filename='info.log', level=logging.INFO)
#logger = logging.getLogger()
#logger.setLevel(logging.INFO)
SAVE_DIR = Path("./pdfs")
BATCHSIZE = os.environ.get("NOUGAT_BATCHSIZE", 6)
NOUGAT_CHECKPOINT = get_checkpoint()
#########################################
##
## init the fastapi server
##
#########################################
global selected_model_name
# Load the ML model
def loadModel(checkpoint):
if not checkpoint.exists():
checkpoint = default_checkpoint_path
logging.info(f"request checkpoint is not exist, using default {checkpoint_name}")
model = NougatModel.from_pretrained(checkpoint).to(torch.bfloat16)
if torch.cuda.is_available():
model.to("cuda")
model.eval()
return model
# Load the ML model
checkpoint_list = [Path("./nougat_middle_ocr/"),
Path("./nougat_middle_cn/"),
Path("./nougat_small_facebook/"),
Path("./nougat_small_ocr_ocrpadded/"),
Path("./nougat_small_cn/"),
Path("./nougat_big_facebook/")]
default_checkpoint_path = Path("./nougat_middle_cn/")
default_model_name = "nougat_middle_cn"
selected_model_name = None
model_list = dict()
if model_list == {}:
logging.info(f"Start up and init the Nougat Model with {checkpoint_list} ")
for checkpoint in checkpoint_list:
checkpoint_name = os.path.basename(checkpoint)
model_list[checkpoint_name] = checkpoint
logging.info(f"loading model {checkpoint_name}")
nougatModel = loadModel(model_list.get(default_model_name, default_checkpoint_path))
if NOUGAT_CHECKPOINT is None:
logging.info(
"Set environment variable 'NOUGAT_CHECKPOINT' with a path to the model checkpoint!."
)
sys.exit(1)
if torch.cuda.is_available():
BATCH_SIZE = int(
torch.cuda.get_device_properties(0).total_memory / 1024 / 1024 / 1000 * 0.3
)
if BATCH_SIZE == 0:
logging.warning("GPU VRAM is too small. Computing on CPU.")
else:
# don't know what a good value is here. Would not recommend to run on CPU
BATCH_SIZE = 1
logging.warning("No GPU found. Conversion on CPU is very slow.")
@asynccontextmanager
async def lifespan(app: FastAPI):
# Load the ML model
#global nougatModel, model_list, default_checkpoint_path, default_model_name
yield
# Clean up the ML models and release the resources
for model in model_list:
model.clear()
app = FastAPI(lifespan=lifespan,
title="AIWorm Application",
description="AIWorm Application with FastAPI and Gradio",
version="1.0.0",)
Base.metadata.create_all(bind=engine,checkfirst=True)
def create_db_schema(engine):
Base.metadata.create_all(engine,checkfirst=True)
def init_db():
#Base.metadata.create_all(bind=engine)
db_session = next(get_db())
#async with async_session() as db_session:
db_session.add(
Inference(
#blobContent= "file name",
#textContent= "test",
fingerPrint= "12124isefuadsf"
#selectedModel= "test",
#result="test",
#published = False
#created_at = datetime.utcnow()
)
)
db_session.commit()
logging.info("Initialized the db")
origins = ["http://localhost", "http://tec1.aiworm.cn:8866", "http://127.0.0.1","https://82s56k6681.zicp.fun/","http://192.168.1.34"]
#@app.on_event("shutdown")
#def shutdown_event():
# for model in model_list:
# model.clear()
#@app.on_event("startup")
#async def load_model(
# checkpoint: str = NOUGAT_CHECKPOINT,
#):
# global model
# if model is None:
# model = NougatModel.from_pretrained(checkpoint).to(torch.bfloat16)
# if torch.cuda.is_available():
# model.to("cuda")
# model.eval()
#########################################
##
## middleware for fastapi app
##
#########################################
class MyMiddleware(BaseHTTPMiddleware):
async def set_body(self, request: Request):
receive_ = await request._receive()
body = receive_.get('body')
# Make all changes to the body object here and return the modified request
async def receive() -> Message:
receive_["body"] = body
return receive_
request._receive = receive
async def get_body(self, request: Request) -> bytes:
body = await request.body()
self.set_body(request)
return body
async def dispatch(self, request, call_next):
if request.method == 'POST':
self.set_body(request)
response = await call_next(request)
return response
''' debug code for , not test yet
def log_info(req_body, res_body):
logging.debug(req_body)
logging.debug(res_body)
async def set_body(request: Request, body: bytes):
async def receive() -> Message:
return {'type': 'http.request', 'body': body}
request._receive = receive
@app.middleware('http')
async def app_middleware(request: Request, call_next):
req_body = await request.body()
#await set_body(request, req_body)
response = await call_next(request)
res_body = b''
async for chunk in response.body_iterator:
res_body += chunk
#logger.debug(f"response: {res_body} -- {response.status_code}")
#logger.debug(f"response.headers :{response.headers}")
#logger.debug(f"response.media_type: {response.media_type}")
task = BackgroundTask(log_info, req_body, res_body)
return Response(content=res_body, status_code=response.status_code,
headers=dict(response.headers), media_type=response.media_type, background=task)
# Exception handlers
def add_exception_handlers(_app: FastAPI):
@_app.exception_handler(ApiAuthException)
async def api_auth_exception_handler(request: Request, exc: ApiAuthException):
return await handler.api_auth_exception_handler(request, exc)
@_app.exception_handler(ApiException)
async def api_exception_handler(request: Request, exc: ApiException):
return await handler.api_exception_handler(request, exc)
add_exception_handlers(main_app)
add_exception_handlers(sub_app)
from loguru import logger
from starlette.routing import Match
logger.remove()
logger.add(sys.stdout, colorize=True, format="<green>{time:HH:mm:ss}</green> | {level} | <level>{message}</level>")
app = FastAPI()
@app.middleware("http")
async def log_middle(request: Request, call_next):
logger.debug(f"{request.method} {request.url}")
routes = request.app.router.routes
logger.debug("Params:")
for route in routes:
match, scope = route.matches(request)
if match == Match.FULL:
for name, value in scope["path_params"].items():
logger.debug(f"\t{name}: {value}")
logger.debug("Headers:")
for name, value in request.headers.items():
logger.debug(f"\t{name}: {value}")
response = await call_next(request)
return response
@app.get("/{param1}/{param2}")
async def path_operation(param1: str, param2: str):
return {'param1': param1, 'param2': param2}
'''
app.add_middleware(
CORSMiddleware,
#myCORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
#########################################
##
## main predict nougat function
##
#########################################
def rasterize_paper(
pdf: Union[Path, bytes],
outpath: Optional[Path] = None,
dpi: int = 96,
return_pil=False,
pages=None,
) -> Optional[List[io.BytesIO]]:
"""
Rasterize a PDF file to PNG images.
Args:
pdf (Path): The path to the PDF file.
outpath (Optional[Path], optional): The output directory. If None, the PIL images will be returned instead. Defaults to None.
dpi (int, optional): The output DPI. Defaults to 96.
return_pil (bool, optional): Whether to return the PIL images instead of writing them to disk. Defaults to False.
pages (Optional[List[int]], optional): The pages to rasterize. If None, all pages will be rasterized. Defaults to None.
Returns:
Optional[List[io.BytesIO]]: The PIL images if `return_pil` is True, otherwise None.
"""
pils = []
if outpath is None:
return_pil = True
try:
if isinstance(pdf, (str, Path)):
logging.debug("input pdf as str")
pdf = pypdfium2.PdfDocument(pdf)
if pages is None:
pages = range(len(pdf))
renderer = pdf.render(
pypdfium2.PdfBitmap.to_pil,
page_indices=pages,
scale=dpi / 72,
)
for i, image in zip(pages, renderer):
if return_pil:
page_bytes = io.BytesIO()
image.save(page_bytes, "bmp")
pils.append(page_bytes)
else:
image.save((outpath / ("%02d.png" % (i + 1))), "png")
else:
logging.debug("input pdf as bytes")
if pages is None:
pages = range(len(pdf))
for page in pdf:
mat = fitz.Matrix(dpi / 72, dpi / 72) # sets zoom factor for 300 dpi
pix = page.get_pixmap(matrix=mat)
if return_pil:
#page_iobytes = io.BytesIO()
#pix.save(page_bytes, "PNG")
#page_iobytes.seek(0)
page_bytes = pix.pil_tobytes('PNG')
page_iobytes = io.BytesIO(page_bytes)
page_iobytes.seek(0)
#img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
#image.save(page_iobytes, format='PNG')
#imgarr = np.asarray(img)
#img = Image.open(page_bytes)
#dat = np.frombuffer(page_bytes)
pils.append(page_iobytes)
else:
#image.save((outpath / ("%02d.png" % (i + 1))), "png")
img_filename = outpath + "/" + "%02d.png" % (page.number + 1)
pix.pil_save(img_filename, format="TIFF", dpi=(300,300))
except Exception as e:
logging.error(e)
# Iterate over all the pages in the document
if return_pil:
return pils
def resize_with_padding(img, expected_size):
img.thumbnail((expected_size[0], expected_size[1]))
# print(img.size)
delta_width = expected_size[0] - img.size[0]
delta_height = expected_size[1] - img.size[1]
pad_width = delta_width // 2
pad_height = delta_height // 2
padding = (pad_width, pad_height, delta_width - pad_width, delta_height - pad_height)
return ImageOps.expand(img, padding, fill = 'white')
def predict_image(model_name, images, isBinarized = False, batchsize=1, markdown=True, out_path_root="./output"):
logging.info('*** nougat predict with input image ***')
global nougatModel, model_list
if nougatModel == None:
nougatModel = NougatModel.from_pretrained(default_checkpoint_path).to(torch.float16)
if batchsize > 0:
if torch.cuda.is_available():
nougatModel.to("cuda")
else:
# set batch size to 1. Need to check if there are benefits for CPU conversion for >1
batchsize = 1
nougatModel.eval()
if model_name == None:
logging.info(f"Using {default_model_name} for example predicting.")
model = nougatModel
else:
#model = model_list.get(model_name, nougatModel)
logging.info(f"Using {model_name} for example predicting.")
#print(f"Using {model_name} for example predicting.")
model = loadModel(model_list.get(model_name, default_checkpoint_path))
prepare = model.encoder.prepare_input
datasets = []
output = ""
predictions = []
if images!= None:
import torchvision.transforms as transforms
logging.info("we are under image to mmd convertiong")
sample = images.convert('RGB')
if isBinarized:
images = sauvolaBinarize(img=sample)
im_new = resize_with_padding(sample, (672,896))
img_tensor = prepare(im_new,random_padding=False)
img_tensor = img_tensor.unsqueeze(0)
model_output = model.inference(image_tensors=img_tensor)
for j, output in enumerate(model_output["predictions"]):
predictions.append(output)
# check if model output is faulty
if markdown:
output = markdown_compatible(output)
out = "".join(predictions).strip()
out = re.sub(r"\n{3,}", "\n\n", out).strip()
out_path = None
if out:
out_path = Path(out_path_root) / Path("test").with_suffix(".mmd").name
out_path.parent.mkdir(parents=True, exist_ok=True)
if out_path.exists():
os.remove(out_path)
with open(out_path,mode="w",encoding="utf-8") as f:
out = out.replace(r"\(", "$").replace(r'\)', '$').replace(r'\[', '$$').replace(r'\]', '$$')
f.write(out)
else:
logging.debug(f"the out is {out}")
return model_output, [out_path]
#########################################
##
## fastapi interface
##
#########################################
@app.get("/check")
def root():
"""Health check."""
response = {
"status-code": HTTPStatus.OK,
"data": {},
}
return response
#templates = Jinja2Templates(directory="templates")
#@router.get("/register/")
#def register(request: Request):
# return templates.TemplateResponse("users/register.html", {"request": request})
#@router.post("/register/")
#async def register(request: Request, db: Session = Depends(get_db)):
# form = UserCreateForm(request)
# await form.load_data()
# if await form.is_valid():
# user = UserCreate(
# username=form.username, email=form.email, password=form.password
# )
# try:
# user = create_new_user(user=user, db=db)
# return responses.RedirectResponse(
# "/?msg=Successfully-Registered", status_code=status.HTTP_302_FOUND
# ) # default is post request, to use get request added status code 302
# except IntegrityError:
# form.__dict__.get("errors").append("Duplicate username or email")
# return templates.TemplateResponse("users/register.html", form.__dict__)
# return templates.TemplateResponse("users/register.html", form.__dict__)
@app.post("/multi")
async def check_multi_files(files: List[UploadFile]):
filenames = [file.filename for file in files]
#print(filenames)
return {"filenames": [file.filename for file in files]}
@app.post("/single")
async def check(file: bytes = File() ):
#print(len(file))
#print(file.name)
return {"state": 200}
def convertImageFormat(imgObj, outputFormat="PNG"):
newImgObj = imgObj
if outputFormat and (imgObj.format != outputFormat):
imageBytesIO = io.BytesIO()
imgObj.save(imageBytesIO, outputFormat)
newImgObj = Image.open(imageBytesIO)
return newImgObj
from urllib.parse import parse_qs
from pydantic import BaseModel, model_validator, ValidationError
from typing import Optional, List, Dict
from fastapi import Form
class SubmitGeneral(BaseModel):
selected_model_name: str
class Base(BaseModel):
name: str
point: Optional[float] = None
is_accepted: Optional[bool] = False
@model_validator(mode='before')
@classmethod
def validate_to_json(cls, value):
if isinstance(value, str):
return cls(**json.loads(value))
return value
@app.post("/users/", response_model=schemas.User)
def create_user(user: schemas.UserCreate, db: Session = Depends(get_db)):
db_user = crud.get_user_by_email(db, email=user.email)
if db_user:
raise HTTPException(status_code=400, detail="Email already registered")
return crud.create_user(db=db, user=user)
@app.get("/users/", response_model=List[schemas.User])
def read_users(skip: int = 0, limit: int = 100, db: Session = Depends(get_db)):
users = crud.get_users(db, skip=skip, limit=limit)
return users
def stringToBool(string:str) -> bool:
if (string.lower()=="true"):
return True
elif(string.lower()=="false"):
return False
else: return False
def sauvolaBinarize(img:Image = None, dest:str = None):
if dest!= None:
cv2_image = cv2.imread(dest)
cv2_image = cv2.cvtColor(cv2_image, cv2.COLOR_RGB2BGR)
if Image != None:
#convert to cv2 format
cv2_image = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
if cv2_image is None:
return None
#convert to gray format
gray_image = cv2.cvtColor(cv2_image, cv2.COLOR_BGR2GRAY)
#binarize with sauvola algorithm
adaptive_threshold_img= cv2.adaptiveThreshold(gray_image, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 5, 5)
#too slow
#adaptive_threshold_img = binarization.sauvola(gray_image)
#convert back to RGB format
img = cv2.cvtColor(adaptive_threshold_img,cv2.COLOR_GRAY2RGB)
img = binarization.convert_from_cv2_to_image(img)
#dest_filename, dest_ext = os.path.splitext(dest)
img.save( './uploads/testBinarization' + ".png", "PNG")
return img
@app.post("/predict")
async def predict(
request: Request,
selectedModel: str = Form(...),
isBinarized: str = Form(...),
file: UploadFile = File(...),
start: int = None, stop: int = None,
):
"""
Perform predictions on a PDF document and return the extracted text in Markdown format.
Args:
file (UploadFile): The uploaded PDF file to process.
start (int, optional): The starting page number for prediction.
stop (int, optional): The ending page number for prediction.
Returns:
str: The extracted text in Markdown format.
"""
global nougatModel
logging.info(f'Content type of file: {file.content_type}')
logging.info(f'selectedModel values : {selectedModel}')
#if content_type != "application/vnd.api+json":
#print(f'file type of file: {file.type}')
logging.debug(f'Content uploaded file: {file.filename}')
logging.debug(f'request values: {request.values}')
logging.debug(f'request headers: {request.headers}')
#parsed_url = urlparse(request)
#model_name = parse_qs(parsed_url.query)['model'][0]
model_name = selectedModel
is_binarized = stringToBool(isBinarized)
if model_name == None:
model = nougatModel
else:
#model = model_list.get(model_name, nougatModel)
model = loadModel(model_list.get(model_name, default_checkpoint_path))
##test code
upload_dir = os.path.join(os.getcwd(), "uploads")
# Create the upload directory if it doesn't exist
if not os.path.exists(upload_dir):
os.makedirs(upload_dir)
dest = os.path.join(upload_dir, file.filename)
#result = []
#async with file.file as f:
# async for line in io.TextIOWrapper(f, encoding='utf-8'):
# result.append(len(line))
finger_printer = None
blob_file = dest
if file.content_type == "application/image":
try:
with open(dest, 'wb') as f:
logging.info(f"save uploading files to {dest}")
imgbin = await file.read()
f.write(imgbin)
md5 = hashlib.md5(imgbin).hexdigest()
finger_printer = md5
save_path = SAVE_DIR / md5
except Exception:
return {"message": "There was an error uploading the file"}
finally:
f.close()
#logging.info(f"input image type is {type(imgbin)}")
#logging.info(f"input image type is {type(f)}")
img = Image.open(io.BytesIO(imgbin))
logging.info(f"uploading Image Type: {type(img)}")
if img.format != "PNG":
dest_filename, dest_ext = os.path.splitext(dest)
#with io.BytesIO() as f:
# img.save(f, format='PNG')
# f.seek(0)
# img = Image.open(f)
img = convertImageFormat(img)
img.save( dest_filename + ".png", "PNG")
model_output,_ = predict_image(model_name = model_name, images=img, isBinarized=is_binarized)
logging.debug(f"predict output as: {model_output}")
predictions = [""]
for j, output in enumerate(model_output["predictions"]):
if model_output["repeats"][j] is not None:
if model_output["repeats"][j] > 0:
disclaimer = "\n\n+++ ==WARNING: Truncated because of repetitions==\n%s\n+++\n\n"
else:
disclaimer = (
"\n\n+++ ==ERROR: No output for this page==\n%s\n+++\n\n"
)
rest = close_envs(model_output["repetitions"][j]).strip()
if len(rest) > 0:
disclaimer = disclaimer % rest
else:
disclaimer = ""
else:
disclaimer = ""
predictions[0] = ( markdown_compatible(output) + disclaimer )
(save_path / "pages").mkdir(parents=True, exist_ok=True)
elif file.content_type == "application/pdf":
try:
with open(dest, 'wb') as f:
logging.debug(f"save files to {dest}")
pdfbin = await file.read()
f.write(pdfbin)
pdf = fitz.open("pdf", pdfbin)
md5 = hashlib.md5(pdfbin).hexdigest()
finger_printer = md5
save_path = SAVE_DIR / md5
logging.info(save_path)
#shutil.copyfileobj(file.file, f)
except Exception:
return {"message": "There was an error uploading the file"}
finally:
f.close()
logging.info(f"{file.filename} uploaded successfully, length {len(pdfbin)}")
if start is not None and stop is not None:
pages = list(range(start - 1, stop))
else:
pages = list(range(len(pdf)))
predictions = [""] * len(pages)
dellist = []
if save_path.exists():
for computed in (save_path / "pages").glob("*.mmd"):
logging.debug(f"computed pdf content is {computed}")
try:
idx = int(computed.stem) - 1
if idx not in pages:
i = pages.index(idx)
logging.debug("skip page", idx + 1)
predictions[i] = computed.read_text(encoding="utf-8")
dellist.append(idx)
except Exception as e:
logging.debug(e)
compute_pages = pages.copy()
for el in dellist:
compute_pages.remove(el)
images = rasterize_paper(pdf, pages=compute_pages)
#images = rasterize_paper(dest, pages=compute_pages)
dataset = ImageDataset(
images,
partial(model.encoder.prepare_input, random_padding=False),
)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=BATCHSIZE,
pin_memory=True,
shuffle=False,
)
for idx, sample in tqdm(enumerate(dataloader), total=len(dataloader)):
if sample is None:
continue
model_output = model.inference(image_tensors=sample)
for j, output in enumerate(model_output["predictions"]):
if model_output["repeats"][j] is not None:
if model_output["repeats"][j] > 0:
disclaimer = "\n\n+++ ==WARNING: Truncated because of repetitions==\n%s\n+++\n\n"
else:
disclaimer = (
"\n\n+++ ==ERROR: No output for this page==\n%s\n+++\n\n"
)
rest = close_envs(model_output["repetitions"][j]).strip()
if len(rest) > 0:
disclaimer = disclaimer % rest
else:
disclaimer = ""
else:
disclaimer = ""
predictions[pages.index(compute_pages[idx * BATCHSIZE + j])] = (
markdown_compatible(output) + disclaimer
)
(save_path / "pages").mkdir(parents=True, exist_ok=True)
pdf.save(save_path / "doc.pdf")
if len(images) > 0:
thumb = Image.open(images[0])
thumb.thumbnail((400, 400))
thumb.save(save_path / "thumb.jpg")
for idx, page_num in enumerate(pages):
(save_path / "pages" / ("%02d.mmd" % (page_num + 1))).write_text(
predictions[idx], encoding="utf-8"
)
final = "".join(predictions).strip()
(save_path / "doc.mmd").write_text(final, encoding="utf-8")
#saving to database inference table in async mode
await aync_save2database(blob_file = Path(blob_file), finger_printer=finger_printer, model_name = model_name, result=final)
logging.info("***********************************")
logging.info("save successfully")
logging.info("***********************************")
logging.info(f'save path is {save_path}')
logging.info(f"final output content length is {len(final)}")
return final
#########################################
##
## gradio web
##
#########################################
def nougat_predict(input_files, output_path, model_name, batchsize, markdown,recompute):
logging.info(f'*** nougat predict with input :{input_files} ***')
global nougatModel, model_list, selected_model_name
if nougatModel == None:
nougatModel = NougatModel.from_pretrained(default_checkpoint_path).to(torch.float16)
if batchsize > 0:
if torch.cuda.is_available():
nougatModel.to("cuda")
else:
# set batch size to 1. Need to check if there are benefits for CPU conversion for >1
batchsize = 1
nougatModel.eval()
logging.info(f"Using {model_name} for example predicting.")
if model_name == None:
model = nougatModel
else:
#model = model_list.get(model_name, nougatModel)
model = loadModel(model_list.get(model_name, default_checkpoint_path))
datasets = []
for pdf in input_files:
#if not pdf.exists():
if not os.path.exists(pdf):
continue
if output_path:
out_path = output_path / pdf.with_suffix(".mmd").name
if out_path.exists() and not recompute:
logging.info(
f"Skipping {pdf.name}, already computed. Run with --recompute to convert again."
)
continue
try:
dataset = LazyDataset(
pdf, partial(model.encoder.prepare_input, random_padding=False)
)
except fitz.fitz.FileDataError:
logging.info(f"Could not load file {str(pdf)}.")
continue
datasets.append(dataset)
if len(datasets) == 0:
logging.info(f'*** nougat out files :{out_path} ***')
return out_path
dataloader = torch.utils.data.DataLoader(
ConcatDataset(datasets),
batch_size=batchsize,
shuffle=False,
collate_fn=LazyDataset.ignore_none_collate,
)
predictions = []
output_file_list = []
file_index = 0
page_num = 0
for i, (sample, is_last_page) in enumerate(tqdm(dataloader)):
model_output = model.inference(image_tensors=sample)
# check if model output is faulty
for j, output in enumerate(model_output["predictions"]):
if page_num == 0:
logging.info(
"Processing file %s with %i pages"
% (datasets[file_index].name, datasets[file_index].size)
)
page_num += 1
if output.strip() == "[MISSING_PAGE_POST]":
# uncaught repetitions -- most likely empty page
predictions.append(f"\n\n[MISSING_PAGE_EMPTY:{page_num}]\n\n")
elif model_output["repeats"][j] is not None:
if model_output["repeats"][j] > 0:
# If we end up here, it means the output is most likely not complete and was truncated.
logging.warning(f"Skipping page {page_num} due to repetitions.")
predictions.append(f"\n\n[MISSING_PAGE_FAIL:{page_num}]\n\n")
else:
# If we end up here, it means the document page is too different from the training domain.
# This can happen e.g. for cover pages.
predictions.append(
f"\n\n[MISSING_PAGE_EMPTY:{i*batchsize+j+1}]\n\n"
)
else:
if markdown:
output = markdown_compatible(output)
predictions.append(output)
if is_last_page[j]:
out = "".join(predictions).strip()
out = re.sub(r"\n{3,}", "\n\n", out).strip()
out = out.replace(r"\(", "$").replace(r'\)', '$').replace(r'\[', '$$').replace(r'\]', '$$')
if output_path:
out_path = output_path / Path(is_last_page[j]).with_suffix(".mmd").name
out_path.parent.mkdir(parents=True, exist_ok=True)
out_path.write_text(out, encoding="utf-8")
output_file_list.append(out_path)
else:
logging.debug(out, "\n\n")
predictions = []
page_num = 0
file_index += 1
logging.info(f'the generated markdown files are : {output_file_list}')
return output_file_list
def get_image(url_list):
query_parameters = {"downloadformat": "image"}
for url in url_list:
url_parsed = urlparse(url)
file_path = Path(url_parsed.path)
new_path = Path("./input")
new_file = os.path.join(new_path, os.path.basename(file_path))
response = requests.get(url, stream=True)
if response.ok:
with open(new_file, mode="wb") as file:
for data in tqdm(response.iter_content()):
file.write(data)
# Download pdf from a given link
def get_pdf(pdf_link):
# Generate a unique filename
unique_filename = f"input/downloaded_paper_{uuid.uuid4().hex}.pdf"
# Send a GET request to the PDF link
response = requests.get(pdf_link)
if response.status_code == 200:
# Save the PDF content to a local file
with open(unique_filename, 'wb') as pdf_file:
pdf_file.write(response.content)
logging.info("PDF downloaded successfully.")
else:
logging.info("Failed to download the PDF.")
return unique_filename
def sha256sum(filename):
h = hashlib.sha256()
b = bytearray(128*1024)
mv = memoryview(b)
with open(filename, 'rb', buffering=0) as f:
while n := f.readinto(mv):
h.update(mv[:n])
return h.hexdigest()
def readImage():
try:
fin = open("woman.jpg", "rb")
img = fin.read()
return img
except IOError as e:
logging.info ("Error %d: %s" % (e.args[0],e.args[1]))
sys.exit(1)
finally:
if fin:
fin.close()
# converts from python to postgres
def convert_To_Binary(filename):
with open(filename, 'rb') as file:
data = file.read()
return psycopg.Binary(data)
def _adapt_array(filename):
with open(filename, 'rb') as file:
out = io.BytesIO(file.read())
out.seek(0)
return bytearray(out.read())
# converts from postgres to python
def _typecast_array(value, cur):
if value is None:
return None
data = psycopg.Binary(value, cur)
bdata = io.BytesIO(data[1:-1])
bdata.seek(0)
return np.loadtxt(bdata)
async def aync_save2database(blob_file, finger_printer, model_name, result):
#saving to database inference table
blob = None
if blob_file != None:
blob = _adapt_array(blob_file)
inference = InferenceCreate(
blobContent= blob,
textContent= "",
fileName = blob_file.name,
fingerPrint=finger_printer,
selectedModel= model_name,
result=result,
)
async with async_session() as db_session:
my_inference_on_db = await crud.inference.create_async_inference(db_session, obj_in=inference)
return my_inference_on_db
def save2database(blob_file, finger_printer, model_name, result):
#saving to database inference table
blob = None
if blob_file != None:
blob = _adapt_array(blob_file)
inference = InferenceCreate(
blobContent= blob,
textContent= "",
fileName = blob_file.name,
fingerPrint=finger_printer,
selectedModel= model_name,
result=result,
)
db_session = next(get_db())
my_inference_on_db = crud.inference.create_inference(db_session = db_session, obj_in = inference)
#my_inference_on_db = crud.inference.inferenceAsync(db_session, obj_in=inference)
return my_inference_on_db
# predict function / driver function
def perapare_reader(pdf_file, pdf_link, img_file, model_name, async_model=False):
global model_list, selected_model_name
logging.info(f'*** paper_read ****')
logging.info(f'*** using model ****{model_name}')
output_path = Path("./output")
markdown = True
batchsize = BATCH_SIZE
img = None
blob_file = None
finger_printer = ""
if img_file is not None:
logging.info(f'*** handing image file : {img_file} ****')
file_name = img_file.name
blob_file = img_file
img = Image.open(img_file)
logging.info(f"Image Type: {type(img)}")
finger_printer = sha256sum(img_file)
_, output_files = predict_image(model_name, img, markdown=True, out_path_root=Path("output"))
else:
if pdf_file is None:
logging.info(f'*** handing pdf link paper :{pdf_link} ****')
if pdf_link == '':
logging.info("No file is uploaded and No link is provided")
return "No data provided. Upload a pdf file or provide a pdf link and try again!"
else:
file_name = get_pdf(pdf_link)
blob_file = Path(file_name)
else:
logging.info(f'*** handing pdf paper :{pdf_file}***')
file_name = pdf_file.name
blob_file = pdf_file
pdf_name = pdf_file.name.split('/')[-1].split('.')[0]
input_files = file_name if isinstance(file_name, os.PathLike) else Path(file_name),
finger_printer = sha256sum(input_files[0])
output_files = nougat_predict(input_files=input_files, output_path=output_path, model_name = model_name,
batchsize = batchsize, markdown = markdown, recompute=True)
logging.info(f'the generated markdown file is : {output_files}')
# Open the file for reading
#file_name = file_name.split('/')[-1][:-4]
#with open(f'output/{file_name}.mmd', 'r') as file:
fileList = []
content = None
for output_file in output_files:
if output_file != None:
with open(output_file, 'r+', encoding="utf-8") as file:
content = file.read()
# switch math delimiters
content = content.replace(r"\\(", "$").replace(r'\\)', '$').replace(r'\\[', '$$').replace(r'\\]', '$$')
content = content.replace(r"\(", "$").replace(r'\)', '$').replace(r'\[', '$$').replace(r'\]', '$$')
fileList.append(str(output_file))
else:
fileList.append("")
#download_files = gr.File(value=fileList)
if content:
logging.info("***********************************")
logging.info("convert successfully")
logging.info("***********************************")
if async_model:
try:
#see if there is a loop already running. If there is, reuse it.
loop = asyncio.get_running_loop()
except RuntimeError:
# Create new event loop if one is not running
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(aync_save2database(blob_file, finger_printer, model_name, content))
else:
save2database(blob_file, finger_printer, model_name, content)
logging.info("***********************************")
logging.info("save successfully")
logging.info("***********************************")
else:
content = "convert failed"
logging.info("***********************************")
logging.info("convert failed")
logging.info("***********************************")
return content, gr.Markdown("Converted file list :"+ fileList[0]) , gr.DownloadButton(value="/file=" + str(output_files[0]), visible=True)
import asyncio
# Handling examples in Gradio app
def process_example(pdf_file,pdf_link,img_file, model_name):
logging.info('*** process_example ****')
if model_name == None:
model_name = default_model_name
ocr_content, output_files, download_btn = perapare_reader(pdf_file,pdf_link,img_file, model_name)
#print(f"ocr_content is :{ocr_content}")
return gr.update(value=ocr_content), output_files, download_btn
css = """
#mkd {
height: 500px;
overflow: auto;
border: 1px solid #ccc;
}
"""
url_list = ["https://cdn.mathpix.com/snip/images/Hm62Ib-dDZOseYuVNN8k34IhBY18KglOrM7qETOqXZI.original.fullsize.png",
"https://cdn.mathpix.com/snip/images/lSL07DYTL1bdjzL2mpNyVg17JmqKwgugMLyGuxkLgLg.original.fullsize.png"]
# Function for generating some examples
def generate_examples():
extracted_list = [["./input/test.pdf", "", None],
[None, "https://arxiv.org/pdf/2308.08316.pdf", None],
[None, "", "./input/Hm62Ib-dDZOseYuVNN8k34IhBY18KglOrM7qETOqXZI.original.fullsize.png"]]
return gr.Dataset.update(samples=extracted_list)
#get_pdf("https://arxiv.org/pdf/2308.13418.pdf")
#get_image(url_list)
output_files = []
gr.set_static_paths(paths=["output/"])
gr.set_static_paths(paths=["images/"])
#with gr.Blocks(css=css) as demo:
with gr.Blocks(theme='reilnuud/polite') as demo:
gr.HTML("""<h1><left>AIWorm.CN <img src="file/images/ic_worm_ai.png" style="display: inline-block; margin: 0;" width="100" height="50" /><left><h1>""")
gr.HTML("""<h1><center>Using Nougat: Neural Optical Understanding for Academic Documents<center><h1>""")
gr.HTML("<h3><center>the orignal Nougat-OCR Prject is done by Lukas Blecher et al. <a href='https://arxiv.org/pdf/2308.13418.pdf' target='_blank'>Paper</a>, <a href='https://facebookresearch.github.io/nougat/'>Project</a><center></h3>")
gr.HTML("<h3>The model below: <br>\
- model nougat_big_facebook is the facebook nougat project base model (only support pdf/english input), refer to <a href='https://huggingface.co/facebook/nougat-base '> facebook nougat-base</a> <br>\
- model nougat_small_facebook is the facebook nougat model (only support pdf/english input), refer to <a href='https://huggingface.co/facebook/nougat-small'> facebook nougat-smalle</a> <br> \
- model nougat_middel_cn is the fine-tuned chinese version adopt to pdf and image input , refer to <a href='https://huggingface.co/zphilip48/nougat-middle-cn'> zphilip48's nougat-middle-cn</a> <br>\
- model nougat_small_cn is the fine-tuned chinese version adopt to pdf and image input </h3>")
model_name_list = ["nougat_middle_ocr",
"nougat_middle_cn",
"nougat_small_facebook",
"nougat_small_ocr_ocrpadded",
"nougat_big_facebook",
"nougat_small_cn"]
selected_model_name = gr.Dropdown(choices=model_name_list, label="Select models ", value=model_name_list[1], interactive=True)
with gr.Row():
mkd = gr.Markdown('<h4><center>Upload a PDF</center></h4>')
mkd = gr.Markdown('<h4><center><i>OR</i></center></h4>')
mkd = gr.Markdown('<h4><center>Provide a PDF link</center></h4>')
mkd = gr.Markdown('<h4><center><i>OR</i></center></h4>')
mkd = gr.Markdown('<h4><center>Upload a image(PNG or JPGE)</center></h4>')
with gr.Row(equal_height=True):
pdf_file = gr.File(label='PDF📃', file_count='single')
pdf_link = gr.Textbox(placeholder='Enter an Arxiv link here', label='PDF link🔗🌐')
img_file = gr.File(label='IMG📃', file_count='single')
with gr.Row():
btn = gr.Button('Run NOUGAT🍫')
clr = gr.Button('Clear🚿')
if len(output_files) == 0:
download_btn_visible = False
download_file_path = "./output/test.mmd"
else:
download_btn_visible = True
download_file_path = output_files[0]
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("""<h3><center><font color="red">The converting will take some time due to the GPU P100 as backend </font></center></h3>""")
gr.Markdown("<h3><center>Gradio Markdown might not render complex latex equation correctly, try download it👇:</center></h3>")
download_btn = gr.DownloadButton('Download The Converted Markdown File ', value="/file=" + download_file_path, visible=True)
#with open("/workspace/nougat-latex/output/1910.13461.mmd", 'r') as test_file:
# content = test_file.read()
# content = content.replace(r"\\(", "$").replace(r'\\)', '$').replace(r'\\[', '$$').replace(r'\\]', '$$')
# content = content.replace(r"\(", "$").replace(r'\)', '$').replace(r'\[', '$$').replace(r'\]', '$$')
#gr.Markdown(content, latex_delimiters=[{ "left": r"$", "right": r"$", "display": False },{ "left": r"$$", "right": r"$$", "display": True }])
#output_files = gr.File(interactive=False)
output_headline = gr.Markdown("<h3><center>PDF converted into markup language through Nougat-OCR👇:</center></h3>")
output_files = gr.Markdown("Converted file list")
parsed_output = gr.Markdown(r'# OCR Output📃🔤',elem_id='mkd', latex_delimiters=[{ "left": "$", "right": "$", "display": False },{ "left": "$$", "right": "$$", "display": True }, { "left": "\(", "right": "\)", "display": False }, { "left": "\[", "right": "\]", "display": True }])
btn.click(perapare_reader, inputs=[pdf_file, pdf_link, img_file, selected_model_name], outputs=[parsed_output, output_files, download_btn])
#download_btn.click(download_file_path)
clr.click(lambda : (gr.update(value=None),
gr.update(value=""),
gr.update(value=None),
gr.update(value=None)),
[],
[pdf_file, pdf_link, img_file, parsed_output]
)
def update_visibility(radio, text): # Accept the event argument, even if not used
value = radio # Get the selected value from the radio button
if value == "show":
text.visible = True
else:
text.visible = False
example_btn = gr.Button('re-Run Example🍫')
examples = gr.Examples(
[["./input/test.pdf", "", None, "nougat_big_facebook"],
[None, "https://arxiv.org/pdf/2308.08316.pdf", None, "nougat_middle_cn"],
[None, "", "./input/Hm62Ib-dDZOseYuVNN8k34IhBY18KglOrM7qETOqXZI.original.fullsize.png","nougat_small_cn"]],
inputs = [pdf_file, pdf_link, img_file, selected_model_name],
outputs = [parsed_output,output_files, download_btn],
fn=process_example,
cache_examples=True,
run_on_click=True,
label='Click on any Examples below to get Nougat OCR results quickly:'
)
example_btn.click(generate_examples, outputs=[examples.dataset])
#demo.load(parsed_output, None, text_out, every=3)
demo.queue()
app = gr.mount_gradio_app(app, demo, path="/")
if __name__ == "__main__":
#uvicorn.run(app, host="0.0.0.0", port=8866,log_level="debug",
# ssl_keyfile='/workspace/nougat-latex/lzs.chrdw.ml.key',
# ssl_certfile='/workspace/nougat-latex/fullchain.cer')
uvicorn.run("__main__:app", host="0.0.0.0", port=8503,log_level="debug", workers=1)
#demo.launch(debug=True,share=True, server_name="0.0.0.0",server_port=8866)