vkumartr commited on
Commit
57f0216
·
verified ·
1 Parent(s): 2d4431d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -54
app.py CHANGED
@@ -46,6 +46,7 @@ app = FastAPI(docs_url='/')
46
  use_gpu = False
47
  output_dir = 'output'
48
 
 
49
  @app.on_event("startup")
50
  def startup_db():
51
  try:
@@ -54,6 +55,7 @@ def startup_db():
54
  except Exception as e:
55
  logger.error(f"MongoDB connection failed: {str(e)}")
56
 
 
57
  # AWS S3 Configuration
58
  API_KEY = os.getenv("API_KEY")
59
  AWS_ACCESS_KEY = os.getenv("AWS_ACCESS_KEY")
@@ -70,6 +72,7 @@ s3_client = boto3.client(
70
  aws_secret_access_key=AWS_SECRET_KEY
71
  )
72
 
 
73
  # Function to fetch file from S3
74
  def fetch_file_from_s3(file_key):
75
  try:
@@ -80,6 +83,8 @@ def fetch_file_from_s3(file_key):
80
  except Exception as e:
81
  raise Exception(f"Failed to fetch file from S3: {str(e)}")
82
 
 
 
83
  # Updated extraction function that handles PDF and image files differently
84
  def extract_invoice_data(file_data, content_type, json_schema):
85
  """
@@ -87,27 +92,14 @@ def extract_invoice_data(file_data, content_type, json_schema):
87
  For Images: Pass the Base64-encoded image to OpenAI (assuming a multimodal model)
88
  """
89
  system_prompt = "You are an expert in document data extraction."
90
- base64_encoded_images = [] # To store Base64-encoded image data
91
-
92
- extracted_data = {}
93
 
94
  if content_type == "application/pdf":
95
  # Use PyMuPDF to extract text directly from the PDF
96
  try:
97
  doc = fitz.open(stream=file_data, filetype="pdf")
98
- num_pages = doc.page_count
99
-
100
- # Check if the number of pages exceeds 2
101
- if num_pages > 2:
102
- raise ValueError("The PDF contains more than 2 pages, extraction not supported.")
103
-
104
  extracted_text = ""
105
  for page in doc:
106
  extracted_text += page.get_text()
107
-
108
- # Store the extracted text in the dictionary
109
- extracted_data["text"] = extracted_text
110
-
111
  except Exception as e:
112
  logger.error(f"Error extracting text from PDF: {e}")
113
  raise
@@ -120,38 +112,18 @@ def extract_invoice_data(file_data, content_type, json_schema):
120
  )
121
 
122
  elif content_type.startswith("image/"):
123
- # For images, determine if more than 2 images are provided
124
- try:
125
- img = Image.open(io.BytesIO(file_data)) # Open the image file
126
- num_images = img.n_frames # Get number of images (pages in the image file)
127
-
128
- if num_images > 2:
129
- raise ValueError("The image file contains more than 2 pages, extraction not supported.")
130
-
131
- # Process each image page if there are 1 or 2 pages
132
- for page_num in range(num_images):
133
- img.seek(page_num) # Move to the current page
134
- img_bytes = io.BytesIO()
135
- img.save(img_bytes, format="PNG") # Save each page as a PNG image in memory
136
- base64_encoded = base64.b64encode(img_bytes.getvalue()).decode('utf-8')
137
- base64_encoded_images.append(base64_encoded)
138
-
139
- # Add Base64 image data to the extracted data dictionary
140
- extracted_data["base64_images"] = base64_encoded_images
141
-
142
- # Build a prompt containing the image data for OpenAI
143
- prompt = f"Extract the invoice data from the following images (Base64 encoded). Return only valid JSON that adheres to this schema:\n\n{json.dumps(json_schema, indent=2)}\n\n"
144
- for base64_image in base64_encoded_images:
145
- prompt += f"Image Data URL: data:{content_type};base64,{base64_image}\n"
146
-
147
- except Exception as e:
148
- logger.error(f"Error handling images: {e}")
149
- raise
150
-
151
  else:
152
  raise ValueError(f"Unsupported content type: {content_type}")
153
 
154
- # Send request to OpenAI for data extraction
155
  try:
156
  response = openai.ChatCompletion.create(
157
  model="gpt-4o-mini",
@@ -164,20 +136,21 @@ def extract_invoice_data(file_data, content_type, json_schema):
164
  )
165
 
166
  content = response.choices[0].message.content.strip()
 
 
167
  cleaned_content = content.strip().strip('```json').strip('```')
168
-
169
  try:
170
  parsed_content = json.loads(cleaned_content)
171
- extracted_data["extracted_json"] = parsed_content # Store the parsed JSON data
172
- return extracted_data
173
  except json.JSONDecodeError as e:
174
  logger.error(f"JSON Parse Error: {e}")
175
- return {"error": f"JSON Parse Error: {str(e)}"}
176
 
177
  except Exception as e:
178
  logger.error(f"Error in data extraction: {e}")
179
  return {"error": str(e)}
180
 
 
181
  def get_content_type_from_s3(file_key):
182
  """Fetch the content type (MIME type) of a file stored in S3."""
183
  try:
@@ -186,21 +159,24 @@ def get_content_type_from_s3(file_key):
186
  except Exception as e:
187
  raise Exception(f"Failed to get content type from S3: {str(e)}")
188
 
 
189
  # Dependency to check API Key
190
  def verify_api_key(api_key: str = Header(...)):
191
  if api_key != API_KEY:
192
  raise HTTPException(status_code=401, detail="Invalid API Key")
193
 
 
194
  @app.get("/")
195
  def read_root():
196
  return {"message": "Welcome to the Invoice Summarization API!"}
197
 
 
198
  @app.get("/ocr/extraction")
199
  def extract_text_from_file(
200
- api_key: str = Depends(verify_api_key),
201
- file_key: str = Query(..., description="S3 file key for the file"),
202
- document_type: str = Query(..., description="Type of document"),
203
- entity_ref_key: str = Query(..., description="Entity Reference Key")
204
  ):
205
  """Extract text from a PDF or Image stored in S3 and process it based on document size."""
206
  try:
@@ -218,9 +194,9 @@ def extract_text_from_file(
218
 
219
  json_schema = schema_doc.get("json_schema")
220
  if not json_schema:
221
- raise ValueError("Schema is empty or not properly defined.")
222
-
223
- # Retrieve file from S3 and determine content type
224
  content_type = get_content_type_from_s3(file_key)
225
  file_data, _ = fetch_file_from_s3(file_key)
226
  extracted_data = extract_invoice_data(file_data, content_type, json_schema)
@@ -256,7 +232,8 @@ def extract_text_from_file(
256
  "traceback": traceback.format_exc()
257
  }
258
  return {"error": error_details}
259
-
 
260
  # Serve the output folder as static files
261
  app.mount("/output", StaticFiles(directory="output", follow_symlink=True, html=True), name="output")
262
 
 
46
  use_gpu = False
47
  output_dir = 'output'
48
 
49
+
50
  @app.on_event("startup")
51
  def startup_db():
52
  try:
 
55
  except Exception as e:
56
  logger.error(f"MongoDB connection failed: {str(e)}")
57
 
58
+
59
  # AWS S3 Configuration
60
  API_KEY = os.getenv("API_KEY")
61
  AWS_ACCESS_KEY = os.getenv("AWS_ACCESS_KEY")
 
72
  aws_secret_access_key=AWS_SECRET_KEY
73
  )
74
 
75
+
76
  # Function to fetch file from S3
77
  def fetch_file_from_s3(file_key):
78
  try:
 
83
  except Exception as e:
84
  raise Exception(f"Failed to fetch file from S3: {str(e)}")
85
 
86
+
87
+ # Function to summarize text using OpenAI GPT
88
  # Updated extraction function that handles PDF and image files differently
89
  def extract_invoice_data(file_data, content_type, json_schema):
90
  """
 
92
  For Images: Pass the Base64-encoded image to OpenAI (assuming a multimodal model)
93
  """
94
  system_prompt = "You are an expert in document data extraction."
 
 
 
95
 
96
  if content_type == "application/pdf":
97
  # Use PyMuPDF to extract text directly from the PDF
98
  try:
99
  doc = fitz.open(stream=file_data, filetype="pdf")
 
 
 
 
 
 
100
  extracted_text = ""
101
  for page in doc:
102
  extracted_text += page.get_text()
 
 
 
 
103
  except Exception as e:
104
  logger.error(f"Error extracting text from PDF: {e}")
105
  raise
 
112
  )
113
 
114
  elif content_type.startswith("image/"):
115
+ # For images, encode as Base64 and pass to OpenAI
116
+ base64_encoded = base64.b64encode(file_data).decode('utf-8')
117
+ # In this example we assume the model accepts image inputs via a Base64 data URL.
118
+ # (This requires access to a multimodal model.)
119
+ prompt = (
120
+ f"Extract the invoice data from the following image. "
121
+ f"Return only valid JSON that adheres to this schema:\n\n{json.dumps(json_schema, indent=2)}\n\n"
122
+ f"Image Data URL:\n data:{content_type};base64,{base64_encoded}"
123
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  else:
125
  raise ValueError(f"Unsupported content type: {content_type}")
126
 
 
127
  try:
128
  response = openai.ChatCompletion.create(
129
  model="gpt-4o-mini",
 
136
  )
137
 
138
  content = response.choices[0].message.content.strip()
139
+
140
+ # Clean and parse JSON output (remove markdown formatting if present)
141
  cleaned_content = content.strip().strip('```json').strip('```')
 
142
  try:
143
  parsed_content = json.loads(cleaned_content)
144
+ return parsed_content
 
145
  except json.JSONDecodeError as e:
146
  logger.error(f"JSON Parse Error: {e}")
147
+ return None
148
 
149
  except Exception as e:
150
  logger.error(f"Error in data extraction: {e}")
151
  return {"error": str(e)}
152
 
153
+
154
  def get_content_type_from_s3(file_key):
155
  """Fetch the content type (MIME type) of a file stored in S3."""
156
  try:
 
159
  except Exception as e:
160
  raise Exception(f"Failed to get content type from S3: {str(e)}")
161
 
162
+
163
  # Dependency to check API Key
164
  def verify_api_key(api_key: str = Header(...)):
165
  if api_key != API_KEY:
166
  raise HTTPException(status_code=401, detail="Invalid API Key")
167
 
168
+
169
  @app.get("/")
170
  def read_root():
171
  return {"message": "Welcome to the Invoice Summarization API!"}
172
 
173
+
174
  @app.get("/ocr/extraction")
175
  def extract_text_from_file(
176
+ api_key: str = Depends(verify_api_key),
177
+ file_key: str = Query(..., description="S3 file key for the file"),
178
+ document_type: str = Query(..., description="Type of document"),
179
+ entity_ref_key: str = Query(..., description="Entity Reference Key")
180
  ):
181
  """Extract text from a PDF or Image stored in S3 and process it based on document size."""
182
  try:
 
194
 
195
  json_schema = schema_doc.get("json_schema")
196
  if not json_schema:
197
+ raise ValueError("Schema is empty or not properly defined.")
198
+
199
+ # Retrieve file from S3 and determine content type
200
  content_type = get_content_type_from_s3(file_key)
201
  file_data, _ = fetch_file_from_s3(file_key)
202
  extracted_data = extract_invoice_data(file_data, content_type, json_schema)
 
232
  "traceback": traceback.format_exc()
233
  }
234
  return {"error": error_details}
235
+
236
+
237
  # Serve the output folder as static files
238
  app.mount("/output", StaticFiles(directory="output", follow_symlink=True, html=True), name="output")
239