Elgene commited on
Commit
ddb7273
1 Parent(s): 8249272

create main.py

Browse files
Files changed (1) hide show
  1. main.py +45 -0
main.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import torch
3
+ from fastapi import FastAPI, File, UploadFile
4
+ from transformers import DonutProcessor, VisionEncoderDecoderModel
5
+ from PIL import Image
6
+ from io import BytesIO
7
+
8
+ app = FastAPI()
9
+
10
+ processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2", use_fast=False)
11
+
12
+ model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2")
13
+
14
+ device = "cuda" if torch.cuda.is_available() else "cpu"
15
+ model.to(device)
16
+
17
+ task_prompt = "<s_cord-v2>"
18
+ decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
19
+
20
+ def generateOutput(fileData):
21
+ pil_image = Image.open(BytesIO(fileData))
22
+ pil_image.resize((800, 600))
23
+ pixel_values = processor(pil_image, return_tensors="pt").pixel_values
24
+
25
+ outputs = model.generate(
26
+ pixel_values.to(device),
27
+ decoder_input_ids=decoder_input_ids.to(device),
28
+ max_length=model.decoder.config.max_position_embeddings,
29
+ pad_token_id=processor.tokenizer.pad_token_id,
30
+ eos_token_id=processor.tokenizer.eos_token_id,
31
+ use_cache=True,
32
+ bad_words_ids=[[processor.tokenizer.unk_token_id]],
33
+ return_dict_in_generate=True,
34
+ )
35
+ return outputs
36
+
37
+ @app.post("/ocr/")
38
+ async def analyze_image(file: UploadFile = File(...)):
39
+ content = await file.read()
40
+ outputs = generateOutput(content)
41
+ sequence = processor.batch_decode(outputs.sequences)[0]
42
+ sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
43
+ sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
44
+ return processor.token2json(sequence)
45
+