MJobe commited on
Commit
86a0b7a
·
1 Parent(s): 251ed69

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +57 -32
main.py CHANGED
@@ -6,28 +6,21 @@ from PIL import Image
6
  from io import BytesIO
7
  from starlette.middleware import Middleware
8
  from starlette.middleware.cors import CORSMiddleware
 
 
9
 
10
- app = FastAPI()
11
-
12
- # Use a pipeline as a high-level helper
13
- nlp_qa = pipeline("document-question-answering", model="impira/layoutlm-invoices")
14
- # Use a pipeline as a high-level helper
15
- nlp_ner = pipeline('question-answering', model='deepset/roberta-base-squad2', tokenizer='deepset/roberta-base-squad2')
16
-
17
 
18
- description = """
19
- ## Image-based Document QA
20
- This API performs document question answering using a LayoutLM-based model.
21
 
22
- ### Endpoints:
23
- - **POST /uploadfile/:** Upload an image file to extract text and answer provided questions.
24
- - **POST /pdfUpload/:** Provide a file to extract text and answer provided questions.
25
- """
26
 
27
- app = FastAPI(docs_url="/", description=description)
 
28
 
29
- @app.post("/uploadfile/", description=description)
30
- async def perform_document_qa(
31
  file: UploadFile = File(...),
32
  questions: str = Form(...),
33
  ):
@@ -38,25 +31,57 @@ async def perform_document_qa(
38
  # Open the image using PIL
39
  image = Image.open(BytesIO(contents))
40
 
41
- # Perform document question answering for each question using LayoutLM-based model
42
- answers_dict = {}
43
- for question in questions.split(','):
44
- result = nlp_qa(
45
- image,
46
- question.strip()
47
- )
48
-
49
- # Access the 'answer' key from the first item in the result list
50
- answer = result[0]['answer']
51
 
52
- # Format the question as a string without extra characters
53
- formatted_question = question.strip("[]")
54
 
55
- answers_dict[formatted_question] = answer
 
 
56
 
57
- return answers_dict
58
  except Exception as e:
59
- return JSONResponse(content=f"Error processing file: {str(e)}", status_code=500)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
  @app.post("/pdfQA/", description=description)
62
  async def pdf_question_answering(
 
6
  from io import BytesIO
7
  from starlette.middleware import Middleware
8
  from starlette.middleware.cors import CORSMiddleware
9
+ import torch
10
+ import re
11
 
12
+ from transformers import DonutProcessor, VisionEncoderDecoderModel
 
 
 
 
 
 
13
 
14
+ app = FastAPI()
 
 
15
 
16
+ processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa")
17
+ model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa")
 
 
18
 
19
+ device = "cuda" if torch.cuda.is_available() else "cpu"
20
+ model.to(device)
21
 
22
+ @app.post("/donutQA/")
23
+ async def donut_question_answering(
24
  file: UploadFile = File(...),
25
  questions: str = Form(...),
26
  ):
 
31
  # Open the image using PIL
32
  image = Image.open(BytesIO(contents))
33
 
34
+ # Split the questions into a list
35
+ question_list = questions.split(',')
 
 
 
 
 
 
 
 
36
 
37
+ # Process document with Donut model for each question
38
+ answers = process_document(image, question_list)
39
 
40
+ # Return a dictionary with questions and corresponding answers
41
+ result_dict = dict(zip(question_list, answers))
42
+ return result_dict
43
 
 
44
  except Exception as e:
45
+ return {"error": f"Error processing file: {str(e)}"}
46
+
47
+ def process_document(image, questions):
48
+ # prepare encoder inputs
49
+ pixel_values = processor(image, return_tensors="pt").pixel_values
50
+
51
+ # prepare decoder inputs
52
+ task_prompt = "<s_docvqa><s_question>{user_input}</s_question><s_answer>"
53
+
54
+ # Initialize a list to store answers for each question
55
+ answers = []
56
+
57
+ # Process each question
58
+ for question in questions:
59
+ prompt = task_prompt.replace("{user_input}", question)
60
+ decoder_input_ids = processor.tokenizer(prompt, add_special_tokens=False, return_tensors="pt").input_ids
61
+
62
+ # generate answer
63
+ outputs = model.generate(
64
+ pixel_values.to(device),
65
+ decoder_input_ids=decoder_input_ids.to(device),
66
+ max_length=model.decoder.config.max_position_embeddings,
67
+ early_stopping=True,
68
+ pad_token_id=processor.tokenizer.pad_token_id,
69
+ eos_token_id=processor.tokenizer.eos_token_id,
70
+ use_cache=True,
71
+ num_beams=1,
72
+ bad_words_ids=[[processor.tokenizer.unk_token_id]],
73
+ return_dict_in_generate=True,
74
+ )
75
+
76
+ # postprocess
77
+ sequence = processor.batch_decode(outputs.sequences)[0]
78
+ sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
79
+ sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
80
+
81
+ # Append the answer to the list
82
+ answers.append(processor.token2json(sequence))
83
+
84
+ return answers
85
 
86
  @app.post("/pdfQA/", description=description)
87
  async def pdf_question_answering(