MJobe commited on
Commit
574f9e3
1 Parent(s): 420d3c9

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +84 -67
main.py CHANGED
@@ -1,26 +1,40 @@
1
  import fitz
2
- from fastapi import FastAPI, File, UploadFile, Form, Request, Response
3
  from fastapi.responses import JSONResponse
4
  from transformers import pipeline
5
  from PIL import Image
6
  from io import BytesIO
7
  from starlette.middleware import Middleware
8
  from starlette.middleware.cors import CORSMiddleware
9
- import torch
10
- import re
11
-
12
- from transformers import DonutProcessor, VisionEncoderDecoderModel
13
 
14
  app = FastAPI()
15
 
16
- processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa")
17
- model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- device = "cuda" if torch.cuda.is_available() else "cpu"
20
- model.to(device)
21
 
22
- @app.post("/donutQA/")
23
- async def donut_question_answering(
24
  file: UploadFile = File(...),
25
  questions: str = Form(...),
26
  ):
@@ -31,68 +45,71 @@ async def donut_question_answering(
31
  # Open the image using PIL
32
  image = Image.open(BytesIO(contents))
33
 
34
- image = image.convert("RGB")
 
 
 
 
 
 
35
 
36
- # Split the questions into a list
37
- question_list = questions.split(',')
38
 
39
- # Process document with Donut model for each question
40
- answers = process_document(image, question_list)
41
 
42
- # Return a dictionary with questions and corresponding answers
43
- result_dict = dict(zip(question_list, answers))
44
- return result_dict
45
 
 
46
  except Exception as e:
47
- return {"error": f"Error processing file: {str(e)}"}
48
-
49
- def process_document(image, questions):
50
- image = image.convert("RGB")
51
-
52
- # prepare encoder inputs
53
- pixel_values = processor(image, return_tensors="pt").pixel_values
54
-
55
- # prepare decoder inputs
56
- task_prompt = "<s_docvqa><s_question>{user_input}</s_question><s_answer>"
57
-
58
- # Initialize a list to store answers for each question
59
- answers = []
60
-
61
- # Process each question
62
- for question in questions:
63
- prompt = task_prompt.replace("{user_input}", question)
64
- decoder_input_ids = processor.tokenizer(prompt, add_special_tokens=False, return_tensors="pt").input_ids
65
-
66
- # generate answer
67
- outputs = model.generate(
68
- pixel_values.to(device),
69
- decoder_input_ids=decoder_input_ids.to(device),
70
- max_length=model.decoder.config.max_position_embeddings,
71
- early_stopping=True,
72
- pad_token_id=processor.tokenizer.pad_token_id,
73
- eos_token_id=processor.tokenizer.eos_token_id,
74
- use_cache=True,
75
- num_beams=1,
76
- bad_words_ids=[[processor.tokenizer.unk_token_id]],
77
- return_dict_in_generate=True,
78
- )
79
 
80
- # postprocess
81
- sequence = processor.batch_decode(outputs.sequences)[0]
82
- sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
83
- sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
 
 
84
 
85
- # Append the answer to the list
86
- answers.append(processor.token2json(sequence))
87
 
88
- return answers
 
89
 
90
- # Set up CORS middleware
91
- origins = ["*"] # or specify your list of allowed origins
92
- app.add_middleware(
93
- CORSMiddleware,
94
- allow_origins=origins,
95
- allow_credentials=True,
96
- allow_methods=["*"],
97
- allow_headers=["*"],
98
- )
 
 
 
 
 
 
 
 
 
 
 
 
1
  import fitz
2
+ from fastapi import FastAPI, File, UploadFile, Form
3
  from fastapi.responses import JSONResponse
4
  from transformers import pipeline
5
  from PIL import Image
6
  from io import BytesIO
7
  from starlette.middleware import Middleware
8
  from starlette.middleware.cors import CORSMiddleware
 
 
 
 
9
 
10
  app = FastAPI()
11
 
12
+ # Set up CORS middleware
13
+ origins = ["*"] # or specify your list of allowed origins
14
+ app.add_middleware(
15
+ CORSMiddleware,
16
+ allow_origins=origins,
17
+ allow_credentials=True,
18
+ allow_methods=["*"],
19
+ allow_headers=["*"],
20
+ )
21
+
22
+ # Use a pipeline as a high-level helper
23
+ nlp_qa = pipeline("document-question-answering", model="tiennvcs/layoutlmv2-base-uncased-finetuned-docvqa")
24
+
25
+ description = """
26
+ ## Image-based Document QA
27
+ This API performs document question answering using a LayoutLMv2-based model.
28
+
29
+ ### Endpoints:
30
+ - **POST /uploadfile/:** Upload an image file to extract text and answer provided questions.
31
+ - **POST /pdfQA/:** Provide a PDF file to extract text and answer provided questions.
32
+ """
33
 
34
+ app = FastAPI(docs_url="/", description=description)
 
35
 
36
+ @app.post("/uploadfile/", description="Upload an image file to extract text and answer provided questions.")
37
+ async def perform_document_qa(
38
  file: UploadFile = File(...),
39
  questions: str = Form(...),
40
  ):
 
45
  # Open the image using PIL
46
  image = Image.open(BytesIO(contents))
47
 
48
+ # Perform document question answering for each question using LayoutLMv2-based model
49
+ answers_dict = {}
50
+ for question in questions.split(','):
51
+ result = nlp_qa(
52
+ image,
53
+ question.strip()
54
+ )
55
 
56
+ # Access the 'answer' key from the first item in the result list
57
+ answer = result[0]['answer']
58
 
59
+ # Format the question as a string without extra characters
60
+ formatted_question = question.strip("[]")
61
 
62
+ answers_dict[formatted_question] = answer
 
 
63
 
64
+ return answers_dict
65
  except Exception as e:
66
+ return JSONResponse(content=f"Error processing file: {str(e)}", status_code=500)
67
+
68
+ @app.post("/pdfQA/", description="Provide a PDF file to extract text and answer provided questions.")
69
+ async def pdf_question_answering(
70
+ file: UploadFile = File(...),
71
+ questions: str = Form(...),
72
+ ):
73
+ try:
74
+ # Read the uploaded file as bytes
75
+ contents = await file.read()
76
+
77
+ # Initialize an empty string to store the text content of the PDF
78
+ all_text = ""
79
+
80
+ # Use PyMuPDF to process the PDF and extract text
81
+ pdf_document = fitz.open_from_bytes(contents)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
+ # Loop through each page and perform OCR
84
+ for page_num in range(pdf_document.page_count):
85
+ page = pdf_document.load_page(page_num)
86
+ print(f"Processing page {page_num + 1}...")
87
+ text = page.get_text()
88
+ all_text += text + '\n'
89
 
90
+ # Print or do something with the collected text
91
+ print(all_text)
92
 
93
+ # List of questions
94
+ question_list = questions.split(',')
95
 
96
+ # Initialize an empty dictionary to store questions and answers
97
+ qa_dict = {}
98
+
99
+ # Get answers for each question with the same context
100
+ for question in question_list:
101
+ result = nlp_qa({
102
+ 'question': question,
103
+ 'context': all_text
104
+ })
105
+
106
+ # Access the 'answer' key from the result
107
+ answer = result['answer']
108
+
109
+ # Store the question and answer in the dictionary
110
+ qa_dict[question] = answer
111
+
112
+ return qa_dict
113
+
114
+ except Exception as e:
115
+ return JSONResponse(content=f"Error processing PDF file: {str(e)}", status_code=500)