Spaces:
Runtime error
Runtime error
import gradio as gr | |
import os | |
#os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" | |
#os.environ["CUDA_VISIBLE_DEVICES"] = "3" # just use one GPU on big machine | |
import torch | |
print ('Available devices ', torch.cuda.device_count()) | |
print ('Current cuda device ', torch.cuda.current_device()) | |
assert torch.cuda.device_count() == 1 | |
print('GPU Device name:', torch.cuda.get_device_name(torch.cuda.current_device())) | |
import sys | |
from functools import partial | |
from http import HTTPStatus | |
import cv2 | |
from fastapi import FastAPI, File, UploadFile, Request,Response, BackgroundTasks, HTTPException | |
from fastapi import APIRouter, Depends | |
import os | |
import requests | |
import uuid | |
#for the user register and login | |
#from fastapi.templating import Jinja2Templates | |
#from schemas.users import UserCreate | |
#from sqlalchemy.orm import Session | |
#from sqlalchemy.exc import IntegrityError | |
#from webapps.users.forms import UserCreateForm | |
from PIL import Image | |
from pathlib import Path | |
import hashlib | |
from fastapi.middleware.cors import CORSMiddleware | |
import fitz | |
import torch | |
from torch.utils.data import ConcatDataset | |
from nougat import NougatModel | |
from nougat.postprocessing import markdown_compatible, close_envs | |
from nougat.utils.dataset import ImageDataset,LazyDataset | |
from nougat.utils.checkpoint import get_checkpoint | |
#from nougat.dataset.rasterize import rasterize_paper | |
from tqdm import tqdm | |
import uvicorn | |
import shutil | |
import io | |
import numpy as np | |
import logging | |
import pypdfium2 | |
from typing import Optional, List, Union | |
from PIL import ImageOps,Image | |
import re | |
from contextlib import asynccontextmanager | |
from starlette.types import Message | |
from starlette.middleware.base import BaseHTTPMiddleware | |
from starlette.background import BackgroundTask | |
from urllib.parse import urlparse, unquote | |
from utils import binarization | |
from sql_app import models, schemas | |
from sql_app import crud | |
from sql_app.db import get_db, engine, async_engine, async_session | |
from sqlalchemy.orm import Session | |
from sqlalchemy.ext.asyncio import AsyncSession | |
from sql_app.schemas import ( | |
InferenceCreate, | |
InferenceUpdate, | |
InferenceRead, | |
#Inference | |
) | |
from sql_app.models import Inference,User | |
from hashlib import sha1,md5 | |
from datetime import datetime | |
from sql_app.db import Base | |
import psycopg | |
import numpy as np | |
#logging.basicConfig(level=logging.INFO) | |
logging.basicConfig(filename='info.log', level=logging.INFO) | |
#logger = logging.getLogger() | |
#logger.setLevel(logging.INFO) | |
SAVE_DIR = Path("./pdfs") | |
BATCHSIZE = os.environ.get("NOUGAT_BATCHSIZE", 6) | |
NOUGAT_CHECKPOINT = get_checkpoint() | |
######################################### | |
## | |
## init the fastapi server | |
## | |
######################################### | |
global selected_model_name | |
# Load the ML model | |
def loadModel(checkpoint): | |
if not checkpoint.exists(): | |
checkpoint = default_checkpoint_path | |
logging.info(f"request checkpoint is not exist, using default {checkpoint_name}") | |
model = NougatModel.from_pretrained(checkpoint).to(torch.bfloat16) | |
if torch.cuda.is_available(): | |
model.to("cuda") | |
model.eval() | |
return model | |
# Load the ML model | |
checkpoint_list = [Path("./nougat_middle_ocr/"), | |
Path("./nougat_middle_cn/"), | |
Path("./nougat_small_facebook/"), | |
Path("./nougat_small_ocr_ocrpadded/"), | |
Path("./nougat_small_cn/"), | |
Path("./nougat_big_facebook/")] | |
default_checkpoint_path = Path("./nougat_middle_cn/") | |
default_model_name = "nougat_middle_cn" | |
selected_model_name = None | |
model_list = dict() | |
if model_list == {}: | |
logging.info(f"Start up and init the Nougat Model with {checkpoint_list} ") | |
for checkpoint in checkpoint_list: | |
checkpoint_name = os.path.basename(checkpoint) | |
model_list[checkpoint_name] = checkpoint | |
logging.info(f"loading model {checkpoint_name}") | |
nougatModel = loadModel(model_list.get(default_model_name, default_checkpoint_path)) | |
if NOUGAT_CHECKPOINT is None: | |
logging.info( | |
"Set environment variable 'NOUGAT_CHECKPOINT' with a path to the model checkpoint!." | |
) | |
sys.exit(1) | |
if torch.cuda.is_available(): | |
BATCH_SIZE = int( | |
torch.cuda.get_device_properties(0).total_memory / 1024 / 1024 / 1000 * 0.3 | |
) | |
if BATCH_SIZE == 0: | |
logging.warning("GPU VRAM is too small. Computing on CPU.") | |
else: | |
# don't know what a good value is here. Would not recommend to run on CPU | |
BATCH_SIZE = 1 | |
logging.warning("No GPU found. Conversion on CPU is very slow.") | |
async def lifespan(app: FastAPI): | |
# Load the ML model | |
#global nougatModel, model_list, default_checkpoint_path, default_model_name | |
yield | |
# Clean up the ML models and release the resources | |
for model in model_list: | |
model.clear() | |
app = FastAPI(lifespan=lifespan, | |
title="AIWorm Application", | |
description="AIWorm Application with FastAPI and Gradio", | |
version="1.0.0",) | |
Base.metadata.create_all(bind=engine,checkfirst=True) | |
def create_db_schema(engine): | |
Base.metadata.create_all(engine,checkfirst=True) | |
def init_db(): | |
#Base.metadata.create_all(bind=engine) | |
db_session = next(get_db()) | |
#async with async_session() as db_session: | |
db_session.add( | |
Inference( | |
#blobContent= "file name", | |
#textContent= "test", | |
fingerPrint= "12124isefuadsf" | |
#selectedModel= "test", | |
#result="test", | |
#published = False | |
#created_at = datetime.utcnow() | |
) | |
) | |
db_session.commit() | |
logging.info("Initialized the db") | |
origins = ["http://localhost", "http://tec1.aiworm.cn:8866", "http://127.0.0.1","https://82s56k6681.zicp.fun/","http://192.168.1.34"] | |
#@app.on_event("shutdown") | |
#def shutdown_event(): | |
# for model in model_list: | |
# model.clear() | |
#@app.on_event("startup") | |
#async def load_model( | |
# checkpoint: str = NOUGAT_CHECKPOINT, | |
#): | |
# global model | |
# if model is None: | |
# model = NougatModel.from_pretrained(checkpoint).to(torch.bfloat16) | |
# if torch.cuda.is_available(): | |
# model.to("cuda") | |
# model.eval() | |
######################################### | |
## | |
## middleware for fastapi app | |
## | |
######################################### | |
class MyMiddleware(BaseHTTPMiddleware): | |
async def set_body(self, request: Request): | |
receive_ = await request._receive() | |
body = receive_.get('body') | |
# Make all changes to the body object here and return the modified request | |
async def receive() -> Message: | |
receive_["body"] = body | |
return receive_ | |
request._receive = receive | |
async def get_body(self, request: Request) -> bytes: | |
body = await request.body() | |
self.set_body(request) | |
return body | |
async def dispatch(self, request, call_next): | |
if request.method == 'POST': | |
self.set_body(request) | |
response = await call_next(request) | |
return response | |
''' debug code for , not test yet | |
def log_info(req_body, res_body): | |
logging.debug(req_body) | |
logging.debug(res_body) | |
async def set_body(request: Request, body: bytes): | |
async def receive() -> Message: | |
return {'type': 'http.request', 'body': body} | |
request._receive = receive | |
@app.middleware('http') | |
async def app_middleware(request: Request, call_next): | |
req_body = await request.body() | |
#await set_body(request, req_body) | |
response = await call_next(request) | |
res_body = b'' | |
async for chunk in response.body_iterator: | |
res_body += chunk | |
#logger.debug(f"response: {res_body} -- {response.status_code}") | |
#logger.debug(f"response.headers :{response.headers}") | |
#logger.debug(f"response.media_type: {response.media_type}") | |
task = BackgroundTask(log_info, req_body, res_body) | |
return Response(content=res_body, status_code=response.status_code, | |
headers=dict(response.headers), media_type=response.media_type, background=task) | |
# Exception handlers | |
def add_exception_handlers(_app: FastAPI): | |
@_app.exception_handler(ApiAuthException) | |
async def api_auth_exception_handler(request: Request, exc: ApiAuthException): | |
return await handler.api_auth_exception_handler(request, exc) | |
@_app.exception_handler(ApiException) | |
async def api_exception_handler(request: Request, exc: ApiException): | |
return await handler.api_exception_handler(request, exc) | |
add_exception_handlers(main_app) | |
add_exception_handlers(sub_app) | |
from loguru import logger | |
from starlette.routing import Match | |
logger.remove() | |
logger.add(sys.stdout, colorize=True, format="<green>{time:HH:mm:ss}</green> | {level} | <level>{message}</level>") | |
app = FastAPI() | |
@app.middleware("http") | |
async def log_middle(request: Request, call_next): | |
logger.debug(f"{request.method} {request.url}") | |
routes = request.app.router.routes | |
logger.debug("Params:") | |
for route in routes: | |
match, scope = route.matches(request) | |
if match == Match.FULL: | |
for name, value in scope["path_params"].items(): | |
logger.debug(f"\t{name}: {value}") | |
logger.debug("Headers:") | |
for name, value in request.headers.items(): | |
logger.debug(f"\t{name}: {value}") | |
response = await call_next(request) | |
return response | |
@app.get("/{param1}/{param2}") | |
async def path_operation(param1: str, param2: str): | |
return {'param1': param1, 'param2': param2} | |
''' | |
app.add_middleware( | |
CORSMiddleware, | |
#myCORSMiddleware, | |
allow_origins=origins, | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
######################################### | |
## | |
## main predict nougat function | |
## | |
######################################### | |
def rasterize_paper( | |
pdf: Union[Path, bytes], | |
outpath: Optional[Path] = None, | |
dpi: int = 96, | |
return_pil=False, | |
pages=None, | |
) -> Optional[List[io.BytesIO]]: | |
""" | |
Rasterize a PDF file to PNG images. | |
Args: | |
pdf (Path): The path to the PDF file. | |
outpath (Optional[Path], optional): The output directory. If None, the PIL images will be returned instead. Defaults to None. | |
dpi (int, optional): The output DPI. Defaults to 96. | |
return_pil (bool, optional): Whether to return the PIL images instead of writing them to disk. Defaults to False. | |
pages (Optional[List[int]], optional): The pages to rasterize. If None, all pages will be rasterized. Defaults to None. | |
Returns: | |
Optional[List[io.BytesIO]]: The PIL images if `return_pil` is True, otherwise None. | |
""" | |
pils = [] | |
if outpath is None: | |
return_pil = True | |
try: | |
if isinstance(pdf, (str, Path)): | |
logging.debug("input pdf as str") | |
pdf = pypdfium2.PdfDocument(pdf) | |
if pages is None: | |
pages = range(len(pdf)) | |
renderer = pdf.render( | |
pypdfium2.PdfBitmap.to_pil, | |
page_indices=pages, | |
scale=dpi / 72, | |
) | |
for i, image in zip(pages, renderer): | |
if return_pil: | |
page_bytes = io.BytesIO() | |
image.save(page_bytes, "bmp") | |
pils.append(page_bytes) | |
else: | |
image.save((outpath / ("%02d.png" % (i + 1))), "png") | |
else: | |
logging.debug("input pdf as bytes") | |
if pages is None: | |
pages = range(len(pdf)) | |
for page in pdf: | |
mat = fitz.Matrix(dpi / 72, dpi / 72) # sets zoom factor for 300 dpi | |
pix = page.get_pixmap(matrix=mat) | |
if return_pil: | |
#page_iobytes = io.BytesIO() | |
#pix.save(page_bytes, "PNG") | |
#page_iobytes.seek(0) | |
page_bytes = pix.pil_tobytes('PNG') | |
page_iobytes = io.BytesIO(page_bytes) | |
page_iobytes.seek(0) | |
#img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) | |
#image.save(page_iobytes, format='PNG') | |
#imgarr = np.asarray(img) | |
#img = Image.open(page_bytes) | |
#dat = np.frombuffer(page_bytes) | |
pils.append(page_iobytes) | |
else: | |
#image.save((outpath / ("%02d.png" % (i + 1))), "png") | |
img_filename = outpath + "/" + "%02d.png" % (page.number + 1) | |
pix.pil_save(img_filename, format="TIFF", dpi=(300,300)) | |
except Exception as e: | |
logging.error(e) | |
# Iterate over all the pages in the document | |
if return_pil: | |
return pils | |
def resize_with_padding(img, expected_size): | |
img.thumbnail((expected_size[0], expected_size[1])) | |
# print(img.size) | |
delta_width = expected_size[0] - img.size[0] | |
delta_height = expected_size[1] - img.size[1] | |
pad_width = delta_width // 2 | |
pad_height = delta_height // 2 | |
padding = (pad_width, pad_height, delta_width - pad_width, delta_height - pad_height) | |
return ImageOps.expand(img, padding, fill = 'white') | |
def predict_image(model_name, images, isBinarized = False, batchsize=1, markdown=True, out_path_root="./output"): | |
logging.info('*** nougat predict with input image ***') | |
global nougatModel, model_list | |
if nougatModel == None: | |
nougatModel = NougatModel.from_pretrained(default_checkpoint_path).to(torch.float16) | |
if batchsize > 0: | |
if torch.cuda.is_available(): | |
nougatModel.to("cuda") | |
else: | |
# set batch size to 1. Need to check if there are benefits for CPU conversion for >1 | |
batchsize = 1 | |
nougatModel.eval() | |
if model_name == None: | |
logging.info(f"Using {default_model_name} for example predicting.") | |
model = nougatModel | |
else: | |
#model = model_list.get(model_name, nougatModel) | |
logging.info(f"Using {model_name} for example predicting.") | |
#print(f"Using {model_name} for example predicting.") | |
model = loadModel(model_list.get(model_name, default_checkpoint_path)) | |
prepare = model.encoder.prepare_input | |
datasets = [] | |
output = "" | |
predictions = [] | |
if images!= None: | |
import torchvision.transforms as transforms | |
logging.info("we are under image to mmd convertiong") | |
sample = images.convert('RGB') | |
if isBinarized: | |
images = sauvolaBinarize(img=sample) | |
im_new = resize_with_padding(sample, (672,896)) | |
img_tensor = prepare(im_new,random_padding=False) | |
img_tensor = img_tensor.unsqueeze(0) | |
model_output = model.inference(image_tensors=img_tensor) | |
for j, output in enumerate(model_output["predictions"]): | |
predictions.append(output) | |
# check if model output is faulty | |
if markdown: | |
output = markdown_compatible(output) | |
out = "".join(predictions).strip() | |
out = re.sub(r"\n{3,}", "\n\n", out).strip() | |
out_path = None | |
if out: | |
out_path = Path(out_path_root) / Path("test").with_suffix(".mmd").name | |
out_path.parent.mkdir(parents=True, exist_ok=True) | |
if out_path.exists(): | |
os.remove(out_path) | |
with open(out_path,mode="w",encoding="utf-8") as f: | |
out = out.replace(r"\(", "$").replace(r'\)', '$').replace(r'\[', '$$').replace(r'\]', '$$') | |
f.write(out) | |
else: | |
logging.debug(f"the out is {out}") | |
return model_output, [out_path] | |
######################################### | |
## | |
## fastapi interface | |
## | |
######################################### | |
def root(): | |
"""Health check.""" | |
response = { | |
"status-code": HTTPStatus.OK, | |
"data": {}, | |
} | |
return response | |
#templates = Jinja2Templates(directory="templates") | |
#@router.get("/register/") | |
#def register(request: Request): | |
# return templates.TemplateResponse("users/register.html", {"request": request}) | |
#@router.post("/register/") | |
#async def register(request: Request, db: Session = Depends(get_db)): | |
# form = UserCreateForm(request) | |
# await form.load_data() | |
# if await form.is_valid(): | |
# user = UserCreate( | |
# username=form.username, email=form.email, password=form.password | |
# ) | |
# try: | |
# user = create_new_user(user=user, db=db) | |
# return responses.RedirectResponse( | |
# "/?msg=Successfully-Registered", status_code=status.HTTP_302_FOUND | |
# ) # default is post request, to use get request added status code 302 | |
# except IntegrityError: | |
# form.__dict__.get("errors").append("Duplicate username or email") | |
# return templates.TemplateResponse("users/register.html", form.__dict__) | |
# return templates.TemplateResponse("users/register.html", form.__dict__) | |
async def check_multi_files(files: List[UploadFile]): | |
filenames = [file.filename for file in files] | |
#print(filenames) | |
return {"filenames": [file.filename for file in files]} | |
async def check(file: bytes = File() ): | |
#print(len(file)) | |
#print(file.name) | |
return {"state": 200} | |
def convertImageFormat(imgObj, outputFormat="PNG"): | |
newImgObj = imgObj | |
if outputFormat and (imgObj.format != outputFormat): | |
imageBytesIO = io.BytesIO() | |
imgObj.save(imageBytesIO, outputFormat) | |
newImgObj = Image.open(imageBytesIO) | |
return newImgObj | |
from urllib.parse import parse_qs | |
from pydantic import BaseModel, model_validator, ValidationError | |
from typing import Optional, List, Dict | |
from fastapi import Form | |
class SubmitGeneral(BaseModel): | |
selected_model_name: str | |
class Base(BaseModel): | |
name: str | |
point: Optional[float] = None | |
is_accepted: Optional[bool] = False | |
def validate_to_json(cls, value): | |
if isinstance(value, str): | |
return cls(**json.loads(value)) | |
return value | |
def create_user(user: schemas.UserCreate, db: Session = Depends(get_db)): | |
db_user = crud.get_user_by_email(db, email=user.email) | |
if db_user: | |
raise HTTPException(status_code=400, detail="Email already registered") | |
return crud.create_user(db=db, user=user) | |
def read_users(skip: int = 0, limit: int = 100, db: Session = Depends(get_db)): | |
users = crud.get_users(db, skip=skip, limit=limit) | |
return users | |
def stringToBool(string:str) -> bool: | |
if (string.lower()=="true"): | |
return True | |
elif(string.lower()=="false"): | |
return False | |
else: return False | |
def sauvolaBinarize(img:Image = None, dest:str = None): | |
if dest!= None: | |
cv2_image = cv2.imread(dest) | |
cv2_image = cv2.cvtColor(cv2_image, cv2.COLOR_RGB2BGR) | |
if Image != None: | |
#convert to cv2 format | |
cv2_image = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) | |
if cv2_image is None: | |
return None | |
#convert to gray format | |
gray_image = cv2.cvtColor(cv2_image, cv2.COLOR_BGR2GRAY) | |
#binarize with sauvola algorithm | |
adaptive_threshold_img= cv2.adaptiveThreshold(gray_image, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 5, 5) | |
#too slow | |
#adaptive_threshold_img = binarization.sauvola(gray_image) | |
#convert back to RGB format | |
img = cv2.cvtColor(adaptive_threshold_img,cv2.COLOR_GRAY2RGB) | |
img = binarization.convert_from_cv2_to_image(img) | |
#dest_filename, dest_ext = os.path.splitext(dest) | |
img.save( './uploads/testBinarization' + ".png", "PNG") | |
return img | |
async def predict( | |
request: Request, | |
selectedModel: str = Form(...), | |
isBinarized: str = Form(...), | |
file: UploadFile = File(...), | |
start: int = None, stop: int = None, | |
): | |
""" | |
Perform predictions on a PDF document and return the extracted text in Markdown format. | |
Args: | |
file (UploadFile): The uploaded PDF file to process. | |
start (int, optional): The starting page number for prediction. | |
stop (int, optional): The ending page number for prediction. | |
Returns: | |
str: The extracted text in Markdown format. | |
""" | |
global nougatModel | |
logging.info(f'Content type of file: {file.content_type}') | |
logging.info(f'selectedModel values : {selectedModel}') | |
#if content_type != "application/vnd.api+json": | |
#print(f'file type of file: {file.type}') | |
logging.debug(f'Content uploaded file: {file.filename}') | |
logging.debug(f'request values: {request.values}') | |
logging.debug(f'request headers: {request.headers}') | |
#parsed_url = urlparse(request) | |
#model_name = parse_qs(parsed_url.query)['model'][0] | |
model_name = selectedModel | |
is_binarized = stringToBool(isBinarized) | |
if model_name == None: | |
model = nougatModel | |
else: | |
#model = model_list.get(model_name, nougatModel) | |
model = loadModel(model_list.get(model_name, default_checkpoint_path)) | |
##test code | |
upload_dir = os.path.join(os.getcwd(), "uploads") | |
# Create the upload directory if it doesn't exist | |
if not os.path.exists(upload_dir): | |
os.makedirs(upload_dir) | |
dest = os.path.join(upload_dir, file.filename) | |
#result = [] | |
#async with file.file as f: | |
# async for line in io.TextIOWrapper(f, encoding='utf-8'): | |
# result.append(len(line)) | |
finger_printer = None | |
blob_file = dest | |
if file.content_type == "application/image": | |
try: | |
with open(dest, 'wb') as f: | |
logging.info(f"save uploading files to {dest}") | |
imgbin = await file.read() | |
f.write(imgbin) | |
md5 = hashlib.md5(imgbin).hexdigest() | |
finger_printer = md5 | |
save_path = SAVE_DIR / md5 | |
except Exception: | |
return {"message": "There was an error uploading the file"} | |
finally: | |
f.close() | |
#logging.info(f"input image type is {type(imgbin)}") | |
#logging.info(f"input image type is {type(f)}") | |
img = Image.open(io.BytesIO(imgbin)) | |
logging.info(f"uploading Image Type: {type(img)}") | |
if img.format != "PNG": | |
dest_filename, dest_ext = os.path.splitext(dest) | |
#with io.BytesIO() as f: | |
# img.save(f, format='PNG') | |
# f.seek(0) | |
# img = Image.open(f) | |
img = convertImageFormat(img) | |
img.save( dest_filename + ".png", "PNG") | |
model_output,_ = predict_image(model_name = model_name, images=img, isBinarized=is_binarized) | |
logging.debug(f"predict output as: {model_output}") | |
predictions = [""] | |
for j, output in enumerate(model_output["predictions"]): | |
if model_output["repeats"][j] is not None: | |
if model_output["repeats"][j] > 0: | |
disclaimer = "\n\n+++ ==WARNING: Truncated because of repetitions==\n%s\n+++\n\n" | |
else: | |
disclaimer = ( | |
"\n\n+++ ==ERROR: No output for this page==\n%s\n+++\n\n" | |
) | |
rest = close_envs(model_output["repetitions"][j]).strip() | |
if len(rest) > 0: | |
disclaimer = disclaimer % rest | |
else: | |
disclaimer = "" | |
else: | |
disclaimer = "" | |
predictions[0] = ( markdown_compatible(output) + disclaimer ) | |
(save_path / "pages").mkdir(parents=True, exist_ok=True) | |
elif file.content_type == "application/pdf": | |
try: | |
with open(dest, 'wb') as f: | |
logging.debug(f"save files to {dest}") | |
pdfbin = await file.read() | |
f.write(pdfbin) | |
pdf = fitz.open("pdf", pdfbin) | |
md5 = hashlib.md5(pdfbin).hexdigest() | |
finger_printer = md5 | |
save_path = SAVE_DIR / md5 | |
logging.info(save_path) | |
#shutil.copyfileobj(file.file, f) | |
except Exception: | |
return {"message": "There was an error uploading the file"} | |
finally: | |
f.close() | |
logging.info(f"{file.filename} uploaded successfully, length {len(pdfbin)}") | |
if start is not None and stop is not None: | |
pages = list(range(start - 1, stop)) | |
else: | |
pages = list(range(len(pdf))) | |
predictions = [""] * len(pages) | |
dellist = [] | |
if save_path.exists(): | |
for computed in (save_path / "pages").glob("*.mmd"): | |
logging.debug(f"computed pdf content is {computed}") | |
try: | |
idx = int(computed.stem) - 1 | |
if idx not in pages: | |
i = pages.index(idx) | |
logging.debug("skip page", idx + 1) | |
predictions[i] = computed.read_text(encoding="utf-8") | |
dellist.append(idx) | |
except Exception as e: | |
logging.debug(e) | |
compute_pages = pages.copy() | |
for el in dellist: | |
compute_pages.remove(el) | |
images = rasterize_paper(pdf, pages=compute_pages) | |
#images = rasterize_paper(dest, pages=compute_pages) | |
dataset = ImageDataset( | |
images, | |
partial(model.encoder.prepare_input, random_padding=False), | |
) | |
dataloader = torch.utils.data.DataLoader( | |
dataset, | |
batch_size=BATCHSIZE, | |
pin_memory=True, | |
shuffle=False, | |
) | |
for idx, sample in tqdm(enumerate(dataloader), total=len(dataloader)): | |
if sample is None: | |
continue | |
model_output = model.inference(image_tensors=sample) | |
for j, output in enumerate(model_output["predictions"]): | |
if model_output["repeats"][j] is not None: | |
if model_output["repeats"][j] > 0: | |
disclaimer = "\n\n+++ ==WARNING: Truncated because of repetitions==\n%s\n+++\n\n" | |
else: | |
disclaimer = ( | |
"\n\n+++ ==ERROR: No output for this page==\n%s\n+++\n\n" | |
) | |
rest = close_envs(model_output["repetitions"][j]).strip() | |
if len(rest) > 0: | |
disclaimer = disclaimer % rest | |
else: | |
disclaimer = "" | |
else: | |
disclaimer = "" | |
predictions[pages.index(compute_pages[idx * BATCHSIZE + j])] = ( | |
markdown_compatible(output) + disclaimer | |
) | |
(save_path / "pages").mkdir(parents=True, exist_ok=True) | |
pdf.save(save_path / "doc.pdf") | |
if len(images) > 0: | |
thumb = Image.open(images[0]) | |
thumb.thumbnail((400, 400)) | |
thumb.save(save_path / "thumb.jpg") | |
for idx, page_num in enumerate(pages): | |
(save_path / "pages" / ("%02d.mmd" % (page_num + 1))).write_text( | |
predictions[idx], encoding="utf-8" | |
) | |
final = "".join(predictions).strip() | |
(save_path / "doc.mmd").write_text(final, encoding="utf-8") | |
#saving to database inference table in async mode | |
await aync_save2database(blob_file = Path(blob_file), finger_printer=finger_printer, model_name = model_name, result=final) | |
logging.info("***********************************") | |
logging.info("save successfully") | |
logging.info("***********************************") | |
logging.info(f'save path is {save_path}') | |
logging.info(f"final output content length is {len(final)}") | |
return final | |
######################################### | |
## | |
## gradio web | |
## | |
######################################### | |
def nougat_predict(input_files, output_path, model_name, batchsize, markdown,recompute): | |
logging.info(f'*** nougat predict with input :{input_files} ***') | |
global nougatModel, model_list, selected_model_name | |
if nougatModel == None: | |
nougatModel = NougatModel.from_pretrained(default_checkpoint_path).to(torch.float16) | |
if batchsize > 0: | |
if torch.cuda.is_available(): | |
nougatModel.to("cuda") | |
else: | |
# set batch size to 1. Need to check if there are benefits for CPU conversion for >1 | |
batchsize = 1 | |
nougatModel.eval() | |
logging.info(f"Using {model_name} for example predicting.") | |
if model_name == None: | |
model = nougatModel | |
else: | |
#model = model_list.get(model_name, nougatModel) | |
model = loadModel(model_list.get(model_name, default_checkpoint_path)) | |
datasets = [] | |
for pdf in input_files: | |
#if not pdf.exists(): | |
if not os.path.exists(pdf): | |
continue | |
if output_path: | |
out_path = output_path / pdf.with_suffix(".mmd").name | |
if out_path.exists() and not recompute: | |
logging.info( | |
f"Skipping {pdf.name}, already computed. Run with --recompute to convert again." | |
) | |
continue | |
try: | |
dataset = LazyDataset( | |
pdf, partial(model.encoder.prepare_input, random_padding=False) | |
) | |
except fitz.fitz.FileDataError: | |
logging.info(f"Could not load file {str(pdf)}.") | |
continue | |
datasets.append(dataset) | |
if len(datasets) == 0: | |
logging.info(f'*** nougat out files :{out_path} ***') | |
return out_path | |
dataloader = torch.utils.data.DataLoader( | |
ConcatDataset(datasets), | |
batch_size=batchsize, | |
shuffle=False, | |
collate_fn=LazyDataset.ignore_none_collate, | |
) | |
predictions = [] | |
output_file_list = [] | |
file_index = 0 | |
page_num = 0 | |
for i, (sample, is_last_page) in enumerate(tqdm(dataloader)): | |
model_output = model.inference(image_tensors=sample) | |
# check if model output is faulty | |
for j, output in enumerate(model_output["predictions"]): | |
if page_num == 0: | |
logging.info( | |
"Processing file %s with %i pages" | |
% (datasets[file_index].name, datasets[file_index].size) | |
) | |
page_num += 1 | |
if output.strip() == "[MISSING_PAGE_POST]": | |
# uncaught repetitions -- most likely empty page | |
predictions.append(f"\n\n[MISSING_PAGE_EMPTY:{page_num}]\n\n") | |
elif model_output["repeats"][j] is not None: | |
if model_output["repeats"][j] > 0: | |
# If we end up here, it means the output is most likely not complete and was truncated. | |
logging.warning(f"Skipping page {page_num} due to repetitions.") | |
predictions.append(f"\n\n[MISSING_PAGE_FAIL:{page_num}]\n\n") | |
else: | |
# If we end up here, it means the document page is too different from the training domain. | |
# This can happen e.g. for cover pages. | |
predictions.append( | |
f"\n\n[MISSING_PAGE_EMPTY:{i*batchsize+j+1}]\n\n" | |
) | |
else: | |
if markdown: | |
output = markdown_compatible(output) | |
predictions.append(output) | |
if is_last_page[j]: | |
out = "".join(predictions).strip() | |
out = re.sub(r"\n{3,}", "\n\n", out).strip() | |
out = out.replace(r"\(", "$").replace(r'\)', '$').replace(r'\[', '$$').replace(r'\]', '$$') | |
if output_path: | |
out_path = output_path / Path(is_last_page[j]).with_suffix(".mmd").name | |
out_path.parent.mkdir(parents=True, exist_ok=True) | |
out_path.write_text(out, encoding="utf-8") | |
output_file_list.append(out_path) | |
else: | |
logging.debug(out, "\n\n") | |
predictions = [] | |
page_num = 0 | |
file_index += 1 | |
logging.info(f'the generated markdown files are : {output_file_list}') | |
return output_file_list | |
def get_image(url_list): | |
query_parameters = {"downloadformat": "image"} | |
for url in url_list: | |
url_parsed = urlparse(url) | |
file_path = Path(url_parsed.path) | |
new_path = Path("./input") | |
new_file = os.path.join(new_path, os.path.basename(file_path)) | |
response = requests.get(url, stream=True) | |
if response.ok: | |
with open(new_file, mode="wb") as file: | |
for data in tqdm(response.iter_content()): | |
file.write(data) | |
# Download pdf from a given link | |
def get_pdf(pdf_link): | |
# Generate a unique filename | |
unique_filename = f"input/downloaded_paper_{uuid.uuid4().hex}.pdf" | |
# Send a GET request to the PDF link | |
response = requests.get(pdf_link) | |
if response.status_code == 200: | |
# Save the PDF content to a local file | |
with open(unique_filename, 'wb') as pdf_file: | |
pdf_file.write(response.content) | |
logging.info("PDF downloaded successfully.") | |
else: | |
logging.info("Failed to download the PDF.") | |
return unique_filename | |
def sha256sum(filename): | |
h = hashlib.sha256() | |
b = bytearray(128*1024) | |
mv = memoryview(b) | |
with open(filename, 'rb', buffering=0) as f: | |
while n := f.readinto(mv): | |
h.update(mv[:n]) | |
return h.hexdigest() | |
def readImage(): | |
try: | |
fin = open("woman.jpg", "rb") | |
img = fin.read() | |
return img | |
except IOError as e: | |
logging.info ("Error %d: %s" % (e.args[0],e.args[1])) | |
sys.exit(1) | |
finally: | |
if fin: | |
fin.close() | |
# converts from python to postgres | |
def convert_To_Binary(filename): | |
with open(filename, 'rb') as file: | |
data = file.read() | |
return psycopg.Binary(data) | |
def _adapt_array(filename): | |
with open(filename, 'rb') as file: | |
out = io.BytesIO(file.read()) | |
out.seek(0) | |
return bytearray(out.read()) | |
# converts from postgres to python | |
def _typecast_array(value, cur): | |
if value is None: | |
return None | |
data = psycopg.Binary(value, cur) | |
bdata = io.BytesIO(data[1:-1]) | |
bdata.seek(0) | |
return np.loadtxt(bdata) | |
async def aync_save2database(blob_file, finger_printer, model_name, result): | |
#saving to database inference table | |
blob = None | |
if blob_file != None: | |
blob = _adapt_array(blob_file) | |
inference = InferenceCreate( | |
blobContent= blob, | |
textContent= "", | |
fileName = blob_file.name, | |
fingerPrint=finger_printer, | |
selectedModel= model_name, | |
result=result, | |
) | |
async with async_session() as db_session: | |
my_inference_on_db = await crud.inference.create_async_inference(db_session, obj_in=inference) | |
return my_inference_on_db | |
def save2database(blob_file, finger_printer, model_name, result): | |
#saving to database inference table | |
blob = None | |
if blob_file != None: | |
blob = _adapt_array(blob_file) | |
inference = InferenceCreate( | |
blobContent= blob, | |
textContent= "", | |
fileName = blob_file.name, | |
fingerPrint=finger_printer, | |
selectedModel= model_name, | |
result=result, | |
) | |
db_session = next(get_db()) | |
my_inference_on_db = crud.inference.create_inference(db_session = db_session, obj_in = inference) | |
#my_inference_on_db = crud.inference.inferenceAsync(db_session, obj_in=inference) | |
return my_inference_on_db | |
# predict function / driver function | |
def perapare_reader(pdf_file, pdf_link, img_file, model_name, async_model=False): | |
global model_list, selected_model_name | |
logging.info(f'*** paper_read ****') | |
logging.info(f'*** using model ****{model_name}') | |
output_path = Path("./output") | |
markdown = True | |
batchsize = BATCH_SIZE | |
img = None | |
blob_file = None | |
finger_printer = "" | |
if img_file is not None: | |
logging.info(f'*** handing image file : {img_file} ****') | |
file_name = img_file.name | |
blob_file = img_file | |
img = Image.open(img_file) | |
logging.info(f"Image Type: {type(img)}") | |
finger_printer = sha256sum(img_file) | |
_, output_files = predict_image(model_name, img, markdown=True, out_path_root=Path("output")) | |
else: | |
if pdf_file is None: | |
logging.info(f'*** handing pdf link paper :{pdf_link} ****') | |
if pdf_link == '': | |
logging.info("No file is uploaded and No link is provided") | |
return "No data provided. Upload a pdf file or provide a pdf link and try again!" | |
else: | |
file_name = get_pdf(pdf_link) | |
blob_file = Path(file_name) | |
else: | |
logging.info(f'*** handing pdf paper :{pdf_file}***') | |
file_name = pdf_file.name | |
blob_file = pdf_file | |
pdf_name = pdf_file.name.split('/')[-1].split('.')[0] | |
input_files = file_name if isinstance(file_name, os.PathLike) else Path(file_name), | |
finger_printer = sha256sum(input_files[0]) | |
output_files = nougat_predict(input_files=input_files, output_path=output_path, model_name = model_name, | |
batchsize = batchsize, markdown = markdown, recompute=True) | |
logging.info(f'the generated markdown file is : {output_files}') | |
# Open the file for reading | |
#file_name = file_name.split('/')[-1][:-4] | |
#with open(f'output/{file_name}.mmd', 'r') as file: | |
fileList = [] | |
content = None | |
for output_file in output_files: | |
if output_file != None: | |
with open(output_file, 'r+', encoding="utf-8") as file: | |
content = file.read() | |
# switch math delimiters | |
content = content.replace(r"\\(", "$").replace(r'\\)', '$').replace(r'\\[', '$$').replace(r'\\]', '$$') | |
content = content.replace(r"\(", "$").replace(r'\)', '$').replace(r'\[', '$$').replace(r'\]', '$$') | |
fileList.append(str(output_file)) | |
else: | |
fileList.append("") | |
#download_files = gr.File(value=fileList) | |
if content: | |
logging.info("***********************************") | |
logging.info("convert successfully") | |
logging.info("***********************************") | |
if async_model: | |
try: | |
#see if there is a loop already running. If there is, reuse it. | |
loop = asyncio.get_running_loop() | |
except RuntimeError: | |
# Create new event loop if one is not running | |
loop = asyncio.new_event_loop() | |
asyncio.set_event_loop(loop) | |
loop.run_until_complete(aync_save2database(blob_file, finger_printer, model_name, content)) | |
else: | |
save2database(blob_file, finger_printer, model_name, content) | |
logging.info("***********************************") | |
logging.info("save successfully") | |
logging.info("***********************************") | |
else: | |
content = "convert failed" | |
logging.info("***********************************") | |
logging.info("convert failed") | |
logging.info("***********************************") | |
return content, gr.Markdown("Converted file list :"+ fileList[0]) , gr.DownloadButton(value="/file=" + str(output_files[0]), visible=True) | |
import asyncio | |
# Handling examples in Gradio app | |
def process_example(pdf_file,pdf_link,img_file, model_name): | |
logging.info('*** process_example ****') | |
if model_name == None: | |
model_name = default_model_name | |
ocr_content, output_files, download_btn = perapare_reader(pdf_file,pdf_link,img_file, model_name) | |
#print(f"ocr_content is :{ocr_content}") | |
return gr.update(value=ocr_content), output_files, download_btn | |
css = """ | |
#mkd { | |
height: 500px; | |
overflow: auto; | |
border: 1px solid #ccc; | |
} | |
""" | |
url_list = ["https://cdn.mathpix.com/snip/images/Hm62Ib-dDZOseYuVNN8k34IhBY18KglOrM7qETOqXZI.original.fullsize.png", | |
"https://cdn.mathpix.com/snip/images/lSL07DYTL1bdjzL2mpNyVg17JmqKwgugMLyGuxkLgLg.original.fullsize.png"] | |
# Function for generating some examples | |
def generate_examples(): | |
extracted_list = [["./input/test.pdf", "", None], | |
[None, "https://arxiv.org/pdf/2308.08316.pdf", None], | |
[None, "", "./input/Hm62Ib-dDZOseYuVNN8k34IhBY18KglOrM7qETOqXZI.original.fullsize.png"]] | |
return gr.Dataset.update(samples=extracted_list) | |
#get_pdf("https://arxiv.org/pdf/2308.13418.pdf") | |
#get_image(url_list) | |
output_files = [] | |
gr.set_static_paths(paths=["output/"]) | |
gr.set_static_paths(paths=["images/"]) | |
#with gr.Blocks(css=css) as demo: | |
with gr.Blocks(theme='reilnuud/polite') as demo: | |
gr.HTML("""<h1><left>AIWorm.CN <img src="file/images/ic_worm_ai.png" style="display: inline-block; margin: 0;" width="100" height="50" /><left><h1>""") | |
gr.HTML("""<h1><center>Using Nougat: Neural Optical Understanding for Academic Documents<center><h1>""") | |
gr.HTML("<h3><center>the orignal Nougat-OCR Prject is done by Lukas Blecher et al. <a href='https://arxiv.org/pdf/2308.13418.pdf' target='_blank'>Paper</a>, <a href='https://facebookresearch.github.io/nougat/'>Project</a><center></h3>") | |
gr.HTML("<h3>The model below: <br>\ | |
- model nougat_big_facebook is the facebook nougat project base model (only support pdf/english input), refer to <a href='https://huggingface.co/facebook/nougat-base '> facebook nougat-base</a> <br>\ | |
- model nougat_small_facebook is the facebook nougat model (only support pdf/english input), refer to <a href='https://huggingface.co/facebook/nougat-small'> facebook nougat-smalle</a> <br> \ | |
- model nougat_middel_cn is the fine-tuned chinese version adopt to pdf and image input , refer to <a href='https://huggingface.co/zphilip48/nougat-middle-cn'> zphilip48's nougat-middle-cn</a> <br>\ | |
- model nougat_small_cn is the fine-tuned chinese version adopt to pdf and image input </h3>") | |
model_name_list = ["nougat_middle_ocr", | |
"nougat_middle_cn", | |
"nougat_small_facebook", | |
"nougat_small_ocr_ocrpadded", | |
"nougat_big_facebook", | |
"nougat_small_cn"] | |
selected_model_name = gr.Dropdown(choices=model_name_list, label="Select models ", value=model_name_list[1], interactive=True) | |
with gr.Row(): | |
mkd = gr.Markdown('<h4><center>Upload a PDF</center></h4>') | |
mkd = gr.Markdown('<h4><center><i>OR</i></center></h4>') | |
mkd = gr.Markdown('<h4><center>Provide a PDF link</center></h4>') | |
mkd = gr.Markdown('<h4><center><i>OR</i></center></h4>') | |
mkd = gr.Markdown('<h4><center>Upload a image(PNG or JPGE)</center></h4>') | |
with gr.Row(equal_height=True): | |
pdf_file = gr.File(label='PDF📃', file_count='single') | |
pdf_link = gr.Textbox(placeholder='Enter an Arxiv link here', label='PDF link🔗🌐') | |
img_file = gr.File(label='IMG📃', file_count='single') | |
with gr.Row(): | |
btn = gr.Button('Run NOUGAT🍫') | |
clr = gr.Button('Clear🚿') | |
if len(output_files) == 0: | |
download_btn_visible = False | |
download_file_path = "./output/test.mmd" | |
else: | |
download_btn_visible = True | |
download_file_path = output_files[0] | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown("""<h3><center><font color="red">The converting will take some time due to the GPU P100 as backend </font></center></h3>""") | |
gr.Markdown("<h3><center>Gradio Markdown might not render complex latex equation correctly, try download it👇:</center></h3>") | |
download_btn = gr.DownloadButton('Download The Converted Markdown File ', value="/file=" + download_file_path, visible=True) | |
#with open("/workspace/nougat-latex/output/1910.13461.mmd", 'r') as test_file: | |
# content = test_file.read() | |
# content = content.replace(r"\\(", "$").replace(r'\\)', '$').replace(r'\\[', '$$').replace(r'\\]', '$$') | |
# content = content.replace(r"\(", "$").replace(r'\)', '$').replace(r'\[', '$$').replace(r'\]', '$$') | |
#gr.Markdown(content, latex_delimiters=[{ "left": r"$", "right": r"$", "display": False },{ "left": r"$$", "right": r"$$", "display": True }]) | |
#output_files = gr.File(interactive=False) | |
output_headline = gr.Markdown("<h3><center>PDF converted into markup language through Nougat-OCR👇:</center></h3>") | |
output_files = gr.Markdown("Converted file list") | |
parsed_output = gr.Markdown(r'# OCR Output📃🔤',elem_id='mkd', latex_delimiters=[{ "left": "$", "right": "$", "display": False },{ "left": "$$", "right": "$$", "display": True }, { "left": "\(", "right": "\)", "display": False }, { "left": "\[", "right": "\]", "display": True }]) | |
btn.click(perapare_reader, inputs=[pdf_file, pdf_link, img_file, selected_model_name], outputs=[parsed_output, output_files, download_btn]) | |
#download_btn.click(download_file_path) | |
clr.click(lambda : (gr.update(value=None), | |
gr.update(value=""), | |
gr.update(value=None), | |
gr.update(value=None)), | |
[], | |
[pdf_file, pdf_link, img_file, parsed_output] | |
) | |
def update_visibility(radio, text): # Accept the event argument, even if not used | |
value = radio # Get the selected value from the radio button | |
if value == "show": | |
text.visible = True | |
else: | |
text.visible = False | |
example_btn = gr.Button('re-Run Example🍫') | |
examples = gr.Examples( | |
[["./input/test.pdf", "", None, "nougat_big_facebook"], | |
[None, "https://arxiv.org/pdf/2308.08316.pdf", None, "nougat_middle_cn"], | |
[None, "", "./input/Hm62Ib-dDZOseYuVNN8k34IhBY18KglOrM7qETOqXZI.original.fullsize.png","nougat_small_cn"]], | |
inputs = [pdf_file, pdf_link, img_file, selected_model_name], | |
outputs = [parsed_output,output_files, download_btn], | |
fn=process_example, | |
cache_examples=True, | |
run_on_click=True, | |
label='Click on any Examples below to get Nougat OCR results quickly:' | |
) | |
example_btn.click(generate_examples, outputs=[examples.dataset]) | |
#demo.load(parsed_output, None, text_out, every=3) | |
demo.queue() | |
app = gr.mount_gradio_app(app, demo, path="/") | |
if __name__ == "__main__": | |
#uvicorn.run(app, host="0.0.0.0", port=8866,log_level="debug", | |
# ssl_keyfile='/workspace/nougat-latex/lzs.chrdw.ml.key', | |
# ssl_certfile='/workspace/nougat-latex/fullchain.cer') | |
uvicorn.run("__main__:app", host="0.0.0.0", port=8503,log_level="debug", workers=1) | |
#demo.launch(debug=True,share=True, server_name="0.0.0.0",server_port=8866) |