FaceSwap / app.py
HariLogicgo's picture
usig gridfs for storage
0fd380e
raw
history blame
9.98 kB
import os
os.environ["OMP_NUM_THREADS"] = "1"
import gradio as gr
import cv2
import shutil
import uuid
import insightface
from insightface.app import FaceAnalysis
from huggingface_hub import hf_hub_download
import subprocess
import numpy as np
import threading
from fastapi import FastAPI, UploadFile, File, HTTPException, Response
from fastapi.responses import RedirectResponse
from pydantic import BaseModel
from motor.motor_asyncio import AsyncIOMotorClient
from bson.objectid import ObjectId
from gridfs import AsyncIOMotorGridFSBucket
from gradio import mount_gradio_app
import uvicorn
import logging
import io
# -------------------------------------------------
# Logging
# -------------------------------------------------
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# -------------------------------------------------
# Paths
# -------------------------------------------------
REPO_ID = "HariLogicgo/face_swap_models"
BASE_DIR = "./workspace"
UPLOAD_DIR = os.path.join(BASE_DIR, "uploads")
RESULT_DIR = os.path.join(BASE_DIR, "results")
MODELS_DIR = "./models"
os.makedirs(UPLOAD_DIR, exist_ok=True)
os.makedirs(RESULT_DIR, exist_ok=True)
os.makedirs(MODELS_DIR, exist_ok=True)
# -------------------------------------------------
# Download models
# -------------------------------------------------
def download_models():
logger.info("Downloading models...")
inswapper_path = hf_hub_download(
repo_id=REPO_ID,
filename="models/inswapper_128.onnx",
repo_type="model",
local_dir=MODELS_DIR
)
buffalo_files = [
"1k3d68.onnx",
"2d106det.onnx",
"genderage.onnx",
"det_10g.onnx",
"w600k_r50.onnx"
]
for f in buffalo_files:
hf_hub_download(
repo_id=REPO_ID,
filename=f"models/buffalo_l/{f}",
repo_type="model",
local_dir=MODELS_DIR
)
logger.info("Models downloaded successfully")
return inswapper_path
inswapper_path = download_models()
# -------------------------------------------------
# Face Analysis + Swapper
# -------------------------------------------------
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
logger.info(f"Initializing FaceAnalysis with providers: {providers}")
face_analysis_app = FaceAnalysis(name="buffalo_l", root=MODELS_DIR, providers=providers)
face_analysis_app.prepare(ctx_id=0, det_size=(640, 640))
swapper = insightface.model_zoo.get_model(inswapper_path, providers=providers)
logger.info("FaceAnalysis and swapper initialized")
# -------------------------------------------------
# CodeFormer setup
# -------------------------------------------------
CODEFORMER_PATH = "CodeFormer/inference_codeformer.py"
def ensure_codeformer():
if not os.path.exists("CodeFormer"):
logger.info("Cloning CodeFormer repository...")
subprocess.run("git clone https://github.com/sczhou/CodeFormer.git", shell=True, check=True)
subprocess.run("pip install -r CodeFormer/requirements.txt", shell=True, check=True)
subprocess.run("python CodeFormer/basicsr/setup.py develop", shell=True, check=True)
subprocess.run("python CodeFormer/scripts/download_pretrained_models.py facelib", shell=True, check=True)
subprocess.run("python CodeFormer/scripts/download_pretrained_models.py CodeFormer", shell=True, check=True)
logger.info("CodeFormer setup complete")
ensure_codeformer()
# -------------------------------------------------
# MongoDB + GridFS
# -------------------------------------------------
MONGODB_URL = os.getenv(
"MONGODB_URL",
"mongodb+srv://harilogicgo_db_user:logicgoinfotech@cluster0.dcs1tnb.mongodb.net/?retryWrites=true&w=majority&appName=Cluster0"
)
client = AsyncIOMotorClient(MONGODB_URL)
database = client.FaceSwap
fs_bucket = AsyncIOMotorGridFSBucket(database)
logger.info("MongoDB + GridFS initialized")
# -------------------------------------------------
# Lock for face swap
# -------------------------------------------------
swap_lock = threading.Lock()
# -------------------------------------------------
# Face Swap Pipeline
# -------------------------------------------------
def face_swap_and_enhance(src_img, tgt_img):
logger.info("Starting face swap and enhancement")
try:
with swap_lock:
shutil.rmtree(UPLOAD_DIR, ignore_errors=True)
shutil.rmtree(RESULT_DIR, ignore_errors=True)
os.makedirs(UPLOAD_DIR, exist_ok=True)
os.makedirs(RESULT_DIR, exist_ok=True)
if not isinstance(src_img, np.ndarray) or not isinstance(tgt_img, np.ndarray):
return None, None, "❌ Invalid input images"
src_bgr = cv2.cvtColor(src_img, cv2.COLOR_RGB2BGR)
tgt_bgr = cv2.cvtColor(tgt_img, cv2.COLOR_RGB2BGR)
src_faces = face_analysis_app.get(src_bgr)
tgt_faces = face_analysis_app.get(tgt_bgr)
if not src_faces or not tgt_faces:
return None, None, "❌ Face not detected"
swapped_path = os.path.join(UPLOAD_DIR, f"swapped_{uuid.uuid4().hex[:8]}.jpg")
swapped_bgr = swapper.get(tgt_bgr, tgt_faces[0], src_faces[0])
if swapped_bgr is None:
return None, None, "❌ Face swap failed"
cv2.imwrite(swapped_path, swapped_bgr)
cmd = f"python {CODEFORMER_PATH} -w 0.7 --input_path {swapped_path} --output_path {RESULT_DIR} --bg_upsampler realesrgan --face_upsample"
result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
if result.returncode != 0:
return None, None, f"❌ CodeFormer failed:\n{result.stderr}"
final_results_dir = os.path.join(RESULT_DIR, "final_results")
final_files = [f for f in os.listdir(final_results_dir) if f.endswith(".png")]
if not final_files:
return None, None, "❌ No enhanced image found"
final_path = os.path.join(final_results_dir, final_files[0])
final_img = cv2.cvtColor(cv2.imread(final_path), cv2.COLOR_BGR2RGB)
return final_img, final_path, ""
except Exception as e:
return None, None, f"❌ Error: {str(e)}"
# -------------------------------------------------
# Gradio Interface
# -------------------------------------------------
with gr.Blocks() as demo:
gr.Markdown("Face Swap")
with gr.Row():
src_input = gr.Image(type="numpy", label="Upload Your Face")
tgt_input = gr.Image(type="numpy", label="Upload Target Image")
btn = gr.Button("Swap Face")
output_img = gr.Image(type="numpy", label="Enhanced Output")
download = gr.File(label="⬇️ Download Enhanced Image")
error_box = gr.Textbox(label="Logs / Errors", interactive=False)
def process(src, tgt):
img, path, err = face_swap_and_enhance(src, tgt)
return img, path, err
btn.click(process, [src_input, tgt_input], [output_img, download, error_box])
# -------------------------------------------------
# FastAPI App
# -------------------------------------------------
fastapi_app = FastAPI()
@fastapi_app.get("/")
def root():
return RedirectResponse("/gradio")
@fastapi_app.get("/health")
async def health():
return {"status": "healthy"}
# -------- Upload Endpoints with GridFS --------
@fastapi_app.post("/source")
async def upload_source(image: UploadFile = File(...)):
contents = await image.read()
file_id = await fs_bucket.upload_from_stream(image.filename, contents)
return {"source_id": str(file_id)}
@fastapi_app.post("/target")
async def upload_target(image: UploadFile = File(...)):
contents = await image.read()
file_id = await fs_bucket.upload_from_stream(image.filename, contents)
return {"target_id": str(file_id)}
# -------- Faceswap Endpoint --------
class FaceSwapRequest(BaseModel):
source_id: str
target_id: str
@fastapi_app.post("/faceswap")
async def perform_faceswap(request: FaceSwapRequest):
try:
# Read source
source_stream = await fs_bucket.open_download_stream(ObjectId(request.source_id))
source_bytes = await source_stream.read()
source_array = np.frombuffer(source_bytes, np.uint8)
source_bgr = cv2.imdecode(source_array, cv2.IMREAD_COLOR)
source_rgb = cv2.cvtColor(source_bgr, cv2.COLOR_BGR2RGB)
# Read target
target_stream = await fs_bucket.open_download_stream(ObjectId(request.target_id))
target_bytes = await target_stream.read()
target_array = np.frombuffer(target_bytes, np.uint8)
target_bgr = cv2.imdecode(target_array, cv2.IMREAD_COLOR)
target_rgb = cv2.cvtColor(target_bgr, cv2.COLOR_BGR2RGB)
# Run pipeline
final_img, final_path, err = face_swap_and_enhance(source_rgb, target_rgb)
if err:
raise HTTPException(status_code=500, detail=err)
# Store result in GridFS
with open(final_path, "rb") as f:
final_bytes = f.read()
result_id = await fs_bucket.upload_from_stream("enhanced.png", final_bytes)
return {"result_id": str(result_id)}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# -------- Download Endpoint --------
@fastapi_app.get("/download/{result_id}")
async def download_result(result_id: str):
try:
stream = await fs_bucket.open_download_stream(ObjectId(result_id))
file_data = await stream.read()
return Response(
content=file_data,
media_type="image/png",
headers={"Content-Disposition": f"attachment; filename=enhanced.png"}
)
except Exception:
raise HTTPException(status_code=404, detail="Result not found")
# Mount Gradio
fastapi_app = mount_gradio_app(fastapi_app, demo, path="/gradio")
if __name__ == "__main__":
uvicorn.run(fastapi_app, host="0.0.0.0", port=7860)