Spaces:
Running
Running
from fastapi import FastAPI, File, UploadFile, HTTPException | |
from fastapi.responses import JSONResponse | |
from fastapi.middleware.cors import CORSMiddleware | |
from transformers import pipeline | |
from PIL import Image | |
import io | |
import logging | |
from datetime import datetime | |
import asyncio | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
app = FastAPI(title="Age Detection API", version="1.0.0") | |
# Add CORS middleware - CRITICAL FIX | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], # In production, specify your FlutterFlow domain | |
allow_credentials=True, | |
allow_methods=["GET", "POST"], | |
allow_headers=["*"], | |
) | |
# Global variable to store the model | |
pipe = None | |
def load_model(): | |
"""Load the model with error handling""" | |
global pipe | |
try: | |
logger.info("Loading age classification model...") | |
pipe = pipeline("image-classification", model="nateraw/vit-age-classifier") | |
logger.info("Model loaded successfully") | |
return True | |
except Exception as e: | |
logger.error(f"Failed to load model: {e}") | |
return False | |
# Load model on startup | |
async def startup_event(): | |
success = load_model() | |
if not success: | |
logger.error("Failed to initialize model on startup") | |
async def root(): | |
return {"message": "Age Detection API is running", "status": "healthy"} | |
async def health_check(): | |
"""Keep-alive endpoint to prevent sleeping""" | |
global pipe | |
model_status = "loaded" if pipe is not None else "not_loaded" | |
return { | |
"status": "alive", | |
"timestamp": datetime.now().isoformat(), | |
"model_status": model_status | |
} | |
async def predict(file: UploadFile = File(...)): | |
global pipe | |
try: | |
# Check if model is loaded | |
if pipe is None: | |
logger.warning("Model not loaded, attempting to load...") | |
success = load_model() | |
if not success: | |
raise HTTPException(status_code=500, detail="Model failed to load") | |
# Validate file type - more robust approach | |
# Don't rely solely on content_type as it might be incorrect | |
valid_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.gif', '.webp'] | |
filename_lower = (file.filename or '').lower() | |
# Check both content type and file extension | |
is_valid_content_type = file.content_type and file.content_type.startswith('image/') | |
is_valid_extension = any(filename_lower.endswith(ext) for ext in valid_extensions) | |
if not (is_valid_content_type or is_valid_extension): | |
logger.warning(f"Invalid file type: content_type={file.content_type}, filename={file.filename}") | |
raise HTTPException(status_code=400, detail="File must be an image") | |
# Read and process image | |
logger.info(f"Processing image: {file.filename}") | |
image_data = await file.read() | |
# Optimize image processing with better error handling | |
try: | |
image = Image.open(io.BytesIO(image_data)) | |
# Verify it's actually an image by trying to get basic info | |
image.verify() # This will raise an exception if not a valid image | |
# Reopen the image since verify() closes it | |
image = Image.open(io.BytesIO(image_data)).convert("RGB") | |
# Resize large images to improve speed | |
max_size = (1024, 1024) | |
if image.size[0] > max_size[0] or image.size[1] > max_size[1]: | |
image.thumbnail(max_size, Image.Resampling.LANCZOS) | |
logger.info(f"Resized image to {image.size}") | |
except Exception as e: | |
logger.error(f"Image processing error: {e}") | |
raise HTTPException(status_code=400, detail="Invalid or corrupted image file") | |
# Run prediction with timeout | |
try: | |
logger.info("Running model prediction...") | |
# Add timeout to prevent hanging | |
results = await asyncio.wait_for( | |
asyncio.to_thread(pipe, image), | |
timeout=30.0 | |
) | |
logger.info(f"Prediction completed: {len(results)} results") | |
except asyncio.TimeoutError: | |
logger.error("Model prediction timed out") | |
raise HTTPException(status_code=504, detail="Prediction timed out") | |
except Exception as e: | |
logger.error(f"Model prediction error: {e}") | |
raise HTTPException(status_code=500, detail="Prediction failed") | |
return JSONResponse(content={ | |
"results": results, | |
"timestamp": datetime.now().isoformat(), | |
"image_size": image.size | |
}) | |
except HTTPException: | |
raise | |
except Exception as e: | |
logger.error(f"Unexpected error: {e}") | |
raise HTTPException(status_code=500, detail="Internal server error") | |
# Additional endpoint to warm up the model | |
async def warmup(): | |
"""Endpoint to warm up the model""" | |
global pipe | |
if pipe is None: | |
success = load_model() | |
return {"status": "loaded" if success else "failed"} | |
return {"status": "already_loaded"} |