Spaces:
Runtime error
Runtime error
""" | |
Copyright (c) Meta Platforms, Inc. and affiliates. | |
This source code is licensed under the MIT license found in the | |
LICENSE file in the root directory of this source tree. | |
""" | |
import os | |
os.environ["CUDA_VISIBLE_DEVICES"] = "1" # just use one GPU on big machine | |
import torch | |
print ('Available devices ', torch.cuda.device_count()) | |
print ('Current cuda device ', torch.cuda.current_device()) | |
assert torch.cuda.device_count() == 1 | |
print('GPU Device name:', torch.cuda.get_device_name(torch.cuda.current_device())) | |
import sys | |
from functools import partial | |
from http import HTTPStatus | |
from fastapi import FastAPI, File, UploadFile | |
from PIL import Image | |
from pathlib import Path | |
import hashlib | |
from fastapi.middleware.cors import CORSMiddleware | |
import fitz | |
import torch | |
from nougat import NougatModel | |
from nougat.postprocessing import markdown_compatible, close_envs | |
from nougat.utils.dataset import ImageDataset | |
from nougat.utils.checkpoint import get_checkpoint | |
from nougat.dataset.rasterize import rasterize_paper | |
from tqdm import tqdm | |
import logging | |
import pypdfium2 | |
import io | |
from typing import Optional, List, Union | |
SAVE_DIR = Path("./pdfs") | |
BATCHSIZE = os.environ.get("NOUGAT_BATCHSIZE", 6) | |
NOUGAT_CHECKPOINT = get_checkpoint() | |
if NOUGAT_CHECKPOINT is None: | |
print( | |
"Set environment variable 'NOUGAT_CHECKPOINT' with a path to the model checkpoint!." | |
) | |
sys.exit(1) | |
app = FastAPI(title="Nougat API") | |
origins = ["http://localhost", "http://127.0.0.1","http://43.155.187.132"] | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=origins, | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
model = None | |
def rasterize_paper( | |
pdf: Union[Path, bytes], | |
outpath: Optional[Path] = None, | |
dpi: int = 96, | |
return_pil=False, | |
pages=None, | |
) -> Optional[List[io.BytesIO]]: | |
""" | |
Rasterize a PDF file to PNG images. | |
Args: | |
pdf (Path): The path to the PDF file. | |
outpath (Optional[Path], optional): The output directory. If None, the PIL images will be returned instead. Defaults to None. | |
dpi (int, optional): The output DPI. Defaults to 96. | |
return_pil (bool, optional): Whether to return the PIL images instead of writing them to disk. Defaults to False. | |
pages (Optional[List[int]], optional): The pages to rasterize. If None, all pages will be rasterized. Defaults to None. | |
Returns: | |
Optional[List[io.BytesIO]]: The PIL images if `return_pil` is True, otherwise None. | |
""" | |
pils = [] | |
if outpath is None: | |
return_pil = True | |
try: | |
if isinstance(pdf, (str, Path)): | |
pdf = pypdfium2.PdfDocument(pdf) | |
if pages is None: | |
pages = range(len(pdf)) | |
renderer = pdf.render( | |
pypdfium2.PdfBitmap.to_pil, | |
page_indices=pages, | |
scale=dpi / 72, | |
) | |
for i, image in zip(pages, renderer): | |
if return_pil: | |
page_bytes = io.BytesIO() | |
image.save(page_bytes, "bmp") | |
pils.append(page_bytes) | |
else: | |
image.save((outpath / ("%02d.png" % (i + 1))), "png") | |
except Exception as e: | |
logging.error(e) | |
if return_pil: | |
return pils | |
async def load_model( | |
checkpoint: str = NOUGAT_CHECKPOINT, | |
): | |
global model | |
if model is None: | |
model = NougatModel.from_pretrained(checkpoint).to(torch.bfloat16) | |
if torch.cuda.is_available(): | |
model.to("cuda") | |
model.eval() | |
def root(): | |
"""Health check.""" | |
response = { | |
"status-code": HTTPStatus.OK, | |
"data": {}, | |
} | |
return response | |
import shutil | |
async def predict( | |
file: UploadFile = File(), start: int = None, stop: int = None | |
): | |
""" | |
Perform predictions on a PDF document and return the extracted text in Markdown format. | |
Args: | |
file (UploadFile): The uploaded PDF file to process. | |
start (int, optional): The starting page number for prediction. | |
stop (int, optional): The ending page number for prediction. | |
Returns: | |
str: The extracted text in Markdown format. | |
""" | |
##test code | |
upload_dir = os.path.join(os.getcwd(), "uploads") | |
# Create the upload directory if it doesn't exist | |
if not os.path.exists(upload_dir): | |
os.makedirs(upload_dir) | |
dest = os.path.join(upload_dir, file.filename) | |
try: | |
with open(dest, 'wb') as f: | |
pdfbin = file.file.read() | |
shutil.copyfileobj(file.file, f) | |
pdf = fitz.open("pdf", pdfbin) | |
print(pdf) | |
md5 = hashlib.md5(pdfbin).hexdigest() | |
print(md5) | |
except Exception: | |
return {"message": "There was an error uploading the file"} | |
finally: | |
file.file.close() | |
save_path = SAVE_DIR / md5 | |
print(save_path) | |
if start is not None and stop is not None: | |
pages = list(range(start - 1, stop)) | |
else: | |
pages = list(range(len(pdf))) | |
predictions = [""] * len(pages) | |
dellist = [] | |
if save_path.exists(): | |
for computed in (save_path / "pages").glob("*.mmd"): | |
try: | |
idx = int(computed.stem) - 1 | |
if idx in pages: | |
i = pages.index(idx) | |
print("skip page", idx + 1) | |
predictions[i] = computed.read_text(encoding="utf-8") | |
dellist.append(idx) | |
except Exception as e: | |
print(e) | |
compute_pages = pages.copy() | |
for el in dellist: | |
compute_pages.remove(el) | |
images = rasterize_paper(pdf, pages=compute_pages) | |
global model | |
dataset = ImageDataset( | |
images, | |
partial(model.encoder.prepare_input, random_padding=False), | |
) | |
dataloader = torch.utils.data.DataLoader( | |
dataset, | |
batch_size=BATCHSIZE, | |
pin_memory=True, | |
shuffle=False, | |
) | |
for idx, sample in tqdm(enumerate(dataloader), total=len(dataloader)): | |
if sample is None: | |
continue | |
model_output = model.inference(image_tensors=sample) | |
for j, output in enumerate(model_output["predictions"]): | |
if model_output["repeats"][j] is not None: | |
if model_output["repeats"][j] > 0: | |
disclaimer = "\n\n+++ ==WARNING: Truncated because of repetitions==\n%s\n+++\n\n" | |
else: | |
disclaimer = ( | |
"\n\n+++ ==ERROR: No output for this page==\n%s\n+++\n\n" | |
) | |
rest = close_envs(model_output["repetitions"][j]).strip() | |
if len(rest) > 0: | |
disclaimer = disclaimer % rest | |
else: | |
disclaimer = "" | |
else: | |
disclaimer = "" | |
predictions[pages.index(compute_pages[idx * BATCHSIZE + j])] = ( | |
markdown_compatible(output) + disclaimer | |
) | |
(save_path / "pages").mkdir(parents=True, exist_ok=True) | |
pdf.save(save_path / "doc.pdf") | |
if len(images) > 0: | |
thumb = Image.open(images[0]) | |
thumb.thumbnail((400, 400)) | |
thumb.save(save_path / "thumb.jpg") | |
for idx, page_num in enumerate(pages): | |
(save_path / "pages" / ("%02d.mmd" % (page_num + 1))).write_text( | |
predictions[idx], encoding="utf-8" | |
) | |
final = "".join(predictions).strip() | |
(save_path / "doc.mmd").write_text(final, encoding="utf-8") | |
return final | |
def main(): | |
import uvicorn | |
uvicorn.run("app:app", host="0.0.0.0", port=8866) | |
if __name__ == "__main__": | |
main() | |