got_ocr_test / app.py
acharyaaditya26's picture
changes
da60b63
raw
history blame
3.87 kB
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
import fitz # PyMuPDF
from transformers import AutoModel, AutoTokenizer
from PIL import Image
import numpy as np
import os
import base64
import io
import uuid
import tempfile
import time
import shutil
from pathlib import Path
import json
from starlette.requests import Request
import uvicorn
app = FastAPI()
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True)
model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, device_map='cuda', use_safetensors=True)
model = model.eval().cuda()
UPLOAD_FOLDER = "./uploads"
RESULTS_FOLDER = "./results"
# Ensure directories exist
for folder in [UPLOAD_FOLDER, RESULTS_FOLDER]:
if not os.path.exists(folder):
os.makedirs(folder)
def image_to_base64(image):
buffered = io.BytesIO()
image.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode()
def pdf_to_images(pdf_path):
images = []
pdf_document = fitz.open(pdf_path)
for page_num in range(len(pdf_document)):
page = pdf_document.load_page(page_num)
pix = page.get_pixmap()
img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
images.append(img)
return images
def run_GOT(pdf_file):
unique_id = str(uuid.uuid4())
pdf_path = os.path.join(UPLOAD_FOLDER, f"{unique_id}.pdf")
shutil.copy(pdf_file, pdf_path)
images = pdf_to_images(pdf_path)
results = []
try:
for i, image in enumerate(images):
image_path = os.path.join(UPLOAD_FOLDER, f"{unique_id}_page_{i+1}.png")
image.save(image_path)
result_path = os.path.join(RESULTS_FOLDER, f"{unique_id}_page_{i+1}.html")
res = model.chat_crop(tokenizer, image_path, ocr_type='format', render=True, save_render_file=result_path)
# Read the rendered HTML content
with open(result_path, 'r') as f:
html_content = f.read()
results.append({
"page_number": i + 1,
"text": res,
"html": html_content
})
if os.path.exists(image_path):
os.remove(image_path)
if os.path.exists(result_path):
os.remove(result_path)
except Exception as e:
return f"Error: {str(e)}", None
finally:
if os.path.exists(pdf_path):
os.remove(pdf_path)
return json.dumps(results, indent=4), results
def cleanup_old_files():
current_time = time.time()
for folder in [UPLOAD_FOLDER, RESULTS_FOLDER]:
for file_path in Path(folder).glob('*'):
if current_time - file_path.stat().st_mtime > 3600: # 1 hour
file_path.unlink()
cleanup_old_files()
# Mount static files
app.mount("/static", StaticFiles(directory="static"), name="static")
# Set up Jinja2 templates
templates = Jinja2Templates(directory="templates")
@app.get("/", response_class=HTMLResponse)
async def read_root(request: Request):
return templates.TemplateResponse("index.html", {"request": request})
@app.post("/uploadfile/")
async def upload_file(request: Request, file: UploadFile = File(...)):
temp_dir = tempfile.TemporaryDirectory()
temp_pdf_path = os.path.join(temp_dir.name, file.filename)
with open(temp_pdf_path, "wb") as buffer:
buffer.write(await file.read())
json_output, results = run_GOT(temp_pdf_path)
temp_dir.cleanup()
return templates.TemplateResponse("result.html", {"request": request, "json_output": json_output, "results": results})