""" Copyright (c) Meta Platforms, Inc. and affiliates. This source code is licensed under the MIT license found in the LICENSE file in the root directory of this source tree. """ import os os.environ["CUDA_VISIBLE_DEVICES"] = "1" # just use one GPU on big machine import torch print ('Available devices ', torch.cuda.device_count()) print ('Current cuda device ', torch.cuda.current_device()) assert torch.cuda.device_count() == 1 print('GPU Device name:', torch.cuda.get_device_name(torch.cuda.current_device())) import sys from functools import partial from http import HTTPStatus from fastapi import FastAPI, File, UploadFile from PIL import Image from pathlib import Path import hashlib from fastapi.middleware.cors import CORSMiddleware import fitz import torch from nougat import NougatModel from nougat.postprocessing import markdown_compatible, close_envs from nougat.utils.dataset import ImageDataset from nougat.utils.checkpoint import get_checkpoint from nougat.dataset.rasterize import rasterize_paper from tqdm import tqdm import logging import pypdfium2 import io from typing import Optional, List, Union SAVE_DIR = Path("./pdfs") BATCHSIZE = os.environ.get("NOUGAT_BATCHSIZE", 6) NOUGAT_CHECKPOINT = get_checkpoint() if NOUGAT_CHECKPOINT is None: print( "Set environment variable 'NOUGAT_CHECKPOINT' with a path to the model checkpoint!." ) sys.exit(1) app = FastAPI(title="Nougat API") origins = ["http://localhost", "http://127.0.0.1","http://43.155.187.132"] app.add_middleware( CORSMiddleware, allow_origins=origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) model = None def rasterize_paper( pdf: Union[Path, bytes], outpath: Optional[Path] = None, dpi: int = 96, return_pil=False, pages=None, ) -> Optional[List[io.BytesIO]]: """ Rasterize a PDF file to PNG images. Args: pdf (Path): The path to the PDF file. outpath (Optional[Path], optional): The output directory. If None, the PIL images will be returned instead. Defaults to None. dpi (int, optional): The output DPI. Defaults to 96. return_pil (bool, optional): Whether to return the PIL images instead of writing them to disk. Defaults to False. pages (Optional[List[int]], optional): The pages to rasterize. If None, all pages will be rasterized. Defaults to None. Returns: Optional[List[io.BytesIO]]: The PIL images if `return_pil` is True, otherwise None. """ pils = [] if outpath is None: return_pil = True try: if isinstance(pdf, (str, Path)): pdf = pypdfium2.PdfDocument(pdf) if pages is None: pages = range(len(pdf)) renderer = pdf.render( pypdfium2.PdfBitmap.to_pil, page_indices=pages, scale=dpi / 72, ) for i, image in zip(pages, renderer): if return_pil: page_bytes = io.BytesIO() image.save(page_bytes, "bmp") pils.append(page_bytes) else: image.save((outpath / ("%02d.png" % (i + 1))), "png") except Exception as e: logging.error(e) if return_pil: return pils @app.on_event("startup") async def load_model( checkpoint: str = NOUGAT_CHECKPOINT, ): global model if model is None: model = NougatModel.from_pretrained(checkpoint).to(torch.bfloat16) if torch.cuda.is_available(): model.to("cuda") model.eval() @app.get("/") def root(): """Health check.""" response = { "status-code": HTTPStatus.OK, "data": {}, } return response import shutil @app.post("/predict/") async def predict( file: UploadFile = File(), start: int = None, stop: int = None ): """ Perform predictions on a PDF document and return the extracted text in Markdown format. Args: file (UploadFile): The uploaded PDF file to process. start (int, optional): The starting page number for prediction. stop (int, optional): The ending page number for prediction. Returns: str: The extracted text in Markdown format. """ ##test code upload_dir = os.path.join(os.getcwd(), "uploads") # Create the upload directory if it doesn't exist if not os.path.exists(upload_dir): os.makedirs(upload_dir) dest = os.path.join(upload_dir, file.filename) try: with open(dest, 'wb') as f: pdfbin = file.file.read() shutil.copyfileobj(file.file, f) pdf = fitz.open("pdf", pdfbin) print(pdf) md5 = hashlib.md5(pdfbin).hexdigest() print(md5) except Exception: return {"message": "There was an error uploading the file"} finally: file.file.close() save_path = SAVE_DIR / md5 print(save_path) if start is not None and stop is not None: pages = list(range(start - 1, stop)) else: pages = list(range(len(pdf))) predictions = [""] * len(pages) dellist = [] if save_path.exists(): for computed in (save_path / "pages").glob("*.mmd"): try: idx = int(computed.stem) - 1 if idx in pages: i = pages.index(idx) print("skip page", idx + 1) predictions[i] = computed.read_text(encoding="utf-8") dellist.append(idx) except Exception as e: print(e) compute_pages = pages.copy() for el in dellist: compute_pages.remove(el) images = rasterize_paper(pdf, pages=compute_pages) global model dataset = ImageDataset( images, partial(model.encoder.prepare_input, random_padding=False), ) dataloader = torch.utils.data.DataLoader( dataset, batch_size=BATCHSIZE, pin_memory=True, shuffle=False, ) for idx, sample in tqdm(enumerate(dataloader), total=len(dataloader)): if sample is None: continue model_output = model.inference(image_tensors=sample) for j, output in enumerate(model_output["predictions"]): if model_output["repeats"][j] is not None: if model_output["repeats"][j] > 0: disclaimer = "\n\n+++ ==WARNING: Truncated because of repetitions==\n%s\n+++\n\n" else: disclaimer = ( "\n\n+++ ==ERROR: No output for this page==\n%s\n+++\n\n" ) rest = close_envs(model_output["repetitions"][j]).strip() if len(rest) > 0: disclaimer = disclaimer % rest else: disclaimer = "" else: disclaimer = "" predictions[pages.index(compute_pages[idx * BATCHSIZE + j])] = ( markdown_compatible(output) + disclaimer ) (save_path / "pages").mkdir(parents=True, exist_ok=True) pdf.save(save_path / "doc.pdf") if len(images) > 0: thumb = Image.open(images[0]) thumb.thumbnail((400, 400)) thumb.save(save_path / "thumb.jpg") for idx, page_num in enumerate(pages): (save_path / "pages" / ("%02d.mmd" % (page_num + 1))).write_text( predictions[idx], encoding="utf-8" ) final = "".join(predictions).strip() (save_path / "doc.mmd").write_text(final, encoding="utf-8") return final def main(): import uvicorn uvicorn.run("app:app", host="0.0.0.0", port=8866) if __name__ == "__main__": main()