| import os |
| import cv2 |
| import numpy as np |
| import onnxruntime as ort |
| import uuid |
| import base64 |
| from io import BytesIO |
| from PIL import Image |
| from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Depends, Request |
| from fastapi.responses import FileResponse |
| from fastapi.staticfiles import StaticFiles |
| from fastapi.templating import Jinja2Templates |
| from fastapi.middleware.cors import CORSMiddleware |
| import gradio as gr |
| import shutil |
|
|
| |
| API_KEY = os.getenv("API_KEY") |
|
|
| |
| app = FastAPI(title="Background Removal API") |
|
|
| |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| |
| |
| TMP_FOLDER = os.path.join(os.path.dirname(os.path.abspath(__file__)), "tmp") |
| os.makedirs(TMP_FOLDER, exist_ok=True) |
| print(f"Created tmp folder at: {TMP_FOLDER}") |
|
|
|
|
| |
| app.mount("/tmp", StaticFiles(directory=TMP_FOLDER), name="tmp") |
| templates = Jinja2Templates(directory="templates") |
|
|
| |
| model_path = "BiRefNet-general-resolution_512x512-fp16-epoch_216.onnx" |
| session = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"]) |
| input_name = "input_image" |
| INPUT_SIZE = (512, 512) |
|
|
| input_info = session.get_inputs()[0] |
|
|
| print("Input name:", input_info.name) |
| print("Input shape:", input_info.shape) |
| print("Input type:", input_info.type) |
|
|
| |
| def verify_api_key(api_key: str = Form(...)): |
| if api_key != API_KEY: |
| raise HTTPException(status_code=401, detail="Invalid API key") |
| return api_key |
|
|
| def preprocess_image(image): |
| """Process image from various input types""" |
| if isinstance(image, str): |
| img = cv2.imread(image) |
| elif isinstance(image, np.ndarray): |
| img = image |
| else: |
| nparr = np.frombuffer(image, np.uint8) |
| img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) |
| |
| |
| original_img = img.copy() |
| original_shape = img.shape[:2] |
| |
| |
| rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
| |
| |
| resized = cv2.resize(rgb, INPUT_SIZE) |
| |
| |
| normalized = resized.astype(np.float32) / 255.0 |
| normalized = (normalized - 0.5) / 0.5 |
| transposed = np.transpose(normalized, (2, 0, 1)) |
| input_tensor = np.expand_dims(transposed, axis=0).astype(np.float32) |
| |
| return input_tensor, original_shape, original_img |
|
|
|
|
| def apply_mask(original_img, mask_array, original_shape, output_path): |
| try: |
| |
| mask = np.squeeze(mask_array) |
| mask = cv2.resize(mask, (original_shape[1], original_shape[0])) |
| mask = np.clip(mask, 0, 1) |
|
|
| |
| binary_mask = (mask > 0.5).astype(np.uint8) |
|
|
| |
| img = original_img.astype(np.uint8) |
| masked_img = cv2.bitwise_and(img, img, mask=binary_mask) |
|
|
| |
| alpha = (binary_mask * 255).astype(np.uint8) |
|
|
| |
| bgra = cv2.cvtColor(masked_img, cv2.COLOR_BGR2BGRA) |
| bgra[:, :, 3] = alpha |
|
|
| |
| os.makedirs(os.path.dirname(output_path), exist_ok=True) |
|
|
| |
| cv2.imwrite(output_path, bgra, [cv2.IMWRITE_PNG_COMPRESSION, 0]) |
| print(f"Saved masked object image to: {output_path} with size {bgra.shape[:2]}") |
|
|
| return bgra, True |
| except Exception as e: |
| print(f"Error applying mask: {e}") |
| return None, False |
|
|
|
|
| |
| @app.post("/") |
| async def index_post( |
| request: Request, |
| main_photo: UploadFile = File(...), |
| bg_photo: UploadFile = File(None) |
| ): |
| try: |
| |
| main_image_data = await main_photo.read() |
| input_tensor, original_shape, original_img = preprocess_image(main_image_data) |
| output = session.run(None, {input_name: input_tensor}) |
| mask = output[0] |
|
|
| result_filename = f"{uuid.uuid4()}.png" |
| output_path = os.path.join(TMP_FOLDER, result_filename) |
|
|
| |
| transparent_img, success = apply_mask(original_img, mask, original_shape, output_path) |
| final_result_path = output_path |
|
|
| |
| if bg_photo: |
| bg_image_data = await bg_photo.read() |
| bg_np = np.frombuffer(bg_image_data, np.uint8) |
| bg_img = cv2.imdecode(bg_np, cv2.IMREAD_COLOR) |
| bg_img_resized = cv2.resize(bg_img, (original_shape[1], original_shape[0])) |
|
|
| alpha = transparent_img[:, :, 3] / 255.0 |
| foreground = transparent_img[:, :, :3] |
|
|
| blended = (foreground * alpha[..., None] + bg_img_resized * (1 - alpha[..., None])).astype(np.uint8) |
| final_result_path = os.path.join(TMP_FOLDER, f"bg_replaced_{uuid.uuid4()}.png") |
| cv2.imwrite(final_result_path, blended) |
|
|
| return templates.TemplateResponse("index.html", { |
| "request": request, |
| "output_image": os.path.basename(final_result_path) |
| }) |
|
|
| except Exception as e: |
| import traceback |
| print("Error in index_post:", str(e)) |
| print(traceback.format_exc()) |
| return templates.TemplateResponse("index.html", { |
| "request": request, |
| "error": f"Error: {str(e)}" |
| }) |
| |
| @app.post("/remove-background") |
| async def remove_background(request: Request, api_key: str = Form(...), main_photo: UploadFile = File(...)): |
| |
| verify_api_key(api_key) |
| |
| try: |
| |
| image_data = await main_photo.read() |
| |
| |
| result_filename = f"{uuid.uuid4()}.png" |
| output_path = os.path.join(TMP_FOLDER, result_filename) |
| |
| |
| os.makedirs(TMP_FOLDER, exist_ok=True) |
| |
| |
| input_tensor, original_shape, original_img = preprocess_image(image_data) |
| output = session.run(None, {input_name: input_tensor}) |
| mask = output[0] |
| |
| |
| _, success = apply_mask(original_img, mask, original_shape, output_path) |
| |
| if success: |
| |
| base_url = str(request.base_url) |
| if base_url.endswith("/"): |
| base_url = base_url[:-1] |
| |
| |
| if "hf.space" in base_url: |
| |
| full_url = f"{base_url}/tmp/{result_filename}" |
| else: |
| |
| full_url = f"{base_url}/tmp/{result_filename}" |
| |
| return { |
| "status": "success", |
| "message": "Background removed successfully", |
| "filename": result_filename, |
| "image_url": full_url |
| } |
| else: |
| return { |
| "status": "failure", |
| "message": "Failed to process image" |
| } |
| |
| except Exception as e: |
| import traceback |
| print(f"Error in remove_background: {str(e)}") |
| print(traceback.format_exc()) |
| return { |
| "status": "failure", |
| "message": f"Error: {str(e)}" |
| } |
|
|
| |
| def process_image_gradio(image): |
| |
| input_tensor, original_shape, original_img = preprocess_image(image) |
| output = session.run(None, {input_name: input_tensor}) |
| mask = output[0] |
| |
| |
| filename = f"{uuid.uuid4()}.png" |
| output_path = os.path.join(TMP_FOLDER, filename) |
| |
| |
| os.makedirs(TMP_FOLDER, exist_ok=True) |
| |
| |
| result_img, success = apply_mask(original_img, mask, original_shape, output_path) |
| |
| if success: |
| |
| result_pil = Image.fromarray(cv2.cvtColor(result_img, cv2.COLOR_BGRA2RGBA)) |
| return result_pil |
| else: |
| return None |
|
|
| |
| interface = gr.Interface( |
| fn=process_image_gradio, |
| inputs=gr.Image(type="numpy"), |
| outputs=gr.Image(type="pil"), |
| title="Background Removal", |
| description="Upload an image to remove its background" |
| ) |
|
|
| |
| app = gr.mount_gradio_app(app, interface, path="/gradio") |
|
|
| |
|
|
| |
| @app.get("/") |
| async def index_get(request: Request): |
| return templates.TemplateResponse("index.html", {"request": request}) |
|
|
| |
| |
| @app.post("/process_image") |
| async def process_image(request: Request, image: UploadFile = File(...), api_key: str = Form(...)): |
| |
| verify_api_key(api_key) |
| |
| try: |
| |
| image_data = await image.read() |
| |
| |
| result_filename = f"{uuid.uuid4()}.png" |
| output_path = os.path.join(TMP_FOLDER, result_filename) |
| |
| |
| os.makedirs(TMP_FOLDER, exist_ok=True) |
| |
| |
| input_tensor, original_shape, original_img = preprocess_image(image_data) |
| output = session.run(None, {input_name: input_tensor}) |
| mask = output[0] |
| |
| |
| bgra, success = apply_mask(original_img, mask, original_shape, output_path) |
| |
| if success: |
| |
| with open(output_path, "rb") as img_file: |
| base64_image = base64.b64encode(img_file.read()).decode('utf-8') |
| |
| |
| return { |
| "status": "success", |
| "image_code": base64_image |
| } |
| else: |
| return { |
| "status": "failure", |
| "message": "Failed to process image" |
| } |
| |
| except Exception as e: |
| import traceback |
| print(f"Error in process_image: {str(e)}") |
| print(traceback.format_exc()) |
| return { |
| "status": "failure", |
| "message": f"Error: {str(e)}" |
| } |
|
|
| |
| |
| @app.get("/download/{filename}") |
| async def download_file(filename: str): |
| file_path = os.path.join(TMP_FOLDER, filename) |
| if os.path.exists(file_path): |
| return FileResponse( |
| path=file_path, |
| filename=filename, |
| media_type="image/png" |
| ) |
| raise HTTPException(status_code=404, detail="File not found") |
|
|
| |
| if __name__ == "__main__": |
| import uvicorn |
| |
| print(f"Current working directory: {os.getcwd()}") |
| print(f"TMP_FOLDER absolute path: {os.path.abspath(TMP_FOLDER)}") |
| uvicorn.run(app, host="0.0.0.0", port=7860) |
|
|
|
|