from fastapi import FastAPI, File, UploadFile, Form, Request from fastapi.responses import HTMLResponse, FileResponse from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates import cv2 import os import torch from basicsr.archs.srvgg_arch import SRVGGNetCompact from gfpgan.utils import GFPGANer from realesrgan.utils import RealESRGANer app = FastAPI() app.mount("/static", StaticFiles(directory="static"), name="static") templates = Jinja2Templates(directory="templates") # Download weights if not exists def download_weights(): weights = [ ('realesr-general-x4v3.pth', 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth'), ('GFPGANv1.2.pth', 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.2.pth'), ('GFPGANv1.3.pth', 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth'), ('GFPGANv1.4.pth', 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth') ] for weight_file, weight_url in weights: if not os.path.exists(weight_file): os.system(f"wget {weight_url} -P .") # Initialize model and weights def initialize_models(): model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu') half = True if torch.cuda.is_available() else False return model, half # Perform image enhancement def enhance_image(img_path, version, scale, model, half): try: input_img = cv2.imread(img_path) face_enhancer = None if version == 'v1.2': face_enhancer = GFPGANer( model_path='GFPGANv1.2.pth', upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=None) elif version == 'v1.3': face_enhancer = GFPGANer( model_path='GFPGANv1.3.pth', upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=None) elif version == 'v1.4': face_enhancer = GFPGANer( model_path='GFPGANv1.4.pth', upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=None) elif version == 'RealESR-General-x4v3': face_enhancer = RealESRGANer( scale=4, model_path='realesr-general-x4v3.pth', model=model, tile=0, tile_pad=10, pre_pad=0, half=half) if face_enhancer: _, _, output = face_enhancer.enhance(input_img, has_aligned=False, only_center_face=False, paste_back=True) if scale != 2: interpolation = cv2.INTER_AREA if scale < 2 else cv2.INTER_LANCZOS4 h, w = input_img.shape[0:2] output = cv2.resize(output, (int(w * scale / 2), int(h * scale / 2)), interpolation=interpolation) output_path = f'output/out.jpg' cv2.imwrite(output_path, output) return output_path else: return None except Exception as e: print(f"Error enhancing image: {e}") return None # Download weights download_weights() # Initialize model model, half = initialize_models() @app.post("/process_image/") async def process_image(file: UploadFile = File(...), version: str = Form(...), scale: int = Form(...)): try: contents = await file.read() img_path = "temp.jpg" with open(img_path, "wb") as f: f.write(contents) output_path = enhance_image(img_path, version, scale, model, half) if output_path: return FileResponse(output_path, media_type='image/jpeg') else: return {"error": "Failed to process the image."} except Exception as e: return {"error": f"An error occurred: {e}"} app.mount("/", StaticFiles(directory="static", html=True), name="static") @app.get("/") def index() -> FileResponse: return FileResponse(path="/app/static/index.html", media_type="text/html")