mobisoft's picture
Update app.py
72cef1f verified
import os
import torch
import numpy as np
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from fastapi.responses import StreamingResponse, HTMLResponse
from PIL import Image
from io import BytesIO
import requests
from transformers import AutoModelForImageSegmentation
import uvicorn
# ---------------------------------------------------------
# CPU optimization (important for HF Spaces)
# ---------------------------------------------------------
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
torch.set_num_threads(1)
# ---------------------------------------------------------
# Config (speed focused)
# ---------------------------------------------------------
TARGET_SIZE = (320, 320) # πŸ”₯ faster inference
MAX_FILE_SIZE = 5 * 1024 * 1024 # 5MB
MAX_COMPRESS_DIM = 1400 # aggressive resize
# ---------------------------------------------------------
# Load model
# ---------------------------------------------------------
MODEL_DIR = "models/BiRefNet"
os.makedirs(MODEL_DIR, exist_ok=True)
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
print("Loading model...")
model = AutoModelForImageSegmentation.from_pretrained(
"ZhengPeng7/BiRefNet",
cache_dir=MODEL_DIR,
trust_remote_code=True
)
model.to(device, dtype=dtype).eval()
print("Model ready")
# ---------------------------------------------------------
# Image helpers
# ---------------------------------------------------------
def load_image_from_url(url: str):
r = requests.get(url, timeout=10)
r.raise_for_status()
return Image.open(BytesIO(r.content)).convert("RGB")
# πŸ”₯ FAST compression (key part)
def compress_if_needed(img: Image.Image, raw_bytes: bytes):
if len(raw_bytes) <= MAX_FILE_SIZE:
return img
print("[INFO] Compressing image >5MB")
img = img.convert("RGB")
# Resize aggressively
w, h = img.size
scale = min(1.0, MAX_COMPRESS_DIM / max(w, h))
img = img.resize((int(w * scale), int(h * scale)), Image.BILINEAR)
# Reduce quality quickly (no loop β†’ faster)
buffer = BytesIO()
img.save(buffer, format="JPEG", quality=70, optimize=True)
buffer.seek(0)
return Image.open(buffer).convert("RGB")
def transform(img):
img = img.resize(TARGET_SIZE, Image.BILINEAR)
arr = np.asarray(img, dtype=np.float32) / 255.0
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
arr = (arr - mean) / std
arr = np.transpose(arr, (2, 0, 1))
return torch.from_numpy(arr).unsqueeze(0).to(device=device, dtype=dtype)
# πŸ”₯ FAST inference
def remove_background(img: Image.Image):
orig_size = img.size
tensor = transform(img)
with torch.inference_mode():
pred = model(tensor)
pred = pred[-1] if isinstance(pred, (list, tuple)) else pred
pred = pred.sigmoid()[0, 0].cpu()
mask = Image.fromarray((pred.mul(255).byte().numpy()))
mask = mask.resize(orig_size, Image.BILINEAR)
img = img.convert("RGBA")
img.putalpha(mask)
return img
# ---------------------------------------------------------
# FastAPI
# ---------------------------------------------------------
app = FastAPI()
@app.post("/remove-background")
async def remove_bg(file: UploadFile = File(None), image_url: str = Form(None)):
try:
if file:
raw = await file.read()
img = Image.open(BytesIO(raw)).convert("RGB")
# βœ… Step 1: compress if >5MB
img = compress_if_needed(img, raw)
elif image_url:
img = load_image_from_url(image_url)
else:
raise HTTPException(400, "Provide file or URL")
# βœ… Step 2: remove background
result = remove_background(img)
buf = BytesIO()
result.save(buf, format="PNG")
buf.seek(0)
return StreamingResponse(buf, media_type="image/png")
except Exception as e:
raise HTTPException(500, str(e))
# ---------------------------------------------------------
# Simple UI
# ---------------------------------------------------------
@app.get("/", response_class=HTMLResponse)
async def home():
return """
<html>
<head>
<title>Fast Background Remover</title>
<link rel='stylesheet'
href='https://cdn.jsdelivr.net/npm/bootstrap@5.3.2/dist/css/bootstrap.min.css'>
</head>
<body class='bg-light'>
<div class='container py-4 text-center'>
<h2>Fast Background Remover</h2>
<div class='row mt-4'>
<div class='col-md-6'>
<h5>Input</h5>
<img id='inputImg' style='max-width:100%; border-radius:10px;'>
</div>
<div class='col-md-6'>
<h5>Output</h5>
<img id='outputImg' style='max-width:100%; border-radius:10px;'>
</div>
</div>
<hr>
<form id="uploadForm">
<input type='file' id='fileInput' class='form-control mb-3'>
<button class='btn btn-primary'>Upload</button>
</form>
<hr>
<form id='urlForm'>
<input id='urlInput' class='form-control mb-3'
placeholder='Enter image URL'>
<button class='btn btn-success'>Use URL</button>
</form>
</div>
<script>
const inputImg = document.getElementById("inputImg");
const outputImg = document.getElementById("outputImg");
document.getElementById("uploadForm").addEventListener("submit", async e => {
e.preventDefault();
const file = document.getElementById("fileInput").files[0];
if (!file) return alert("Select file");
inputImg.src = URL.createObjectURL(file);
const fd = new FormData();
fd.append("file", file);
const r = await fetch("/remove-background", { method:"POST", body:fd });
outputImg.src = URL.createObjectURL(await r.blob());
});
document.getElementById("urlForm").addEventListener("submit", async e => {
e.preventDefault();
const url = document.getElementById("urlInput").value;
inputImg.src = url;
const fd = new FormData();
fd.append("image_url", url);
const r = await fetch("/remove-background", { method:"POST", body:fd });
outputImg.src = URL.createObjectURL(await r.blob());
});
</script>
</body>
</html>
"""
# ---------------------------------------------------------
# Run
# ---------------------------------------------------------
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)