Container's picture
Upload 2 files
f25dfb4 verified
from fastapi import FastAPI, Request
from fastapi.responses import Response
from starlette.formparsers import MultiPartParser
import uvicorn
import aiohttp
import asyncio
import base64
import logging
import os
import re
import time
from io import BytesIO
import numpy as np
import torch
import torch.nn.functional as F
from torchvision.transforms.functional import normalize
from PIL import Image
from huggingface_hub import snapshot_download
from transformers import AutoModelForImageSegmentation
try:
import pkg_resources
except Exception:
pkg_resources = None
# 上传限制:50MB
MAX_UPLOAD_SIZE = 1024 * 1024 * 50
MultiPartParser.spool_max_size = MAX_UPLOAD_SIZE
MultiPartParser.max_part_size = MAX_UPLOAD_SIZE
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s | %(levelname)s | %(message)s"
)
logger = logging.getLogger(__name__)
app = FastAPI()
read_key = os.environ.get("HF_TOKEN", None)
HF_DATASET_REPO = "Maid-10000/RMBG-DataBase"
MODEL_ALLOW_PATTERNS = [
"config.json",
"MyConfig.py",
"briarmbg.py",
"pytorch_model.bin",
]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def load_model():
local_model_path = snapshot_download(
repo_id=HF_DATASET_REPO,
repo_type="dataset",
allow_patterns=MODEL_ALLOW_PATTERNS,
token=read_key,
)
return AutoModelForImageSegmentation.from_pretrained(
local_model_path,
trust_remote_code=True,
use_safetensors=False,
)
logger.info(f"Loading BriaRMBG model on {device}...")
net = load_model()
net.to(device)
net.eval()
logger.info("Model loaded.")
def end_time(start_time, text):
logger.info(f"{text}: 共执行 {time.time() - start_time:.3f} 秒")
def resize_image(image: Image.Image) -> Image.Image:
image = image.convert("RGB")
return image.resize((1024, 1024), Image.BILINEAR)
def process(image: Image.Image) -> Response:
orig_image = image.convert("RGB")
w, h = orig_image.size
input_image = resize_image(orig_image)
im_np = np.array(input_image)
im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1)
im_tensor = im_tensor.unsqueeze(0)
im_tensor = im_tensor / 255.0
im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
im_tensor = im_tensor.to(device)
with torch.no_grad():
result = net(im_tensor)
result = result[0][0]
result = F.interpolate(result, size=(h, w), mode="bilinear", align_corners=False)
result = torch.squeeze(result, 0)
ma = torch.max(result)
mi = torch.min(result)
if ma == mi:
result = torch.zeros_like(result)
else:
result = (result - mi) / (ma - mi)
result_array = (result * 255).cpu().numpy().astype(np.uint8)
pil_mask = Image.fromarray(np.squeeze(result_array))
new_im = orig_image.copy()
new_im.putalpha(pil_mask)
buf = BytesIO()
new_im.save(buf, format="PNG")
buf.seek(0)
return Response(content=buf.read(), media_type="image/png")
async def fetch_url(url: str) -> bytes:
timeout = aiohttp.ClientTimeout(total=15)
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.get(url) as response:
if response.status != 200:
raise RuntimeError(f"URL returned status {response.status}")
return await response.read()
@app.get("/")
async def main():
return {"code": 200, "msg": "Success"}
@app.api_route("/rmbg", methods=["GET", "POST"])
async def rmbg(request: Request):
init_time = time.time()
try:
start_time = time.time()
params = dict(request.query_params)
if request.method in ["POST", "PUT", "PATCH"]:
content_type = request.headers.get("Content-Type", "").lower()
if "application/json" in content_type:
json_data = await request.json()
params.update(json_data)
else:
form_data = await request.form(
max_part_size=MultiPartParser.max_part_size
)
for key, value in form_data.items():
params[key] = value
end_time(start_time, "参数合并完成")
except Exception as e:
logger.exception(e)
return {
"code": 503,
"msg": "An unexpected error occurred during the parameter assignment process"
}
url = params.get("url")
file = params.get("file")
b64 = params.get("base64")
try:
start_time = time.time()
if file:
data = await file.read()
elif b64:
pattern = r"^data:image\/[a-zA-Z]+;base64,"
if re.match(pattern, b64):
b64 = re.sub(pattern, "", b64)
data = base64.b64decode(b64, validate=True)
elif url:
data = await fetch_url(url)
else:
return {"code": 503, "msg": "No image parameters entered"}
end_time(start_time, "图片下载完成")
except Exception as e:
logger.exception(e)
return {"code": 503, "msg": "Image parameter parsing error"}
try:
start_time = time.time()
loop = asyncio.get_running_loop()
image = await loop.run_in_executor(
None,
lambda: Image.open(BytesIO(data)).convert("RGB")
)
end_time(start_time, "图片读取完成")
except Exception as e:
logger.exception(e)
return {"code": 503, "msg": "The input is not an image"}
try:
start_time = time.time()
loop = asyncio.get_running_loop()
result = await loop.run_in_executor(None, lambda: process(image))
end_time(start_time, "图片分析完成")
end_time(init_time, "[ 任务总耗时 ]")
return result
except Exception as e:
logger.exception(e)
return {"code": 503, "msg": "Image processing failed"}
@app.get("/pkg-version")
async def pkg_version():
if pkg_resources is None:
return {"code": 503, "msg": "pkg_resources unavailable"}
installed_packages = pkg_resources.working_set
packages_list = [
{"name": pkg.project_name, "version": pkg.version}
for pkg in installed_packages
]
packages_list = sorted(packages_list, key=lambda x: x["name"])
return {
"code": 200,
"msg": "Success",
"pkg-version": packages_list
}
if __name__ == "__main__":
uvicorn.run(
"app:app",
host="0.0.0.0",
port=7860,
workers=1
)