Spaces:
Sleeping
Sleeping
File size: 4,466 Bytes
30875d3 c2d58b3 f1483b9 c2d58b3 948d2eb c2d58b3 f1483b9 c2d58b3 f1483b9 c2d58b3 30875d3 7894c3d 30875d3 f1483b9 1002c61 30875d3 f1483b9 c2d58b3 30875d3 f1483b9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
from fastapi import APIRouter, File, UploadFile, Form, HTTPException
from typing import Optional
from PIL import Image
import urllib.request
from io import BytesIO
from config import settings
import utils
import os
import json
from routers.donut_inference import process_document_donut
import logging
import io
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
router = APIRouter()
def count_values(obj):
if isinstance(obj, dict):
count = 0
for value in obj.values():
count += count_values(value)
return count
elif isinstance(obj, list):
count = 0
for item in obj:
count += count_values(item)
return count
else:
return 1
@router.post("/inference")
async def run_inference(
file: Optional[UploadFile] = File(None),
image_url: Optional[str] = Form(None),
model_in_use: str = Form('donut'),
shipper_id: Optional[int] = Form(None)
):
# Validate input
if not file and not image_url:
return {"info": "No input provided"}
# Log the shipper_id that was received
logger.info(f"Received inference request with shipper_id: {shipper_id}")
# Convert shipper_id to string if provided (config.py expects a string)
shipper_id_str = str(shipper_id) if shipper_id is not None else "default_shipper"
logger.info(f"Using shipper_id: {shipper_id_str} for model selection")
result = []
processing_time = 0
try:
if file:
# Ensure the uploaded file is a JPG image
if file.content_type not in ["image/jpeg", "image/jpg"]:
logger.warning(f"Invalid file type: {file.content_type}")
return {"error": "Invalid file type. Only JPG images are allowed."}
logger.info(f"Processing file: {file.filename}")
image = Image.open(BytesIO(await file.read()))
if model_in_use == 'donut':
# Pass the shipper_id to the processing function
result, processing_time = process_document_donut(image, shipper_id_str)
utils.log_stats(settings.inference_stats_file, [processing_time, count_values(result), file.filename, settings.model])
logger.info(f"Processing time: {processing_time:.2f} seconds with model: {settings.model}")
else:
logger.warning(f"Unsupported model: {model_in_use}")
return {"error": f"Unsupported model: {model_in_use}"}
elif image_url:
logger.info(f"Processing image from URL: {image_url}")
# test image url: https://raw.githubusercontent.com/katanaml/sparrow/main/sparrow-data/docs/input/invoices/processed/images/invoice_10.jpg
try:
with urllib.request.urlopen(image_url) as url:
image = Image.open(BytesIO(url.read()))
if model_in_use == 'donut':
# Pass the shipper_id to the processing function
result, processing_time = process_document_donut(image, shipper_id_str)
# parse file name from url
file_name = image_url.split("/")[-1]
utils.log_stats(settings.inference_stats_file, [processing_time, count_values(result), file_name, settings.model])
logger.info(f"Processing time inference: {processing_time:.2f} seconds with model: {settings.model}")
else:
logger.warning(f"Unsupported model: {model_in_use}")
return {"error": f"Unsupported model: {model_in_use}"}
except Exception as e:
logger.error(f"Error processing image URL: {str(e)}")
return {"error": f"Error processing image URL: {str(e)}"}
except Exception as e:
logger.error(f"Error during inference: {str(e)}")
return {"error": f"Inference failed: {str(e)}"}
return result
@router.get("/statistics")
async def get_statistics():
file_path = settings.inference_stats_file
# Check if the file exists, and read its content
if os.path.exists(file_path):
with open(file_path, 'r') as file:
try:
content = json.load(file)
except json.JSONDecodeError:
content = []
else:
content = []
return content |