senga-dnotes / routers /inference.py
serenarolloh's picture
Update routers/inference.py
30875d3 verified
from fastapi import APIRouter, File, UploadFile, Form, HTTPException
from typing import Optional
from PIL import Image
import urllib.request
from io import BytesIO
from config import settings
import utils
import os
import json
from routers.donut_inference import process_document_donut
import logging
import io
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
router = APIRouter()
def count_values(obj):
if isinstance(obj, dict):
count = 0
for value in obj.values():
count += count_values(value)
return count
elif isinstance(obj, list):
count = 0
for item in obj:
count += count_values(item)
return count
else:
return 1
@router.post("/inference")
async def run_inference(
file: Optional[UploadFile] = File(None),
image_url: Optional[str] = Form(None),
model_in_use: str = Form('donut'),
shipper_id: Optional[int] = Form(None)
):
# Validate input
if not file and not image_url:
return {"info": "No input provided"}
# Log the shipper_id that was received
logger.info(f"Received inference request with shipper_id: {shipper_id}")
# Convert shipper_id to string if provided (config.py expects a string)
shipper_id_str = str(shipper_id) if shipper_id is not None else "default_shipper"
logger.info(f"Using shipper_id: {shipper_id_str} for model selection")
result = []
processing_time = 0
try:
if file:
# Ensure the uploaded file is a JPG image
if file.content_type not in ["image/jpeg", "image/jpg"]:
logger.warning(f"Invalid file type: {file.content_type}")
return {"error": "Invalid file type. Only JPG images are allowed."}
logger.info(f"Processing file: {file.filename}")
image = Image.open(BytesIO(await file.read()))
if model_in_use == 'donut':
# Pass the shipper_id to the processing function
result, processing_time = process_document_donut(image, shipper_id_str)
utils.log_stats(settings.inference_stats_file, [processing_time, count_values(result), file.filename, settings.model])
logger.info(f"Processing time: {processing_time:.2f} seconds with model: {settings.model}")
else:
logger.warning(f"Unsupported model: {model_in_use}")
return {"error": f"Unsupported model: {model_in_use}"}
elif image_url:
logger.info(f"Processing image from URL: {image_url}")
# test image url: https://raw.githubusercontent.com/katanaml/sparrow/main/sparrow-data/docs/input/invoices/processed/images/invoice_10.jpg
try:
with urllib.request.urlopen(image_url) as url:
image = Image.open(BytesIO(url.read()))
if model_in_use == 'donut':
# Pass the shipper_id to the processing function
result, processing_time = process_document_donut(image, shipper_id_str)
# parse file name from url
file_name = image_url.split("/")[-1]
utils.log_stats(settings.inference_stats_file, [processing_time, count_values(result), file_name, settings.model])
logger.info(f"Processing time inference: {processing_time:.2f} seconds with model: {settings.model}")
else:
logger.warning(f"Unsupported model: {model_in_use}")
return {"error": f"Unsupported model: {model_in_use}"}
except Exception as e:
logger.error(f"Error processing image URL: {str(e)}")
return {"error": f"Error processing image URL: {str(e)}"}
except Exception as e:
logger.error(f"Error during inference: {str(e)}")
return {"error": f"Inference failed: {str(e)}"}
return result
@router.get("/statistics")
async def get_statistics():
file_path = settings.inference_stats_file
# Check if the file exists, and read its content
if os.path.exists(file_path):
with open(file_path, 'r') as file:
try:
content = json.load(file)
except json.JSONDecodeError:
content = []
else:
content = []
return content