ageDetection / app /main.py
divyarspoton's picture
Update app/main.py
cc6bef8 verified
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from transformers import pipeline
from PIL import Image
import io
import logging
from datetime import datetime
import asyncio
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(title="Age Detection API", version="1.0.0")
# Add CORS middleware - CRITICAL FIX
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # In production, specify your FlutterFlow domain
allow_credentials=True,
allow_methods=["GET", "POST"],
allow_headers=["*"],
)
# Global variable to store the model
pipe = None
def load_model():
"""Load the model with error handling"""
global pipe
try:
logger.info("Loading age classification model...")
pipe = pipeline("image-classification", model="nateraw/vit-age-classifier")
logger.info("Model loaded successfully")
return True
except Exception as e:
logger.error(f"Failed to load model: {e}")
return False
# Load model on startup
@app.on_event("startup")
async def startup_event():
success = load_model()
if not success:
logger.error("Failed to initialize model on startup")
@app.get("/")
async def root():
return {"message": "Age Detection API is running", "status": "healthy"}
@app.get("/health")
async def health_check():
"""Keep-alive endpoint to prevent sleeping"""
global pipe
model_status = "loaded" if pipe is not None else "not_loaded"
return {
"status": "alive",
"timestamp": datetime.now().isoformat(),
"model_status": model_status
}
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
global pipe
try:
# Check if model is loaded
if pipe is None:
logger.warning("Model not loaded, attempting to load...")
success = load_model()
if not success:
raise HTTPException(status_code=500, detail="Model failed to load")
# Validate file type - more robust approach
# Don't rely solely on content_type as it might be incorrect
valid_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.gif', '.webp']
filename_lower = (file.filename or '').lower()
# Check both content type and file extension
is_valid_content_type = file.content_type and file.content_type.startswith('image/')
is_valid_extension = any(filename_lower.endswith(ext) for ext in valid_extensions)
if not (is_valid_content_type or is_valid_extension):
logger.warning(f"Invalid file type: content_type={file.content_type}, filename={file.filename}")
raise HTTPException(status_code=400, detail="File must be an image")
# Read and process image
logger.info(f"Processing image: {file.filename}")
image_data = await file.read()
# Optimize image processing with better error handling
try:
image = Image.open(io.BytesIO(image_data))
# Verify it's actually an image by trying to get basic info
image.verify() # This will raise an exception if not a valid image
# Reopen the image since verify() closes it
image = Image.open(io.BytesIO(image_data)).convert("RGB")
# Resize large images to improve speed
max_size = (1024, 1024)
if image.size[0] > max_size[0] or image.size[1] > max_size[1]:
image.thumbnail(max_size, Image.Resampling.LANCZOS)
logger.info(f"Resized image to {image.size}")
except Exception as e:
logger.error(f"Image processing error: {e}")
raise HTTPException(status_code=400, detail="Invalid or corrupted image file")
# Run prediction with timeout
try:
logger.info("Running model prediction...")
# Add timeout to prevent hanging
results = await asyncio.wait_for(
asyncio.to_thread(pipe, image),
timeout=30.0
)
logger.info(f"Prediction completed: {len(results)} results")
except asyncio.TimeoutError:
logger.error("Model prediction timed out")
raise HTTPException(status_code=504, detail="Prediction timed out")
except Exception as e:
logger.error(f"Model prediction error: {e}")
raise HTTPException(status_code=500, detail="Prediction failed")
return JSONResponse(content={
"results": results,
"timestamp": datetime.now().isoformat(),
"image_size": image.size
})
except HTTPException:
raise
except Exception as e:
logger.error(f"Unexpected error: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
# Additional endpoint to warm up the model
@app.post("/warmup")
async def warmup():
"""Endpoint to warm up the model"""
global pipe
if pipe is None:
success = load_model()
return {"status": "loaded" if success else "failed"}
return {"status": "already_loaded"}