duithive-ocr-1 / main.py
Elgene's picture
fix: add comma
8fe0386
raw
history blame
2.96 kB
import re
import torch
from fastapi import FastAPI, File, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from transformers import DonutProcessor, VisionEncoderDecoderModel
from PIL import Image
from io import BytesIO
origins = [
"https://duithive.vercel.app",
"https://staging-duithive.vercel.app",
"http://localhost:3000",
]
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2", use_fast=False)
model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2")
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
task_prompt = "<s_cord-v2>"
decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
# def generateOutput(fileData):
# pil_image = Image.open(BytesIO(fileData))
# resized_image = pil_image.resize((800, 600)).convert('RGB')
# rgb_image = Image.new('RGB', resized_image.size)
# rgb_image.paste(resized_image)
# output_buffer = BytesIO()
# rgb_image.save(output_buffer, format="JPEG", quality = 100)
# jpeg_image = Image.open(BytesIO(output_buffer.getvalue()))
# pixel_values = processor(jpeg_image, return_tensors="pt").pixel_values
# outputs = model.generate(
# pixel_values.to(device),
# decoder_input_ids=decoder_input_ids.to(device),
# max_length=model.decoder.config.max_position_embeddings,
# pad_token_id=processor.tokenizer.pad_token_id,
# eos_token_id=processor.tokenizer.eos_token_id,
# use_cache=True,
# bad_words_ids=[[processor.tokenizer.unk_token_id]],
# return_dict_in_generate=True,
# )
# return outputs
def generateOutput(fileData):
pil_image = Image.open(BytesIO(fileData))
pil_image.resize((800, 600))
pixel_values = processor(pil_image, return_tensors="pt").pixel_values
outputs = model.generate(
pixel_values.to(device),
decoder_input_ids=decoder_input_ids.to(device),
max_length=model.decoder.config.max_position_embeddings,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
use_cache=True,
bad_words_ids=[[processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
)
return outputs
@app.post("/ocr/")
async def analyze_image(file: UploadFile = File(...)):
content = await file.read()
outputs = generateOutput(content)
sequence = processor.batch_decode(outputs.sequences)[0]
sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
return processor.token2json(sequence)