|
|
from fastapi import FastAPI, File, UploadFile, HTTPException, Header, Request, Form |
|
|
from typing import Optional |
|
|
from fastapi.responses import FileResponse |
|
|
from huggingface_hub import hf_hub_download |
|
|
import uuid |
|
|
import os |
|
|
import io |
|
|
import json |
|
|
from PIL import Image |
|
|
import torch |
|
|
from torchvision import transforms |
|
|
from app.database import ( |
|
|
get_database, |
|
|
log_api_call, |
|
|
log_image_upload, |
|
|
log_colorization, |
|
|
log_media_click, |
|
|
close_connection, |
|
|
) |
|
|
try: |
|
|
from firebase_admin import auth as firebase_auth |
|
|
except ImportError: |
|
|
firebase_auth = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app = FastAPI(title="Text-Guided Image Colorization API") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
import firebase_admin |
|
|
from firebase_admin import credentials, app_check |
|
|
|
|
|
firebase_json = os.getenv("FIREBASE_CREDENTIALS") |
|
|
|
|
|
if firebase_json: |
|
|
print("🔥 Loading Firebase credentials from ENV...") |
|
|
firebase_dict = json.loads(firebase_json) |
|
|
cred = credentials.Certificate(firebase_dict) |
|
|
firebase_admin.initialize_app(cred) |
|
|
else: |
|
|
print("⚠️ No Firebase credentials found. Firebase disabled.") |
|
|
|
|
|
except Exception as e: |
|
|
print("❌ Firebase initialization failed:", e) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
UPLOAD_DIR = "/tmp/uploads" |
|
|
RESULTS_DIR = "/tmp/results" |
|
|
os.makedirs(UPLOAD_DIR, exist_ok=True) |
|
|
os.makedirs(RESULTS_DIR, exist_ok=True) |
|
|
|
|
|
MEDIA_CLICK_DEFAULT_CATEGORY = os.getenv("DEFAULT_CATEGORY_FALLBACK", "69368fcd2e46bd68ae1889b2") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MODEL_REPO = "Hammad712/GAN-Colorization-Model" |
|
|
MODEL_FILENAME = "generator.pt" |
|
|
|
|
|
print("⬇️ Downloading model...") |
|
|
model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILENAME) |
|
|
|
|
|
print("📦 Loading model weights...") |
|
|
state_dict = torch.load(model_path, map_location="cpu") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def colorize_image(img: Image.Image): |
|
|
""" Dummy colorizer (replace with real model.predict) """ |
|
|
transform = transforms.ToTensor() |
|
|
tensor = transform(img.convert("L")).unsqueeze(0) |
|
|
tensor = tensor.repeat(1, 3, 1, 1) |
|
|
output_img = transforms.ToPILImage()(tensor.squeeze()) |
|
|
return output_img |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.on_event("startup") |
|
|
async def startup_event(): |
|
|
"""Initialize MongoDB on startup""" |
|
|
try: |
|
|
db = get_database() |
|
|
if db is not None: |
|
|
print("✅ MongoDB initialized successfully!") |
|
|
except Exception as e: |
|
|
print(f"⚠️ MongoDB initialization failed: {e}") |
|
|
|
|
|
@app.on_event("shutdown") |
|
|
async def shutdown_event(): |
|
|
"""Cleanup on shutdown""" |
|
|
close_connection() |
|
|
print("Application shutdown") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/health") |
|
|
def health_check(request: Request): |
|
|
response = {"status": "healthy", "model_loaded": True} |
|
|
|
|
|
|
|
|
log_api_call( |
|
|
endpoint="/health", |
|
|
method="GET", |
|
|
status_code=200, |
|
|
response_data=response, |
|
|
ip_address=request.client.host if request.client else None |
|
|
) |
|
|
|
|
|
return response |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def verify_app_check_token(token: str): |
|
|
if not token or len(token) < 20: |
|
|
raise HTTPException(status_code=401, detail="Invalid Firebase App Check token") |
|
|
return True |
|
|
|
|
|
def _resolve_user_id(request: Request, supplied_user_id: Optional[str]) -> Optional[str]: |
|
|
"""Return supplied user_id if provided and not empty, otherwise None (will auto-generate in log_media_click).""" |
|
|
if supplied_user_id and supplied_user_id.strip(): |
|
|
return supplied_user_id.strip() |
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/upload") |
|
|
async def upload_image( |
|
|
request: Request, |
|
|
file: UploadFile = File(...), |
|
|
x_firebase_appcheck: str = Header(None), |
|
|
user_id: Optional[str] = Form(None), |
|
|
category_id: Optional[str] = Form(None), |
|
|
categoryId: Optional[str] = Form(None), |
|
|
): |
|
|
verify_app_check_token(x_firebase_appcheck) |
|
|
|
|
|
ip_address = request.client.host if request.client else None |
|
|
effective_user_id = _resolve_user_id(request, user_id) |
|
|
effective_category_id = (category_id or categoryId) if (category_id or categoryId) else None |
|
|
if effective_category_id: |
|
|
effective_category_id = effective_category_id.strip() if isinstance(effective_category_id, str) else effective_category_id |
|
|
if not effective_category_id: |
|
|
effective_category_id = None |
|
|
|
|
|
if not file.content_type.startswith("image/"): |
|
|
log_api_call( |
|
|
endpoint="/upload", |
|
|
method="POST", |
|
|
status_code=400, |
|
|
error="Invalid file type", |
|
|
ip_address=ip_address |
|
|
) |
|
|
raise HTTPException(status_code=400, detail="Invalid file type") |
|
|
|
|
|
image_id = f"{uuid.uuid4()}.jpg" |
|
|
file_path = os.path.join(UPLOAD_DIR, image_id) |
|
|
|
|
|
img_bytes = await file.read() |
|
|
file_size = len(img_bytes) |
|
|
|
|
|
with open(file_path, "wb") as f: |
|
|
f.write(img_bytes) |
|
|
|
|
|
base_url = "https://logicgoinfotechspaces-text-guided-image-colorization.hf.space" |
|
|
|
|
|
response_data = { |
|
|
"success": True, |
|
|
"image_id": image_id.replace(".jpg", ""), |
|
|
"file_url": f"{base_url}/uploads/{image_id}" |
|
|
} |
|
|
|
|
|
|
|
|
log_image_upload( |
|
|
image_id=image_id.replace(".jpg", ""), |
|
|
filename=file.filename or image_id, |
|
|
file_size=file_size, |
|
|
content_type=file.content_type or "image/jpeg", |
|
|
user_id=effective_user_id, |
|
|
ip_address=ip_address |
|
|
) |
|
|
|
|
|
log_api_call( |
|
|
endpoint="/upload", |
|
|
method="POST", |
|
|
status_code=200, |
|
|
request_data={"filename": file.filename, "content_type": file.content_type}, |
|
|
response_data=response_data, |
|
|
user_id=effective_user_id, |
|
|
ip_address=ip_address |
|
|
) |
|
|
|
|
|
log_media_click( |
|
|
user_id=effective_user_id, |
|
|
category_id=effective_category_id, |
|
|
endpoint_path=str(request.url.path), |
|
|
default_category_id=MEDIA_CLICK_DEFAULT_CATEGORY, |
|
|
) |
|
|
|
|
|
return response_data |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/colorize") |
|
|
async def colorize( |
|
|
request: Request, |
|
|
file: UploadFile = File(...), |
|
|
x_firebase_appcheck: str = Header(None), |
|
|
user_id: Optional[str] = Form(None), |
|
|
category_id: Optional[str] = Form(None), |
|
|
categoryId: Optional[str] = Form(None), |
|
|
): |
|
|
import time |
|
|
start_time = time.time() |
|
|
|
|
|
verify_app_check_token(x_firebase_appcheck) |
|
|
|
|
|
ip_address = request.client.host if request.client else None |
|
|
effective_user_id = _resolve_user_id(request, user_id) |
|
|
effective_category_id = (category_id or categoryId) if (category_id or categoryId) else None |
|
|
if effective_category_id: |
|
|
effective_category_id = effective_category_id.strip() if isinstance(effective_category_id, str) else effective_category_id |
|
|
if not effective_category_id: |
|
|
effective_category_id = None |
|
|
|
|
|
if not file.content_type.startswith("image/"): |
|
|
error_msg = "Invalid file type" |
|
|
log_api_call( |
|
|
endpoint="/colorize", |
|
|
method="POST", |
|
|
status_code=400, |
|
|
error=error_msg, |
|
|
ip_address=ip_address |
|
|
) |
|
|
|
|
|
log_colorization( |
|
|
result_id=None, |
|
|
model_type="gan", |
|
|
processing_time=None, |
|
|
user_id=effective_user_id, |
|
|
ip_address=ip_address, |
|
|
status="failed", |
|
|
error=error_msg |
|
|
) |
|
|
raise HTTPException(status_code=400, detail=error_msg) |
|
|
|
|
|
try: |
|
|
img = Image.open(io.BytesIO(await file.read())) |
|
|
output_img = colorize_image(img) |
|
|
|
|
|
processing_time = time.time() - start_time |
|
|
|
|
|
result_id = f"{uuid.uuid4()}.jpg" |
|
|
output_path = os.path.join(RESULTS_DIR, result_id) |
|
|
output_img.save(output_path) |
|
|
|
|
|
base_url = "https://logicgoinfotechspaces-text-guided-image-colorization.hf.space" |
|
|
|
|
|
result_id_clean = result_id.replace(".jpg", "") |
|
|
|
|
|
response_data = { |
|
|
"success": True, |
|
|
"result_id": result_id_clean, |
|
|
"download_url": f"{base_url}/results/{result_id}", |
|
|
"api_download": f"{base_url}/download/{result_id_clean}" |
|
|
} |
|
|
|
|
|
|
|
|
log_colorization( |
|
|
result_id=result_id_clean, |
|
|
model_type="gan", |
|
|
processing_time=processing_time, |
|
|
user_id=effective_user_id, |
|
|
ip_address=ip_address, |
|
|
status="success" |
|
|
) |
|
|
|
|
|
log_api_call( |
|
|
endpoint="/colorize", |
|
|
method="POST", |
|
|
status_code=200, |
|
|
request_data={"filename": file.filename, "content_type": file.content_type}, |
|
|
response_data=response_data, |
|
|
user_id=effective_user_id, |
|
|
ip_address=ip_address |
|
|
) |
|
|
|
|
|
log_media_click( |
|
|
user_id=effective_user_id, |
|
|
category_id=effective_category_id, |
|
|
endpoint_path=str(request.url.path), |
|
|
default_category_id=MEDIA_CLICK_DEFAULT_CATEGORY, |
|
|
) |
|
|
|
|
|
return response_data |
|
|
except Exception as e: |
|
|
error_msg = str(e) |
|
|
logger.error("Error colorizing image: %s", error_msg) |
|
|
|
|
|
|
|
|
log_colorization( |
|
|
result_id=None, |
|
|
model_type="gan", |
|
|
processing_time=None, |
|
|
user_id=effective_user_id, |
|
|
ip_address=ip_address, |
|
|
status="failed", |
|
|
error=error_msg |
|
|
) |
|
|
|
|
|
log_api_call( |
|
|
endpoint="/colorize", |
|
|
method="POST", |
|
|
status_code=500, |
|
|
error=error_msg, |
|
|
user_id=effective_user_id, |
|
|
ip_address=ip_address |
|
|
) |
|
|
raise HTTPException(status_code=500, detail=f"Error colorizing image: {error_msg}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/download/{file_id}") |
|
|
def download_result( |
|
|
request: Request, |
|
|
file_id: str, |
|
|
x_firebase_appcheck: str = Header(None) |
|
|
): |
|
|
verify_app_check_token(x_firebase_appcheck) |
|
|
|
|
|
ip_address = request.client.host if request.client else None |
|
|
|
|
|
filename = f"{file_id}.jpg" |
|
|
path = os.path.join(RESULTS_DIR, filename) |
|
|
|
|
|
if not os.path.exists(path): |
|
|
log_api_call( |
|
|
endpoint=f"/download/{file_id}", |
|
|
method="GET", |
|
|
status_code=404, |
|
|
error="Result not found", |
|
|
ip_address=ip_address |
|
|
) |
|
|
raise HTTPException(status_code=404, detail="Result not found") |
|
|
|
|
|
log_api_call( |
|
|
endpoint=f"/download/{file_id}", |
|
|
method="GET", |
|
|
status_code=200, |
|
|
request_data={"file_id": file_id}, |
|
|
ip_address=ip_address |
|
|
) |
|
|
|
|
|
return FileResponse(path, media_type="image/jpeg") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/results/{filename}") |
|
|
def get_result(request: Request, filename: str): |
|
|
ip_address = request.client.host if request.client else None |
|
|
|
|
|
path = os.path.join(RESULTS_DIR, filename) |
|
|
if not os.path.exists(path): |
|
|
log_api_call( |
|
|
endpoint=f"/results/{filename}", |
|
|
method="GET", |
|
|
status_code=404, |
|
|
error="Result not found", |
|
|
ip_address=ip_address |
|
|
) |
|
|
raise HTTPException(status_code=404, detail="Result not found") |
|
|
|
|
|
log_api_call( |
|
|
endpoint=f"/results/{filename}", |
|
|
method="GET", |
|
|
status_code=200, |
|
|
request_data={"filename": filename}, |
|
|
ip_address=ip_address |
|
|
) |
|
|
|
|
|
return FileResponse(path, media_type="image/jpeg") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/uploads/{filename}") |
|
|
def get_upload(request: Request, filename: str): |
|
|
ip_address = request.client.host if request.client else None |
|
|
|
|
|
path = os.path.join(UPLOAD_DIR, filename) |
|
|
if not os.path.exists(path): |
|
|
log_api_call( |
|
|
endpoint=f"/uploads/{filename}", |
|
|
method="GET", |
|
|
status_code=404, |
|
|
error="File not found", |
|
|
ip_address=ip_address |
|
|
) |
|
|
raise HTTPException(status_code=404, detail="File not found") |
|
|
|
|
|
log_api_call( |
|
|
endpoint=f"/uploads/{filename}", |
|
|
method="GET", |
|
|
status_code=200, |
|
|
request_data={"filename": filename}, |
|
|
ip_address=ip_address |
|
|
) |
|
|
|
|
|
return FileResponse(path, media_type="image/jpeg") |
|
|
|