nougat-latex / nougat_api.py
zphilip48's picture
update the application api
7566b08
"""
Copyright (c) Meta Platforms, Inc. and affiliates.
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1" # just use one GPU on big machine
import torch
print ('Available devices ', torch.cuda.device_count())
print ('Current cuda device ', torch.cuda.current_device())
assert torch.cuda.device_count() == 1
print('GPU Device name:', torch.cuda.get_device_name(torch.cuda.current_device()))
import sys
from functools import partial
from http import HTTPStatus
from fastapi import FastAPI, File, UploadFile
from PIL import Image
from pathlib import Path
import hashlib
from fastapi.middleware.cors import CORSMiddleware
import fitz
import torch
from nougat import NougatModel
from nougat.postprocessing import markdown_compatible, close_envs
from nougat.utils.dataset import ImageDataset
from nougat.utils.checkpoint import get_checkpoint
from nougat.dataset.rasterize import rasterize_paper
from tqdm import tqdm
import logging
import pypdfium2
import io
from typing import Optional, List, Union
SAVE_DIR = Path("./pdfs")
BATCHSIZE = os.environ.get("NOUGAT_BATCHSIZE", 6)
NOUGAT_CHECKPOINT = get_checkpoint()
if NOUGAT_CHECKPOINT is None:
print(
"Set environment variable 'NOUGAT_CHECKPOINT' with a path to the model checkpoint!."
)
sys.exit(1)
app = FastAPI(title="Nougat API")
origins = ["http://localhost", "http://127.0.0.1","http://43.155.187.132"]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
model = None
def rasterize_paper(
pdf: Union[Path, bytes],
outpath: Optional[Path] = None,
dpi: int = 96,
return_pil=False,
pages=None,
) -> Optional[List[io.BytesIO]]:
"""
Rasterize a PDF file to PNG images.
Args:
pdf (Path): The path to the PDF file.
outpath (Optional[Path], optional): The output directory. If None, the PIL images will be returned instead. Defaults to None.
dpi (int, optional): The output DPI. Defaults to 96.
return_pil (bool, optional): Whether to return the PIL images instead of writing them to disk. Defaults to False.
pages (Optional[List[int]], optional): The pages to rasterize. If None, all pages will be rasterized. Defaults to None.
Returns:
Optional[List[io.BytesIO]]: The PIL images if `return_pil` is True, otherwise None.
"""
pils = []
if outpath is None:
return_pil = True
try:
if isinstance(pdf, (str, Path)):
pdf = pypdfium2.PdfDocument(pdf)
if pages is None:
pages = range(len(pdf))
renderer = pdf.render(
pypdfium2.PdfBitmap.to_pil,
page_indices=pages,
scale=dpi / 72,
)
for i, image in zip(pages, renderer):
if return_pil:
page_bytes = io.BytesIO()
image.save(page_bytes, "bmp")
pils.append(page_bytes)
else:
image.save((outpath / ("%02d.png" % (i + 1))), "png")
except Exception as e:
logging.error(e)
if return_pil:
return pils
@app.on_event("startup")
async def load_model(
checkpoint: str = NOUGAT_CHECKPOINT,
):
global model
if model is None:
model = NougatModel.from_pretrained(checkpoint).to(torch.bfloat16)
if torch.cuda.is_available():
model.to("cuda")
model.eval()
@app.get("/")
def root():
"""Health check."""
response = {
"status-code": HTTPStatus.OK,
"data": {},
}
return response
import shutil
@app.post("/predict/")
async def predict(
file: UploadFile = File(), start: int = None, stop: int = None
):
"""
Perform predictions on a PDF document and return the extracted text in Markdown format.
Args:
file (UploadFile): The uploaded PDF file to process.
start (int, optional): The starting page number for prediction.
stop (int, optional): The ending page number for prediction.
Returns:
str: The extracted text in Markdown format.
"""
##test code
upload_dir = os.path.join(os.getcwd(), "uploads")
# Create the upload directory if it doesn't exist
if not os.path.exists(upload_dir):
os.makedirs(upload_dir)
dest = os.path.join(upload_dir, file.filename)
try:
with open(dest, 'wb') as f:
pdfbin = file.file.read()
shutil.copyfileobj(file.file, f)
pdf = fitz.open("pdf", pdfbin)
print(pdf)
md5 = hashlib.md5(pdfbin).hexdigest()
print(md5)
except Exception:
return {"message": "There was an error uploading the file"}
finally:
file.file.close()
save_path = SAVE_DIR / md5
print(save_path)
if start is not None and stop is not None:
pages = list(range(start - 1, stop))
else:
pages = list(range(len(pdf)))
predictions = [""] * len(pages)
dellist = []
if save_path.exists():
for computed in (save_path / "pages").glob("*.mmd"):
try:
idx = int(computed.stem) - 1
if idx in pages:
i = pages.index(idx)
print("skip page", idx + 1)
predictions[i] = computed.read_text(encoding="utf-8")
dellist.append(idx)
except Exception as e:
print(e)
compute_pages = pages.copy()
for el in dellist:
compute_pages.remove(el)
images = rasterize_paper(pdf, pages=compute_pages)
global model
dataset = ImageDataset(
images,
partial(model.encoder.prepare_input, random_padding=False),
)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=BATCHSIZE,
pin_memory=True,
shuffle=False,
)
for idx, sample in tqdm(enumerate(dataloader), total=len(dataloader)):
if sample is None:
continue
model_output = model.inference(image_tensors=sample)
for j, output in enumerate(model_output["predictions"]):
if model_output["repeats"][j] is not None:
if model_output["repeats"][j] > 0:
disclaimer = "\n\n+++ ==WARNING: Truncated because of repetitions==\n%s\n+++\n\n"
else:
disclaimer = (
"\n\n+++ ==ERROR: No output for this page==\n%s\n+++\n\n"
)
rest = close_envs(model_output["repetitions"][j]).strip()
if len(rest) > 0:
disclaimer = disclaimer % rest
else:
disclaimer = ""
else:
disclaimer = ""
predictions[pages.index(compute_pages[idx * BATCHSIZE + j])] = (
markdown_compatible(output) + disclaimer
)
(save_path / "pages").mkdir(parents=True, exist_ok=True)
pdf.save(save_path / "doc.pdf")
if len(images) > 0:
thumb = Image.open(images[0])
thumb.thumbnail((400, 400))
thumb.save(save_path / "thumb.jpg")
for idx, page_num in enumerate(pages):
(save_path / "pages" / ("%02d.mmd" % (page_num + 1))).write_text(
predictions[idx], encoding="utf-8"
)
final = "".join(predictions).strip()
(save_path / "doc.mmd").write_text(final, encoding="utf-8")
return final
def main():
import uvicorn
uvicorn.run("app:app", host="0.0.0.0", port=8866)
if __name__ == "__main__":
main()