Integrate FastAI colorization with Firebase auth and Gradio UI - Replace main.py with FastAI implementation - Add Gradio interface for Space UI - Add Firebase authentication to /colorize endpoint - Add curl examples documentation - Update test.py with User-Agent headers
e4599d1
| """ | |
| FastAPI application for FastAI GAN Image Colorization | |
| with Firebase Authentication and Gradio UI | |
| """ | |
| import os | |
| # Set environment variables BEFORE any imports | |
| os.environ["OMP_NUM_THREADS"] = "1" | |
| os.environ["HF_HOME"] = "/tmp/hf_cache" | |
| os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache" | |
| os.environ["HF_HUB_CACHE"] = "/tmp/hf_cache" | |
| os.environ["HUGGINGFACE_HUB_CACHE"] = "/tmp/hf_cache" | |
| os.environ["XDG_CACHE_HOME"] = "/tmp/hf_cache" | |
| os.environ["MPLCONFIGDIR"] = "/tmp/matplotlib_config" | |
| import io | |
| import uuid | |
| import logging | |
| from pathlib import Path | |
| from typing import Optional | |
| from fastapi import FastAPI, UploadFile, File, HTTPException, Depends, Request | |
| from fastapi.responses import FileResponse, JSONResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.staticfiles import StaticFiles | |
| import firebase_admin | |
| from firebase_admin import credentials, app_check, auth as firebase_auth | |
| from PIL import Image | |
| import torch | |
| import uvicorn | |
| import gradio as gr | |
| # FastAI imports | |
| from fastai.vision.all import * | |
| from huggingface_hub import from_pretrained_fastai | |
| from app.config import settings | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # Create writable directories | |
| Path("/tmp/hf_cache").mkdir(parents=True, exist_ok=True) | |
| Path("/tmp/matplotlib_config").mkdir(parents=True, exist_ok=True) | |
| Path("/tmp/colorize_uploads").mkdir(parents=True, exist_ok=True) | |
| Path("/tmp/colorize_results").mkdir(parents=True, exist_ok=True) | |
| # Initialize FastAPI app | |
| app = FastAPI( | |
| title="FastAI Image Colorizer API", | |
| description="Image colorization using FastAI GAN model with Firebase authentication", | |
| version="1.0.0" | |
| ) | |
| # CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Initialize Firebase Admin SDK | |
| firebase_cred_path = os.getenv("FIREBASE_CREDENTIALS_PATH", "/tmp/firebase-adminsdk.json") | |
| if os.path.exists(firebase_cred_path): | |
| try: | |
| cred = credentials.Certificate(firebase_cred_path) | |
| firebase_admin.initialize_app(cred) | |
| logger.info("Firebase Admin SDK initialized") | |
| except Exception as e: | |
| logger.warning("Failed to initialize Firebase: %s", str(e)) | |
| try: | |
| firebase_admin.initialize_app() | |
| except: | |
| pass | |
| else: | |
| logger.warning("Firebase credentials file not found. App Check will be disabled.") | |
| try: | |
| firebase_admin.initialize_app() | |
| except: | |
| pass | |
| # Storage directories | |
| UPLOAD_DIR = Path("/tmp/colorize_uploads") | |
| RESULT_DIR = Path("/tmp/colorize_results") | |
| # Mount static files | |
| app.mount("/results", StaticFiles(directory=str(RESULT_DIR)), name="results") | |
| app.mount("/uploads", StaticFiles(directory=str(UPLOAD_DIR)), name="uploads") | |
| # Initialize FastAI model | |
| learn = None | |
| model_load_error: Optional[str] = None | |
| async def startup_event(): | |
| """Load FastAI model on startup""" | |
| global learn, model_load_error | |
| try: | |
| model_id = os.getenv("MODEL_ID", "Hammad712/GAN-Colorization-Model") | |
| logger.info("🔄 Loading FastAI GAN Colorization Model: %s", model_id) | |
| learn = from_pretrained_fastai(model_id) | |
| logger.info("✅ Model loaded successfully!") | |
| model_load_error = None | |
| except Exception as e: | |
| error_msg = str(e) | |
| logger.error("❌ Failed to load model: %s", error_msg) | |
| model_load_error = error_msg | |
| # Don't raise - allow health check to work | |
| async def shutdown_event(): | |
| """Cleanup on shutdown""" | |
| global learn | |
| if learn: | |
| del learn | |
| logger.info("Application shutdown") | |
| def _extract_bearer_token(authorization_header: str | None) -> str | None: | |
| if not authorization_header: | |
| return None | |
| parts = authorization_header.split(" ", 1) | |
| if len(parts) == 2 and parts[0].lower() == "bearer": | |
| return parts[1].strip() | |
| return None | |
| async def verify_request(request: Request): | |
| """ | |
| Verify Firebase authentication | |
| Accept either: | |
| - Firebase Auth id_token via Authorization: Bearer <id_token> | |
| - Firebase App Check token via X-Firebase-AppCheck (when ENABLE_APP_CHECK=true) | |
| """ | |
| # If Firebase is not initialized or auth is explicitly disabled, allow | |
| if not firebase_admin._apps or os.getenv("DISABLE_AUTH", "false").lower() == "true": | |
| return True | |
| # Try Firebase Auth id_token first if present | |
| bearer = _extract_bearer_token(request.headers.get("Authorization")) | |
| if bearer: | |
| try: | |
| decoded = firebase_auth.verify_id_token(bearer) | |
| request.state.user = decoded | |
| logger.info("Firebase Auth id_token verified for uid: %s", decoded.get("uid")) | |
| return True | |
| except Exception as e: | |
| logger.warning("Auth token verification failed: %s", str(e)) | |
| # If App Check is enabled, require valid App Check token | |
| if settings.ENABLE_APP_CHECK: | |
| app_check_token = request.headers.get("X-Firebase-AppCheck") | |
| if not app_check_token: | |
| raise HTTPException(status_code=401, detail="Missing App Check token") | |
| try: | |
| app_check_claims = app_check.verify_token(app_check_token) | |
| logger.info("App Check token verified for: %s", app_check_claims.get("app_id")) | |
| return True | |
| except Exception as e: | |
| logger.warning("App Check token verification failed: %s", str(e)) | |
| raise HTTPException(status_code=401, detail="Invalid App Check token") | |
| # Neither token required nor provided → allow (App Check disabled) | |
| return True | |
| async def api_info(): | |
| """API info endpoint""" | |
| return { | |
| "app": "FastAI Image Colorizer API", | |
| "version": "1.0.0", | |
| "health": "/health", | |
| "colorize": "/colorize", | |
| "gradio": "/" | |
| } | |
| async def health_check(): | |
| """Health check endpoint""" | |
| response = { | |
| "status": "healthy", | |
| "model_loaded": learn is not None, | |
| "model_id": os.getenv("MODEL_ID", "Hammad712/GAN-Colorization-Model") | |
| } | |
| if model_load_error: | |
| response["model_error"] = model_load_error | |
| return response | |
| def colorize_pil(image: Image.Image) -> Image.Image: | |
| """Run model prediction and return colorized image""" | |
| if learn is None: | |
| raise RuntimeError("Model not loaded") | |
| if image.mode != "RGB": | |
| image = image.convert("RGB") | |
| pred = learn.predict(image) | |
| # Handle different return types from FastAI | |
| if isinstance(pred, (list, tuple)): | |
| colorized = pred[0] if len(pred) > 0 else image | |
| else: | |
| colorized = pred | |
| # Ensure we have a PIL Image | |
| if not isinstance(colorized, Image.Image): | |
| if isinstance(colorized, torch.Tensor): | |
| # Convert tensor to PIL | |
| if colorized.dim() == 4: | |
| colorized = colorized[0] | |
| if colorized.dim() == 3: | |
| colorized = colorized.permute(1, 2, 0).cpu() | |
| if colorized.dtype in (torch.float32, torch.float16): | |
| colorized = torch.clamp(colorized, 0, 1) | |
| colorized = (colorized * 255).byte() | |
| colorized = Image.fromarray(colorized.numpy(), 'RGB') | |
| else: | |
| raise ValueError(f"Unexpected tensor shape: {colorized.shape}") | |
| else: | |
| raise ValueError(f"Unexpected prediction type: {type(colorized)}") | |
| if colorized.mode != "RGB": | |
| colorized = colorized.convert("RGB") | |
| return colorized | |
| async def colorize_api( | |
| file: UploadFile = File(...), | |
| verified: bool = Depends(verify_request) | |
| ): | |
| """ | |
| Upload a black & white image -> returns colorized image. | |
| Requires Firebase authentication unless DISABLE_AUTH=true | |
| """ | |
| if learn is None: | |
| raise HTTPException(status_code=503, detail="Colorization model not loaded") | |
| if not file.content_type or not file.content_type.startswith("image/"): | |
| raise HTTPException(status_code=400, detail="File must be an image") | |
| try: | |
| img_bytes = await file.read() | |
| image = Image.open(io.BytesIO(img_bytes)).convert("RGB") | |
| logger.info("Colorizing image...") | |
| colorized = colorize_pil(image) | |
| output_filename = f"{uuid.uuid4()}.png" | |
| output_path = RESULT_DIR / output_filename | |
| colorized.save(output_path, "PNG") | |
| logger.info("Colorized image saved: %s", output_filename) | |
| # Return the image file | |
| return FileResponse( | |
| output_path, | |
| media_type="image/png", | |
| filename=f"colorized_{output_filename}" | |
| ) | |
| except Exception as e: | |
| logger.error("Error colorizing image: %s", str(e)) | |
| raise HTTPException(status_code=500, detail=f"Error colorizing image: {str(e)}") | |
| # ========================================================== | |
| # Gradio Interface (for Space UI) | |
| # ========================================================== | |
| def gradio_colorize(image): | |
| """Gradio colorization function""" | |
| if image is None: | |
| return None | |
| try: | |
| if learn is None: | |
| return None | |
| return colorize_pil(image) | |
| except Exception as e: | |
| logger.error("Gradio colorization error: %s", str(e)) | |
| return None | |
| title = "🎨 FastAI GAN Image Colorizer" | |
| description = "Upload a black & white photo to generate a colorized version using the FastAI GAN model." | |
| iface = gr.Interface( | |
| fn=gradio_colorize, | |
| inputs=gr.Image(type="pil", label="Upload B&W Image"), | |
| outputs=gr.Image(type="pil", label="Colorized Image"), | |
| title=title, | |
| description=description, | |
| ) | |
| # Mount Gradio app at root (this will be the Space UI) | |
| # Note: This will override the root endpoint, so use /api for API info | |
| app = gr.mount_gradio_app(app, iface, path="/") | |
| # ========================================================== | |
| # Run Server | |
| # ========================================================== | |
| if __name__ == "__main__": | |
| port = int(os.getenv("PORT", "7860")) | |
| uvicorn.run(app, host="0.0.0.0", port=port) | |