Spaces:
Running
on
T4
Running
on
T4
| import os | |
| os.environ["OMP_NUM_THREADS"] = "1" | |
| import gradio as gr | |
| import cv2 | |
| import shutil | |
| import uuid | |
| import insightface | |
| from insightface.app import FaceAnalysis | |
| from huggingface_hub import hf_hub_download | |
| import subprocess | |
| import numpy as np | |
| import threading | |
| from fastapi import FastAPI, UploadFile, File, HTTPException, Response | |
| from fastapi.responses import RedirectResponse | |
| from pydantic import BaseModel | |
| from motor.motor_asyncio import AsyncIOMotorClient | |
| from bson.objectid import ObjectId | |
| from gridfs import AsyncIOMotorGridFSBucket | |
| from gradio import mount_gradio_app | |
| import uvicorn | |
| import logging | |
| import io | |
| # ------------------------------------------------- | |
| # Logging | |
| # ------------------------------------------------- | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # ------------------------------------------------- | |
| # Paths | |
| # ------------------------------------------------- | |
| REPO_ID = "HariLogicgo/face_swap_models" | |
| BASE_DIR = "./workspace" | |
| UPLOAD_DIR = os.path.join(BASE_DIR, "uploads") | |
| RESULT_DIR = os.path.join(BASE_DIR, "results") | |
| MODELS_DIR = "./models" | |
| os.makedirs(UPLOAD_DIR, exist_ok=True) | |
| os.makedirs(RESULT_DIR, exist_ok=True) | |
| os.makedirs(MODELS_DIR, exist_ok=True) | |
| # ------------------------------------------------- | |
| # Download models | |
| # ------------------------------------------------- | |
| def download_models(): | |
| logger.info("Downloading models...") | |
| inswapper_path = hf_hub_download( | |
| repo_id=REPO_ID, | |
| filename="models/inswapper_128.onnx", | |
| repo_type="model", | |
| local_dir=MODELS_DIR | |
| ) | |
| buffalo_files = [ | |
| "1k3d68.onnx", | |
| "2d106det.onnx", | |
| "genderage.onnx", | |
| "det_10g.onnx", | |
| "w600k_r50.onnx" | |
| ] | |
| for f in buffalo_files: | |
| hf_hub_download( | |
| repo_id=REPO_ID, | |
| filename=f"models/buffalo_l/{f}", | |
| repo_type="model", | |
| local_dir=MODELS_DIR | |
| ) | |
| logger.info("Models downloaded successfully") | |
| return inswapper_path | |
| inswapper_path = download_models() | |
| # ------------------------------------------------- | |
| # Face Analysis + Swapper | |
| # ------------------------------------------------- | |
| providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] | |
| logger.info(f"Initializing FaceAnalysis with providers: {providers}") | |
| face_analysis_app = FaceAnalysis(name="buffalo_l", root=MODELS_DIR, providers=providers) | |
| face_analysis_app.prepare(ctx_id=0, det_size=(640, 640)) | |
| swapper = insightface.model_zoo.get_model(inswapper_path, providers=providers) | |
| logger.info("FaceAnalysis and swapper initialized") | |
| # ------------------------------------------------- | |
| # CodeFormer setup | |
| # ------------------------------------------------- | |
| CODEFORMER_PATH = "CodeFormer/inference_codeformer.py" | |
| def ensure_codeformer(): | |
| if not os.path.exists("CodeFormer"): | |
| logger.info("Cloning CodeFormer repository...") | |
| subprocess.run("git clone https://github.com/sczhou/CodeFormer.git", shell=True, check=True) | |
| subprocess.run("pip install -r CodeFormer/requirements.txt", shell=True, check=True) | |
| subprocess.run("python CodeFormer/basicsr/setup.py develop", shell=True, check=True) | |
| subprocess.run("python CodeFormer/scripts/download_pretrained_models.py facelib", shell=True, check=True) | |
| subprocess.run("python CodeFormer/scripts/download_pretrained_models.py CodeFormer", shell=True, check=True) | |
| logger.info("CodeFormer setup complete") | |
| ensure_codeformer() | |
| # ------------------------------------------------- | |
| # MongoDB + GridFS | |
| # ------------------------------------------------- | |
| MONGODB_URL = os.getenv( | |
| "MONGODB_URL", | |
| "mongodb+srv://harilogicgo_db_user:logicgoinfotech@cluster0.dcs1tnb.mongodb.net/?retryWrites=true&w=majority&appName=Cluster0" | |
| ) | |
| client = AsyncIOMotorClient(MONGODB_URL) | |
| database = client.FaceSwap | |
| fs_bucket = AsyncIOMotorGridFSBucket(database) | |
| logger.info("MongoDB + GridFS initialized") | |
| # ------------------------------------------------- | |
| # Lock for face swap | |
| # ------------------------------------------------- | |
| swap_lock = threading.Lock() | |
| # ------------------------------------------------- | |
| # Face Swap Pipeline | |
| # ------------------------------------------------- | |
| def face_swap_and_enhance(src_img, tgt_img): | |
| logger.info("Starting face swap and enhancement") | |
| try: | |
| with swap_lock: | |
| shutil.rmtree(UPLOAD_DIR, ignore_errors=True) | |
| shutil.rmtree(RESULT_DIR, ignore_errors=True) | |
| os.makedirs(UPLOAD_DIR, exist_ok=True) | |
| os.makedirs(RESULT_DIR, exist_ok=True) | |
| if not isinstance(src_img, np.ndarray) or not isinstance(tgt_img, np.ndarray): | |
| return None, None, "β Invalid input images" | |
| src_bgr = cv2.cvtColor(src_img, cv2.COLOR_RGB2BGR) | |
| tgt_bgr = cv2.cvtColor(tgt_img, cv2.COLOR_RGB2BGR) | |
| src_faces = face_analysis_app.get(src_bgr) | |
| tgt_faces = face_analysis_app.get(tgt_bgr) | |
| if not src_faces or not tgt_faces: | |
| return None, None, "β Face not detected" | |
| swapped_path = os.path.join(UPLOAD_DIR, f"swapped_{uuid.uuid4().hex[:8]}.jpg") | |
| swapped_bgr = swapper.get(tgt_bgr, tgt_faces[0], src_faces[0]) | |
| if swapped_bgr is None: | |
| return None, None, "β Face swap failed" | |
| cv2.imwrite(swapped_path, swapped_bgr) | |
| cmd = f"python {CODEFORMER_PATH} -w 0.7 --input_path {swapped_path} --output_path {RESULT_DIR} --bg_upsampler realesrgan --face_upsample" | |
| result = subprocess.run(cmd, shell=True, capture_output=True, text=True) | |
| if result.returncode != 0: | |
| return None, None, f"β CodeFormer failed:\n{result.stderr}" | |
| final_results_dir = os.path.join(RESULT_DIR, "final_results") | |
| final_files = [f for f in os.listdir(final_results_dir) if f.endswith(".png")] | |
| if not final_files: | |
| return None, None, "β No enhanced image found" | |
| final_path = os.path.join(final_results_dir, final_files[0]) | |
| final_img = cv2.cvtColor(cv2.imread(final_path), cv2.COLOR_BGR2RGB) | |
| return final_img, final_path, "" | |
| except Exception as e: | |
| return None, None, f"β Error: {str(e)}" | |
| # ------------------------------------------------- | |
| # Gradio Interface | |
| # ------------------------------------------------- | |
| with gr.Blocks() as demo: | |
| gr.Markdown("Face Swap") | |
| with gr.Row(): | |
| src_input = gr.Image(type="numpy", label="Upload Your Face") | |
| tgt_input = gr.Image(type="numpy", label="Upload Target Image") | |
| btn = gr.Button("Swap Face") | |
| output_img = gr.Image(type="numpy", label="Enhanced Output") | |
| download = gr.File(label="β¬οΈ Download Enhanced Image") | |
| error_box = gr.Textbox(label="Logs / Errors", interactive=False) | |
| def process(src, tgt): | |
| img, path, err = face_swap_and_enhance(src, tgt) | |
| return img, path, err | |
| btn.click(process, [src_input, tgt_input], [output_img, download, error_box]) | |
| # ------------------------------------------------- | |
| # FastAPI App | |
| # ------------------------------------------------- | |
| fastapi_app = FastAPI() | |
| def root(): | |
| return RedirectResponse("/gradio") | |
| async def health(): | |
| return {"status": "healthy"} | |
| # -------- Upload Endpoints with GridFS -------- | |
| async def upload_source(image: UploadFile = File(...)): | |
| contents = await image.read() | |
| file_id = await fs_bucket.upload_from_stream(image.filename, contents) | |
| return {"source_id": str(file_id)} | |
| async def upload_target(image: UploadFile = File(...)): | |
| contents = await image.read() | |
| file_id = await fs_bucket.upload_from_stream(image.filename, contents) | |
| return {"target_id": str(file_id)} | |
| # -------- Faceswap Endpoint -------- | |
| class FaceSwapRequest(BaseModel): | |
| source_id: str | |
| target_id: str | |
| async def perform_faceswap(request: FaceSwapRequest): | |
| try: | |
| # Read source | |
| source_stream = await fs_bucket.open_download_stream(ObjectId(request.source_id)) | |
| source_bytes = await source_stream.read() | |
| source_array = np.frombuffer(source_bytes, np.uint8) | |
| source_bgr = cv2.imdecode(source_array, cv2.IMREAD_COLOR) | |
| source_rgb = cv2.cvtColor(source_bgr, cv2.COLOR_BGR2RGB) | |
| # Read target | |
| target_stream = await fs_bucket.open_download_stream(ObjectId(request.target_id)) | |
| target_bytes = await target_stream.read() | |
| target_array = np.frombuffer(target_bytes, np.uint8) | |
| target_bgr = cv2.imdecode(target_array, cv2.IMREAD_COLOR) | |
| target_rgb = cv2.cvtColor(target_bgr, cv2.COLOR_BGR2RGB) | |
| # Run pipeline | |
| final_img, final_path, err = face_swap_and_enhance(source_rgb, target_rgb) | |
| if err: | |
| raise HTTPException(status_code=500, detail=err) | |
| # Store result in GridFS | |
| with open(final_path, "rb") as f: | |
| final_bytes = f.read() | |
| result_id = await fs_bucket.upload_from_stream("enhanced.png", final_bytes) | |
| return {"result_id": str(result_id)} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # -------- Download Endpoint -------- | |
| async def download_result(result_id: str): | |
| try: | |
| stream = await fs_bucket.open_download_stream(ObjectId(result_id)) | |
| file_data = await stream.read() | |
| return Response( | |
| content=file_data, | |
| media_type="image/png", | |
| headers={"Content-Disposition": f"attachment; filename=enhanced.png"} | |
| ) | |
| except Exception: | |
| raise HTTPException(status_code=404, detail="Result not found") | |
| # Mount Gradio | |
| fastapi_app = mount_gradio_app(fastapi_app, demo, path="/gradio") | |
| if __name__ == "__main__": | |
| uvicorn.run(fastapi_app, host="0.0.0.0", port=7860) | |