|
|
""" |
|
|
FastAPI application for image colorization using ColorizeNet model |
|
|
with Firebase App Check integration |
|
|
""" |
|
|
import os |
|
|
import uuid |
|
|
import logging |
|
|
from pathlib import Path |
|
|
from typing import Optional |
|
|
from fastapi import FastAPI, File, UploadFile, HTTPException, Depends, Request |
|
|
from fastapi.responses import FileResponse, JSONResponse |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from fastapi.staticfiles import StaticFiles |
|
|
import firebase_admin |
|
|
from firebase_admin import credentials, app_check, auth as firebase_auth |
|
|
import numpy as np |
|
|
import torch |
|
|
from PIL import Image |
|
|
import io |
|
|
|
|
|
from app.colorize_model import ColorizeModel |
|
|
from app.config import settings |
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
|
|
) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
app = FastAPI( |
|
|
title="Colorize API", |
|
|
description="Image colorization API using ColorizeNet model", |
|
|
version="1.0.0" |
|
|
) |
|
|
|
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
|
|
|
firebase_cred_path = os.getenv("FIREBASE_CREDENTIALS_PATH", "colorize-662df-firebase-adminsdk-fbsvc-e080668793.json") |
|
|
if os.path.exists(firebase_cred_path): |
|
|
try: |
|
|
cred = credentials.Certificate(firebase_cred_path) |
|
|
firebase_admin.initialize_app(cred) |
|
|
logger.info("Firebase Admin SDK initialized") |
|
|
except Exception as e: |
|
|
logger.warning("Failed to initialize Firebase: %s", str(e)) |
|
|
firebase_admin.initialize_app() |
|
|
else: |
|
|
logger.warning("Firebase credentials file not found. App Check will be disabled.") |
|
|
try: |
|
|
firebase_admin.initialize_app() |
|
|
except: |
|
|
pass |
|
|
|
|
|
|
|
|
UPLOAD_DIR = Path("uploads") |
|
|
RESULT_DIR = Path("results") |
|
|
UPLOAD_DIR.mkdir(exist_ok=True) |
|
|
RESULT_DIR.mkdir(exist_ok=True) |
|
|
|
|
|
|
|
|
app.mount("/results", StaticFiles(directory="results"), name="results") |
|
|
app.mount("/uploads", StaticFiles(directory="uploads"), name="uploads") |
|
|
|
|
|
|
|
|
colorize_model = None |
|
|
|
|
|
@app.get("/") |
|
|
async def root(): |
|
|
return { |
|
|
"app": "Colorize API", |
|
|
"version": "1.0.0", |
|
|
"health": "/health", |
|
|
"upload": "/upload", |
|
|
"colorize": "/colorize" |
|
|
} |
|
|
|
|
|
@app.on_event("startup") |
|
|
async def startup_event(): |
|
|
"""Initialize the colorization model on startup""" |
|
|
global colorize_model |
|
|
try: |
|
|
logger.info("Loading ColorizeNet model...") |
|
|
colorize_model = ColorizeModel(settings.MODEL_ID) |
|
|
logger.info("ColorizeNet model loaded successfully") |
|
|
except Exception as e: |
|
|
logger.error("Failed to load ColorizeNet model: %s", str(e)) |
|
|
|
|
|
|
|
|
@app.on_event("shutdown") |
|
|
async def shutdown_event(): |
|
|
"""Cleanup on shutdown""" |
|
|
global colorize_model |
|
|
if colorize_model: |
|
|
del colorize_model |
|
|
logger.info("Application shutdown") |
|
|
|
|
|
def _extract_bearer_token(authorization_header: str | None) -> str | None: |
|
|
if not authorization_header: |
|
|
return None |
|
|
parts = authorization_header.split(" ", 1) |
|
|
if len(parts) == 2 and parts[0].lower() == "bearer": |
|
|
return parts[1].strip() |
|
|
return None |
|
|
|
|
|
|
|
|
async def verify_request(request: Request): |
|
|
""" |
|
|
Accept either: |
|
|
- Firebase Auth id_token via Authorization: Bearer <id_token> |
|
|
- Firebase App Check token via X-Firebase-AppCheck (when ENABLE_APP_CHECK=true) |
|
|
""" |
|
|
|
|
|
if not firebase_admin._apps or os.getenv("DISABLE_AUTH", "false").lower() == "true": |
|
|
return True |
|
|
|
|
|
|
|
|
bearer = _extract_bearer_token(request.headers.get("Authorization")) |
|
|
if bearer: |
|
|
try: |
|
|
decoded = firebase_auth.verify_id_token(bearer) |
|
|
request.state.user = decoded |
|
|
logger.info("Firebase Auth id_token verified for uid: %s", decoded.get("uid")) |
|
|
return True |
|
|
except Exception as e: |
|
|
logger.warning("Auth token verification failed: %s", str(e)) |
|
|
|
|
|
|
|
|
|
|
|
if settings.ENABLE_APP_CHECK: |
|
|
app_check_token = request.headers.get("X-Firebase-AppCheck") |
|
|
if not app_check_token: |
|
|
raise HTTPException(status_code=401, detail="Missing App Check token") |
|
|
try: |
|
|
app_check_claims = app_check.verify_token(app_check_token) |
|
|
logger.info("App Check token verified for: %s", app_check_claims.get("app_id")) |
|
|
return True |
|
|
except Exception as e: |
|
|
logger.warning("App Check token verification failed: %s", str(e)) |
|
|
raise HTTPException(status_code=401, detail="Invalid App Check token") |
|
|
|
|
|
|
|
|
return True |
|
|
|
|
|
@app.get("/health") |
|
|
async def health_check(): |
|
|
"""Health check endpoint""" |
|
|
return { |
|
|
"status": "healthy", |
|
|
"model_loaded": colorize_model is not None |
|
|
} |
|
|
|
|
|
@app.post("/upload") |
|
|
async def upload_image( |
|
|
file: UploadFile = File(...), |
|
|
verified: bool = Depends(verify_request) |
|
|
): |
|
|
""" |
|
|
Upload an image and return the uploaded image URL |
|
|
""" |
|
|
if not file.content_type or not file.content_type.startswith("image/"): |
|
|
raise HTTPException(status_code=400, detail="File must be an image") |
|
|
|
|
|
|
|
|
file_id = str(uuid.uuid4()) |
|
|
file_extension = Path(file.filename).suffix or ".jpg" |
|
|
filename = f"{file_id}{file_extension}" |
|
|
filepath = UPLOAD_DIR / filename |
|
|
|
|
|
|
|
|
try: |
|
|
contents = await file.read() |
|
|
with open(filepath, "wb") as f: |
|
|
f.write(contents) |
|
|
logger.info("Image uploaded: %s", filename) |
|
|
|
|
|
|
|
|
base_url = os.getenv("BASE_URL", os.getenv("SPACE_HOST", "http://localhost:7860")) |
|
|
image_url = f"{base_url}/uploads/{filename}" |
|
|
|
|
|
return { |
|
|
"success": True, |
|
|
"image_id": file_id, |
|
|
"image_url": image_url, |
|
|
"filename": filename |
|
|
} |
|
|
except Exception as e: |
|
|
logger.error("Error uploading image: %s", str(e)) |
|
|
raise HTTPException(status_code=500, detail=f"Error uploading image: {str(e)}") |
|
|
|
|
|
@app.post("/colorize") |
|
|
async def colorize_image( |
|
|
file: UploadFile = File(...), |
|
|
verified: bool = Depends(verify_request) |
|
|
): |
|
|
""" |
|
|
Colorize an uploaded grayscale image using ColorizeNet |
|
|
Returns the colorized image URL |
|
|
""" |
|
|
if colorize_model is None: |
|
|
raise HTTPException(status_code=503, detail="Colorization model not loaded") |
|
|
|
|
|
if not file.content_type or not file.content_type.startswith("image/"): |
|
|
raise HTTPException(status_code=400, detail="File must be an image") |
|
|
|
|
|
try: |
|
|
|
|
|
contents = await file.read() |
|
|
image = Image.open(io.BytesIO(contents)) |
|
|
|
|
|
|
|
|
if image.mode != "RGB": |
|
|
image = image.convert("RGB") |
|
|
|
|
|
|
|
|
logger.info("Colorizing image...") |
|
|
colorized_image = colorize_model.colorize(image) |
|
|
|
|
|
|
|
|
file_id = str(uuid.uuid4()) |
|
|
result_filename = f"{file_id}.jpg" |
|
|
result_filepath = RESULT_DIR / result_filename |
|
|
|
|
|
colorized_image.save(result_filepath, "JPEG", quality=95) |
|
|
logger.info("Colorized image saved: %s", result_filename) |
|
|
|
|
|
|
|
|
base_url = os.getenv("BASE_URL", os.getenv("SPACE_HOST", "http://localhost:7860")) |
|
|
download_url = f"{base_url}/results/{result_filename}" |
|
|
api_download_url = f"{base_url}/download/{file_id}" |
|
|
|
|
|
return { |
|
|
"success": True, |
|
|
"result_id": file_id, |
|
|
"download_url": download_url, |
|
|
"api_download_url": api_download_url, |
|
|
"filename": result_filename |
|
|
} |
|
|
except Exception as e: |
|
|
logger.error("Error colorizing image: %s", str(e)) |
|
|
raise HTTPException(status_code=500, detail=f"Error colorizing image: {str(e)}") |
|
|
|
|
|
@app.get("/download/{file_id}") |
|
|
async def download_result( |
|
|
file_id: str, |
|
|
verified: bool = Depends(verify_request) |
|
|
): |
|
|
""" |
|
|
Download the colorized image by file ID |
|
|
""" |
|
|
result_filepath = RESULT_DIR / f"{file_id}.jpg" |
|
|
|
|
|
if not result_filepath.exists(): |
|
|
raise HTTPException(status_code=404, detail="Result not found") |
|
|
|
|
|
return FileResponse( |
|
|
result_filepath, |
|
|
media_type="image/jpeg", |
|
|
filename=f"colorized_{file_id}.jpg" |
|
|
) |
|
|
|
|
|
@app.get("/results/{filename}") |
|
|
async def get_result_file(filename: str): |
|
|
""" |
|
|
Serve result files directly (public endpoint for browser access) |
|
|
""" |
|
|
result_filepath = RESULT_DIR / filename |
|
|
|
|
|
if not result_filepath.exists(): |
|
|
raise HTTPException(status_code=404, detail="File not found") |
|
|
|
|
|
return FileResponse( |
|
|
result_filepath, |
|
|
media_type="image/jpeg" |
|
|
) |
|
|
|
|
|
@app.get("/uploads/{filename}") |
|
|
async def get_upload_file(filename: str): |
|
|
""" |
|
|
Serve uploaded files directly |
|
|
""" |
|
|
upload_filepath = UPLOAD_DIR / filename |
|
|
|
|
|
if not upload_filepath.exists(): |
|
|
raise HTTPException(status_code=404, detail="File not found") |
|
|
|
|
|
return FileResponse( |
|
|
upload_filepath, |
|
|
media_type="image/jpeg" |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
port = int(os.getenv("PORT", "7860")) |
|
|
uvicorn.run(app, host="0.0.0.0", port=port) |
|
|
|