Spaces:
Sleeping
Sleeping
| import uvicorn | |
| from fastapi.staticfiles import StaticFiles | |
| import hashlib | |
| from enum import Enum | |
| from fastapi import FastAPI, Header, Query, Depends, HTTPException | |
| from pdf2image import convert_from_bytes | |
| import io | |
| import fitz # PyMuPDF for PDF handling | |
| import logging | |
| from pymongo import MongoClient | |
| import boto3 | |
| import openai | |
| import os | |
| import traceback # For detailed traceback of errors | |
| import re | |
| import json | |
| from dotenv import load_dotenv | |
| import base64 | |
| from bson.objectid import ObjectId | |
| db_client = None | |
| load_dotenv() | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # MongoDB Configuration | |
| MONGODB_URI = os.getenv("MONGODB_URI") | |
| DATABASE_NAME = os.getenv("DATABASE_NAME") | |
| COLLECTION_NAME = os.getenv("COLLECTION_NAME") | |
| SCHEMA = os.getenv("SCHEMA") | |
| # Check if environment variables are set | |
| if not MONGODB_URI: | |
| raise ValueError("MONGODB_URI is not set. Please add it to your secrets.") | |
| # Initialize MongoDB Connection | |
| db_client = MongoClient(MONGODB_URI) | |
| db = db_client[DATABASE_NAME] | |
| invoice_collection = db[COLLECTION_NAME] | |
| schema_collection = db[SCHEMA] | |
| app = FastAPI(docs_url='/') | |
| use_gpu = False | |
| output_dir = 'output' | |
| def startup_db(): | |
| try: | |
| db_client.server_info() | |
| logger.info("MongoDB connection successful") | |
| except Exception as e: | |
| logger.error(f"MongoDB connection failed: {str(e)}") | |
| # AWS S3 Configuration | |
| API_KEY = os.getenv("API_KEY") | |
| AWS_ACCESS_KEY = os.getenv("AWS_ACCESS_KEY") | |
| AWS_SECRET_KEY = os.getenv("AWS_SECRET_KEY") | |
| S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME") | |
| # OpenAI Configuration | |
| openai.api_key = os.getenv("OPENAI_API_KEY") | |
| # S3 Client | |
| s3_client = boto3.client( | |
| 's3', | |
| aws_access_key_id=AWS_ACCESS_KEY, | |
| aws_secret_access_key=AWS_SECRET_KEY | |
| ) | |
| # Function to fetch file from S3 | |
| def fetch_file_from_s3(file_key): | |
| try: | |
| response = s3_client.get_object(Bucket=S3_BUCKET_NAME, Key=file_key) | |
| content_type = response['ContentType'] # Retrieve MIME type | |
| file_data = response['Body'].read() | |
| return file_data, content_type # Return file data as BytesIO | |
| except Exception as e: | |
| raise Exception(f"Failed to fetch file from S3: {str(e)}") | |
| def extract_pdf_text(file_data): | |
| """ | |
| Extracts text from a PDF file using PyMuPDF (fitz). | |
| """ | |
| try: | |
| pdf_document = fitz.open(stream=file_data, filetype="pdf") | |
| text = "\n".join([page.get_text("text") for page in pdf_document]) | |
| pdf_document.close() # Explicitly close the PDF | |
| return text if text.strip() else "" # Handle empty PDFs gracefully | |
| except Exception as e: | |
| logger.error(f"PDF Extraction Error: {e}") | |
| return None | |
| # Function to summarize text using OpenAI GPT | |
| def extract_invoice_data(file_data, content_type, json_schema): | |
| """ | |
| Extracts invoice data from PDFs (text-based) and images using OpenAI's GPT-4o-mini model. | |
| Ensures accurate JSON schema binding. | |
| """ | |
| system_prompt = """You are an expert in invoice data extraction. | |
| Your task is to extract key fields from an invoice image. Ensure accurate extraction and return the data in JSON format. | |
| Extract the following fields: | |
| 1. Line Items: A list containing: | |
| - Product Code | |
| - Description | |
| - Amount (numeric) | |
| 2. Tax Amount (if available) | |
| 3. Vendor GST (if available) | |
| 4. Vendor Name | |
| 5. Invoice Date (format: "DD-MMM-YYYY") | |
| 6. Total Amount (numeric) | |
| 7. Invoice Number (alpha-numeric) | |
| 8. Vendor Address | |
| 9. Invoice Currency | |
| Ensure that: | |
| - All extracted fields match the invoice. | |
| - If any field is missing, return null instead of hallucinating data. | |
| - Do not generate synthetic values—only extract real information from the image. | |
| """ | |
| base64_images = [] | |
| base64DataResp = [] | |
| extracted_text = "" | |
| if content_type == "application/pdf": | |
| try: | |
| extracted_text = extract_pdf_text(file_data) | |
| # Store PDF as Base64 | |
| base64_pdf = base64.b64encode(file_data).decode('utf-8') | |
| base64DataResp.append(f"data:application/pdf;base64,{base64_pdf}") | |
| images = convert_from_bytes(file_data) # Convert PDF to images | |
| if len(images) > 2: | |
| raise ValueError("PDF contains more than 2 pages.") | |
| for img in images[:2]: # Convert up to 2 pages | |
| img_byte_arr = io.BytesIO() | |
| img.save(img_byte_arr, format="PNG", dpi=(300, 300)) | |
| base64_encoded = base64.b64encode(img_byte_arr.getvalue()).decode('utf-8') | |
| base64_images.append(f"data:image/png;base64,{base64_encoded}") | |
| except Exception as e: | |
| logger.error(f"Error converting PDF to image: {e}") | |
| return {"error": "Failed to process PDF"}, None | |
| elif content_type.startswith("image/"): | |
| # Handle direct image files | |
| base64_img = base64.b64encode(file_data).decode('utf-8') | |
| base64DataResp.append(f"data:{content_type};base64,{base64_img}") | |
| base64_images.append(f"data:{content_type};base64,{base64_img}") | |
| else: | |
| return {"error": f"Unsupported file type: {content_type}"} | |
| # Prepare OpenAI request | |
| openai_content = [{"type": "image_url", "image_url": {"url": img_base64}} for img_base64 in base64_images] | |
| try: | |
| response = openai.ChatCompletion.create( | |
| model="gpt-4o-mini", | |
| messages=[ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": openai_content} | |
| ], | |
| response_format={"type": "json_schema", "json_schema": json_schema}, | |
| temperature=0.5, | |
| max_tokens=16384 | |
| ) | |
| parsed_content = json.loads(response.choices[0].message.content.strip()) | |
| return parsed_content, base64DataResp | |
| except Exception as e: | |
| logger.error(f"Error in OpenAI processing: {e}") | |
| return {"error": str(e)}, base64DataResp | |
| def get_content_type_from_s3(file_key): | |
| """Fetch the content type (MIME type) of a file stored in S3.""" | |
| try: | |
| response = s3_client.head_object(Bucket=S3_BUCKET_NAME, Key=file_key) | |
| return response.get('ContentType', 'application/octet-stream') # Default to binary if not found | |
| except Exception as e: | |
| raise Exception(f"Failed to get content type from S3: {str(e)}") | |
| # Dependency to check API Key | |
| def verify_api_key(api_key: str = Header(...)): | |
| if api_key != API_KEY: | |
| raise HTTPException(status_code=401, detail="Invalid API Key") | |
| def read_root(): | |
| return {"message": "Welcome to the Invoice Summarization API!"} | |
| def extract_text_from_file( | |
| api_key: str = Depends(verify_api_key), | |
| file_key: str = Query(..., description="S3 file key for the file"), | |
| document_type: str = Query(..., description="Type of document"), | |
| entity_ref_key: str = Query(..., description="Entity Reference Key") | |
| ): | |
| """Extract structured data from a PDF or image stored in S3.""" | |
| try: | |
| existing_document = invoice_collection.find_one({"entityrefkey": entity_ref_key}) | |
| if existing_document: | |
| existing_document["_id"] = str(existing_document["_id"]) | |
| return existing_document | |
| # Fetch JSON schema for the document type | |
| schema_doc = schema_collection.find_one({"document_type": document_type}) | |
| if not schema_doc: | |
| raise ValueError("No schema found for the given document type") | |
| json_schema = schema_doc.get("json_schema") | |
| if not json_schema: | |
| raise ValueError("Schema is empty or not properly defined.") | |
| # Retrieve file from S3 | |
| content_type = get_content_type_from_s3(file_key) | |
| file_data, _ = fetch_file_from_s3(file_key) | |
| # Extract structured data from the document | |
| extracted_data, base64DataResp = extract_invoice_data(file_data, content_type, json_schema) | |
| # Store document in MongoDB | |
| document = { | |
| "file_key": file_key, | |
| "file_type": content_type, | |
| "document_type": document_type, | |
| "entityrefkey": entity_ref_key, | |
| "base64DataResp": base64DataResp, | |
| "extracted_data": extracted_data | |
| } | |
| inserted_doc = invoice_collection.insert_one(document) | |
| document_id = str(inserted_doc.inserted_id) | |
| logger.info(f"Document inserted with ID: {document_id}") | |
| return { | |
| "message": "Document successfully stored in MongoDB", | |
| "document_id": document_id, | |
| "entityrefkey": entity_ref_key, | |
| "base64DataResp": base64DataResp, | |
| "extracted_data": extracted_data | |
| } | |
| except Exception as e: | |
| error_details = { | |
| "error_type": type(e).__name__, | |
| "error_message": str(e), | |
| "traceback": traceback.format_exc() | |
| } | |
| return {"error": error_details} | |
| # Serve the output folder as static files | |
| app.mount("/output", StaticFiles(directory="output", follow_symlink=True, html=True), name="output") | |
| if __name__ == '__main__': | |
| uvicorn.run(app=app) |