Spaces:
Running
Running
| import os | |
| import torch | |
| import numpy as np | |
| from fastapi import FastAPI, UploadFile, File, Form, HTTPException | |
| from fastapi.responses import StreamingResponse, HTMLResponse | |
| from PIL import Image | |
| from io import BytesIO | |
| import requests | |
| from transformers import AutoModelForImageSegmentation | |
| import uvicorn | |
| # --------------------------------------------------------- | |
| # CPU optimization (important for HF Spaces) | |
| # --------------------------------------------------------- | |
| os.environ["OMP_NUM_THREADS"] = "1" | |
| os.environ["MKL_NUM_THREADS"] = "1" | |
| torch.set_num_threads(1) | |
| # --------------------------------------------------------- | |
| # Config (speed focused) | |
| # --------------------------------------------------------- | |
| TARGET_SIZE = (320, 320) # π₯ faster inference | |
| MAX_FILE_SIZE = 5 * 1024 * 1024 # 5MB | |
| MAX_COMPRESS_DIM = 1400 # aggressive resize | |
| # --------------------------------------------------------- | |
| # Load model | |
| # --------------------------------------------------------- | |
| MODEL_DIR = "models/BiRefNet" | |
| os.makedirs(MODEL_DIR, exist_ok=True) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| print("Loading model...") | |
| model = AutoModelForImageSegmentation.from_pretrained( | |
| "ZhengPeng7/BiRefNet", | |
| cache_dir=MODEL_DIR, | |
| trust_remote_code=True | |
| ) | |
| model.to(device, dtype=dtype).eval() | |
| print("Model ready") | |
| # --------------------------------------------------------- | |
| # Image helpers | |
| # --------------------------------------------------------- | |
| def load_image_from_url(url: str): | |
| r = requests.get(url, timeout=10) | |
| r.raise_for_status() | |
| return Image.open(BytesIO(r.content)).convert("RGB") | |
| # π₯ FAST compression (key part) | |
| def compress_if_needed(img: Image.Image, raw_bytes: bytes): | |
| if len(raw_bytes) <= MAX_FILE_SIZE: | |
| return img | |
| print("[INFO] Compressing image >5MB") | |
| img = img.convert("RGB") | |
| # Resize aggressively | |
| w, h = img.size | |
| scale = min(1.0, MAX_COMPRESS_DIM / max(w, h)) | |
| img = img.resize((int(w * scale), int(h * scale)), Image.BILINEAR) | |
| # Reduce quality quickly (no loop β faster) | |
| buffer = BytesIO() | |
| img.save(buffer, format="JPEG", quality=70, optimize=True) | |
| buffer.seek(0) | |
| return Image.open(buffer).convert("RGB") | |
| def transform(img): | |
| img = img.resize(TARGET_SIZE, Image.BILINEAR) | |
| arr = np.asarray(img, dtype=np.float32) / 255.0 | |
| mean = np.array([0.485, 0.456, 0.406]) | |
| std = np.array([0.229, 0.224, 0.225]) | |
| arr = (arr - mean) / std | |
| arr = np.transpose(arr, (2, 0, 1)) | |
| return torch.from_numpy(arr).unsqueeze(0).to(device=device, dtype=dtype) | |
| # π₯ FAST inference | |
| def remove_background(img: Image.Image): | |
| orig_size = img.size | |
| tensor = transform(img) | |
| with torch.inference_mode(): | |
| pred = model(tensor) | |
| pred = pred[-1] if isinstance(pred, (list, tuple)) else pred | |
| pred = pred.sigmoid()[0, 0].cpu() | |
| mask = Image.fromarray((pred.mul(255).byte().numpy())) | |
| mask = mask.resize(orig_size, Image.BILINEAR) | |
| img = img.convert("RGBA") | |
| img.putalpha(mask) | |
| return img | |
| # --------------------------------------------------------- | |
| # FastAPI | |
| # --------------------------------------------------------- | |
| app = FastAPI() | |
| async def remove_bg(file: UploadFile = File(None), image_url: str = Form(None)): | |
| try: | |
| if file: | |
| raw = await file.read() | |
| img = Image.open(BytesIO(raw)).convert("RGB") | |
| # β Step 1: compress if >5MB | |
| img = compress_if_needed(img, raw) | |
| elif image_url: | |
| img = load_image_from_url(image_url) | |
| else: | |
| raise HTTPException(400, "Provide file or URL") | |
| # β Step 2: remove background | |
| result = remove_background(img) | |
| buf = BytesIO() | |
| result.save(buf, format="PNG") | |
| buf.seek(0) | |
| return StreamingResponse(buf, media_type="image/png") | |
| except Exception as e: | |
| raise HTTPException(500, str(e)) | |
| # --------------------------------------------------------- | |
| # Simple UI | |
| # --------------------------------------------------------- | |
| async def home(): | |
| return """ | |
| <html> | |
| <head> | |
| <title>Fast Background Remover</title> | |
| <link rel='stylesheet' | |
| href='https://cdn.jsdelivr.net/npm/bootstrap@5.3.2/dist/css/bootstrap.min.css'> | |
| </head> | |
| <body class='bg-light'> | |
| <div class='container py-4 text-center'> | |
| <h2>Fast Background Remover</h2> | |
| <div class='row mt-4'> | |
| <div class='col-md-6'> | |
| <h5>Input</h5> | |
| <img id='inputImg' style='max-width:100%; border-radius:10px;'> | |
| </div> | |
| <div class='col-md-6'> | |
| <h5>Output</h5> | |
| <img id='outputImg' style='max-width:100%; border-radius:10px;'> | |
| </div> | |
| </div> | |
| <hr> | |
| <form id="uploadForm"> | |
| <input type='file' id='fileInput' class='form-control mb-3'> | |
| <button class='btn btn-primary'>Upload</button> | |
| </form> | |
| <hr> | |
| <form id='urlForm'> | |
| <input id='urlInput' class='form-control mb-3' | |
| placeholder='Enter image URL'> | |
| <button class='btn btn-success'>Use URL</button> | |
| </form> | |
| </div> | |
| <script> | |
| const inputImg = document.getElementById("inputImg"); | |
| const outputImg = document.getElementById("outputImg"); | |
| document.getElementById("uploadForm").addEventListener("submit", async e => { | |
| e.preventDefault(); | |
| const file = document.getElementById("fileInput").files[0]; | |
| if (!file) return alert("Select file"); | |
| inputImg.src = URL.createObjectURL(file); | |
| const fd = new FormData(); | |
| fd.append("file", file); | |
| const r = await fetch("/remove-background", { method:"POST", body:fd }); | |
| outputImg.src = URL.createObjectURL(await r.blob()); | |
| }); | |
| document.getElementById("urlForm").addEventListener("submit", async e => { | |
| e.preventDefault(); | |
| const url = document.getElementById("urlInput").value; | |
| inputImg.src = url; | |
| const fd = new FormData(); | |
| fd.append("image_url", url); | |
| const r = await fetch("/remove-background", { method:"POST", body:fd }); | |
| outputImg.src = URL.createObjectURL(await r.blob()); | |
| }); | |
| </script> | |
| </body> | |
| </html> | |
| """ | |
| # --------------------------------------------------------- | |
| # Run | |
| # --------------------------------------------------------- | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |