import os import uuid import gc import subprocess import sys import traceback import shutil import logging from typing import Optional, List from pathlib import Path from fastapi import FastAPI, File, UploadFile, HTTPException, BackgroundTasks, Form from fastapi.responses import FileResponse from fastapi.middleware.cors import CORSMiddleware import uvicorn import psutil # --- Configuration --- SCRIPT_DIR = Path(__file__).parent.resolve() REAL_ESRGAN_DIR = SCRIPT_DIR / "Real-ESRGAN" INFERENCE_SCRIPT = REAL_ESRGAN_DIR / "inference_realesrgan.py" MODEL_DIR = REAL_ESRGAN_DIR / "weights" INPUT_DIR = SCRIPT_DIR / "api_inputs" OUTPUT_DIR = SCRIPT_DIR / "api_outputs" API_PORT = 8000 LOG_FILE = SCRIPT_DIR / "api.log" # --- Setup Logging --- logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", handlers=[ logging.FileHandler(LOG_FILE), logging.StreamHandler(sys.stdout) # Also print logs to console ] ) logger = logging.getLogger(__name__) # --- Create Directories --- INPUT_DIR.mkdir(exist_ok=True) OUTPUT_DIR.mkdir(exist_ok=True) # --- FastAPI App Initialization --- app = FastAPI( title="Image Enhancer API", description="API for enhancing images.", version="1.0.0" ) # --- CORS Middleware --- app.add_middleware( CORSMiddleware, allow_origins=["*"], # Allow all origins for simplicity, adjust in production allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # --- Global State --- processing_lock = False available_models = [] DEFAULT_MODEL_PREFERENCE = "RealESRGAN_x4plus" # Preferred default # Define allowed values for API input validation AVAILABLE_MODELS_API = ["RealESRGAN_x4plus", "RealESRGAN_x2plus"] ALLOWED_SCALES_API = [1.0, 2.0, 4.0, 8.0] DEFAULT_MODEL_API = "RealESRGAN_x4plus" DEFAULT_SCALE_API = 4.0 DEFAULT_TILE_SIZE = 400 # Default tile size to use on memory error retry def update_available_models(): """Scans the model directory and updates the list of available models.""" global available_models try: models = [f.stem for f in MODEL_DIR.glob("*.pth")] if not models: logger.warning(f"No model files (.pth) found in {MODEL_DIR}") available_models = sorted(models) logger.info(f"Available models updated: {available_models}") except Exception as e: logger.error(f"Error scanning model directory {MODEL_DIR}: {e}") available_models = [] # Initialize models on startup update_available_models() # --- Helper Functions --- def release_lock(): """Releases the processing lock.""" global processing_lock processing_lock = False logger.info("Processing lock released.") # --- API Endpoints --- @app.get("/") async def root(): """Root endpoint providing basic API information.""" return {"message": "Image Enhancer API is running"} @app.get("/models/", response_model=List[str]) async def get_models(): """Returns a list of available Real-ESRGAN models.""" if not available_models: update_available_models() # Attempt to rescan if list is empty if not available_models: raise HTTPException(status_code=404, detail=f"No models found in {MODEL_DIR}") return available_models @app.post("/enhance/", response_class=FileResponse) async def enhance_image( background_tasks: BackgroundTasks, file: UploadFile = File(...), model_name: Optional[str] = Form(DEFAULT_MODEL_API), outscale: float = Form(DEFAULT_SCALE_API), face_enhance: bool = Form(False), fp32: bool = Form(False), tile: Optional[int] = Form(0) ): """ Enhances an uploaded image using Real-ESRGAN. Automatically retries with tiling if an out-of-memory error is detected. """ global processing_lock temp_input_path = None temp_output_dir_for_request = None temp_input_dir_for_request = None # Added for consistency # --- Request Handling --- request_id = uuid.uuid4().hex logger.info(f"Received enhancement request ID: {request_id}") # Check processing lock if processing_lock: logger.warning(f"Request {request_id}: Server busy, denying request.") raise HTTPException( status_code=429, detail="Server is busy processing another image. Please try again shortly." ) processing_lock = True logger.info(f"Request {request_id}: Processing lock acquired.") # --- Input Validation --- # Validate model name against allowed list if model_name not in AVAILABLE_MODELS_API: logger.warning(f"Request {request_id}: Invalid model_name specified: '{model_name}'. Allowed: {AVAILABLE_MODELS_API}") release_lock() raise HTTPException( status_code=400, detail=f"Invalid model name '{model_name}'. Allowed values: {AVAILABLE_MODELS_API}" ) # Validate scale against allowed list if outscale not in ALLOWED_SCALES_API: logger.warning(f"Request {request_id}: Invalid outscale specified: '{outscale}'. Allowed: {ALLOWED_SCALES_API}") release_lock() raise HTTPException( status_code=400, detail=f"Invalid scale value '{outscale}'. Allowed values: {ALLOWED_SCALES_API}" ) # Validate file type if not file.content_type or not file.content_type.startswith("image/"): logger.warning(f"Request {request_id}: Invalid file type uploaded: {file.content_type}") release_lock() raise HTTPException(status_code=400, detail="Invalid file type. Please upload an image.") # --- Model Existence Check --- # Check if the validated model actually exists in the scanned directory if model_name not in available_models: logger.error(f"Request {request_id}: Model '{model_name}' is allowed but not found in {MODEL_DIR}. Scanned models: {available_models}") update_available_models() # Try rescanning if model_name not in available_models: release_lock() raise HTTPException( status_code=500, detail=f"Model '{model_name}' not found on server, even though it's an allowed option. Please check server configuration." ) final_model_name = model_name # Use the validated model name logger.info(f"Request {request_id}: Using validated model: {final_model_name}, scale: {outscale}") try: # --- File Handling --- # Create unique temporary paths for this request input_suffix = Path(file.filename).suffix if file.filename else '.png' # Use original filename for input file within its own request dir temp_input_filename = Path(file.filename).name if file.filename else f"input_{request_id}{input_suffix}" # Input directory for this specific request temp_input_dir_for_request = INPUT_DIR / request_id temp_input_dir_for_request.mkdir(exist_ok=True) temp_input_path = temp_input_dir_for_request / temp_input_filename # Output directory for this specific request's results temp_output_dir_for_request = OUTPUT_DIR / request_id temp_output_dir_for_request.mkdir(exist_ok=True) # Save uploaded file to its request-specific input dir try: logger.info(f"Request {request_id}: Saving uploaded file to {temp_input_path}") contents = await file.read() with open(temp_input_path, "wb") as buffer: buffer.write(contents) logger.info(f"Request {request_id}: Uploaded file saved successfully.") except Exception as e: logger.error(f"Request {request_id}: Failed to save uploaded file: {e}") raise HTTPException(status_code=500, detail="Failed to save uploaded file.") finally: await file.close() # Ensure file handle is closed # --- Inference Execution --- # Construct command (base_cmd now uses temp_input_path which includes the subdir) base_cmd = [ sys.executable, str(INFERENCE_SCRIPT), "-i", str(temp_input_path), "-o", str(temp_output_dir_for_request), "-n", final_model_name, "-s", str(outscale), ] if face_enhance: base_cmd.append("--face_enhance") if fp32: base_cmd.append("--fp32") # Add tile param only if explicitly provided (> 0) or during retry if tile > 0: base_cmd.extend(["-t", str(tile)]) logger.info(f"Request {request_id}: Preparing initial inference command...") # Execute the script - Attempt 1 (No Tile unless specified) try: logger.info(f"Request {request_id}: Running inference (Attempt 1): {' '.join(base_cmd)}") process = subprocess.run( base_cmd, capture_output=True, text=True, check=True, cwd=REAL_ESRGAN_DIR ) logger.info(f"Request {request_id}: Inference script (Attempt 1) stdout:{process.stdout}") if process.stderr: logger.warning(f"Request {request_id}: Inference script (Attempt 1) stderr:{process.stderr}") except (subprocess.CalledProcessError, RuntimeError) as e: error_output = "" if isinstance(e, subprocess.CalledProcessError): error_output = e.stderr logger.error(f"Request {request_id}: Inference script failed (Attempt 1) with exit code {e.returncode}") logger.error(f"Request {request_id}: Stdout: {e.stdout}") logger.error(f"Request {request_id}: Stderr: {e.stderr}") else: # Handle RuntimeError which might be raised by realesrgan directly error_output = str(e) logger.error(f"Request {request_id}: Inference script raised RuntimeError (Attempt 1): {e}") # Check if it's a memory error and tile wasn't already manually set is_memory_error = "memory" in error_output.lower() or "cuda" in error_output.lower() tile_arg_present = any(arg == "-t" for arg in base_cmd) if is_memory_error and not tile_arg_present: logger.warning(f"Request {request_id}: Detected potential memory error. Retrying with tiling (tile_size={DEFAULT_TILE_SIZE})...") # Attempt 2 (With Tile) retry_cmd = base_cmd + ["-t", str(DEFAULT_TILE_SIZE)] try: logger.info(f"Request {request_id}: Running inference (Attempt 2 - Tiled): {' '.join(retry_cmd)}") process = subprocess.run( retry_cmd, capture_output=True, text=True, check=True, cwd=REAL_ESRGAN_DIR ) logger.info(f"Request {request_id}: Inference script (Attempt 2 - Tiled) stdout:{process.stdout}") if process.stderr: logger.warning(f"Request {request_id}: Inference script (Attempt 2 - Tiled) stderr:{process.stderr}") except (subprocess.CalledProcessError, RuntimeError) as e2: logger.error(f"Request {request_id}: Inference script failed even on retry with tiling.") # Log the second error if isinstance(e2, subprocess.CalledProcessError): logger.error(f"Request {request_id}: Retry Exit Code: {e2.returncode}, Stderr: {e2.stderr}") error_output = e2.stderr # Use the error from the retry attempt else: logger.error(f"Request {request_id}: Retry RuntimeError: {e2}") error_output = str(e2) # Raise original error type but with potentially updated message from retry raise HTTPException(status_code=500, detail=f"Image enhancement failed, even with tiling: {error_output or 'Unknown error'}") else: # Not a memory error, or tile was already specified - fail normally raise HTTPException(status_code=500, detail=f"Image enhancement script failed: {error_output or 'Unknown error'}") except Exception as e: # Catch any other unexpected errors during subprocess execution logger.error(f"Request {request_id}: Unexpected error executing inference script: {e}") logger.error(traceback.format_exc()) raise HTTPException(status_code=500, detail=f"Failed to run enhancement process: {e}") # --- Result Handling --- # Find the output file (assumes script outputs one file with '_out' suffix) # The script `inference_realesrgan.py` saves the output as `{basename}_out.{ext}` original_basename = Path(temp_input_filename).stem expected_output_stem = f"{original_basename}_out" output_files = list(temp_output_dir_for_request.glob(f"{expected_output_stem}.*")) if not output_files: logger.error(f"Request {request_id}: No output file found in {temp_output_dir_for_request} matching stem {expected_output_stem}") raise HTTPException(status_code=500, detail="Enhancement finished, but output file not found.") output_path = output_files[0] output_media_type = f"image/{output_path.suffix.strip('.')}" output_filename = f"enhanced_{Path(file.filename).name}" if file.filename else f"enhanced_{request_id}{output_path.suffix}" logger.info(f"Request {request_id}: Enhancement successful. Output: {output_path}") # Schedule cleanup task (input file and the whole output dir for this request) # background_tasks.add_task(cleanup_files, [temp_input_path, temp_output_dir_for_request]) # Removed cleanup # Release lock AFTER scheduling cleanup but BEFORE returning response background_tasks.add_task(release_lock) # Return the enhanced image file return FileResponse( path=output_path, media_type=output_media_type, filename=output_filename ) except HTTPException as http_exc: # If an HTTPException occurred (validation, busy, etc.), release lock immediately release_lock() # Re-raise the exception to be handled by FastAPI raise http_exc except Exception as e: error_msg = f"Request {request_id}: Unexpected error during enhancement: {str(e)}" logger.error(error_msg) logger.error(traceback.format_exc()) # Ensure cleanup happens even on unexpected errors (Cleanup is removed, but keep release_lock) # We need to potentially clean up the created input directory as well if saving failed # For simplicity now, inputs/outputs persist on errors too, consistent with success path release_lock() raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {str(e)}") @app.get("/status/") async def status(): """Checks the API status and resource usage.""" logger.info("Status check requested.") return { "status": "ok" if not processing_lock else "busy", "processing_active": processing_lock, "available_models": available_models, "memory_usage": { "percent": f"{psutil.virtual_memory().percent}%", "available": f"{psutil.virtual_memory().available / (1024**3):.2f} GB", }, "cpu_usage": f"{psutil.cpu_percent()}%", "real_esrgan_dir_exists": REAL_ESRGAN_DIR.exists(), "inference_script_exists": INFERENCE_SCRIPT.exists(), "model_dir_exists": MODEL_DIR.exists(), "input_dir_exists": INPUT_DIR.exists(), "output_dir_exists": OUTPUT_DIR.exists(), } # --- Server Execution --- if __name__ == "__main__": logger.info(f"Starting Image Enhancer API server on port {API_PORT}...") logger.info(f"Real-ESRGAN Directory: {REAL_ESRGAN_DIR}") logger.info(f"Inference Script: {INFERENCE_SCRIPT}") logger.info(f"Model Directory: {MODEL_DIR}") logger.info(f"API Input Directory: {INPUT_DIR}") logger.info(f"API Output Directory: {OUTPUT_DIR}") update_available_models() # Ensure models are listed on startup uvicorn.run( "api:app", host="0.0.0.0", port=API_PORT, reload=False, # Use reload carefully, can cause issues with locking/models log_level="info" # Uvicorn's own log level )