Spaces:
Running
Running
import json | |
import os | |
import webbrowser | |
from typing import Optional, Tuple, cast | |
import aiohttp | |
import click | |
import gradio as gr | |
import uvicorn | |
from asyncer import asyncify | |
from fastapi import Depends, FastAPI, File, Form, Query | |
from fastapi.middleware.cors import CORSMiddleware | |
from starlette.responses import Response | |
from .._version import get_versions | |
from ..bg import remove | |
from ..session_factory import new_session | |
from ..sessions import sessions_names | |
from ..sessions.base import BaseSession | |
def s_command(port: int, host: str, log_level: str, threads: int) -> None: | |
""" | |
Command-line interface for running the FastAPI web server. | |
This function starts the FastAPI web server with the specified port and log level. | |
If the number of worker threads is specified, it sets the thread limiter accordingly. | |
""" | |
sessions: dict[str, BaseSession] = {} | |
tags_metadata = [ | |
{ | |
"name": "Background Removal", | |
"description": "Endpoints that perform background removal with different image sources.", | |
"externalDocs": { | |
"description": "GitHub Source", | |
"url": "https://github.com/danielgatis/rembg", | |
}, | |
}, | |
] | |
app = FastAPI( | |
title="Rembg", | |
description="Rembg is a tool to remove images background. That is it.", | |
version=get_versions()["version"], | |
contact={ | |
"name": "Daniel Gatis", | |
"url": "https://github.com/danielgatis", | |
"email": "danielgatis@gmail.com", | |
}, | |
license_info={ | |
"name": "MIT License", | |
"url": "https://github.com/danielgatis/rembg/blob/main/LICENSE.txt", | |
}, | |
openapi_tags=tags_metadata, | |
docs_url="/api", | |
) | |
app.add_middleware( | |
CORSMiddleware, | |
allow_credentials=True, | |
allow_origins=["*"], | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
class CommonQueryParams: | |
def __init__( | |
self, | |
model: str = Query( | |
description="Model to use when processing image", | |
regex=r"(" + "|".join(sessions_names) + ")", | |
default="u2net", | |
), | |
a: bool = Query(default=False, description="Enable Alpha Matting"), | |
af: int = Query( | |
default=240, | |
ge=0, | |
le=255, | |
description="Alpha Matting (Foreground Threshold)", | |
), | |
ab: int = Query( | |
default=10, | |
ge=0, | |
le=255, | |
description="Alpha Matting (Background Threshold)", | |
), | |
ae: int = Query( | |
default=10, ge=0, description="Alpha Matting (Erode Structure Size)" | |
), | |
om: bool = Query(default=False, description="Only Mask"), | |
ppm: bool = Query(default=False, description="Post Process Mask"), | |
bgc: Optional[str] = Query(default=None, description="Background Color"), | |
extras: Optional[str] = Query( | |
default=None, description="Extra parameters as JSON" | |
), | |
): | |
self.model = model | |
self.a = a | |
self.af = af | |
self.ab = ab | |
self.ae = ae | |
self.om = om | |
self.ppm = ppm | |
self.extras = extras | |
self.bgc = ( | |
cast(Tuple[int, int, int, int], tuple(map(int, bgc.split(",")))) | |
if bgc | |
else None | |
) | |
class CommonQueryPostParams: | |
def __init__( | |
self, | |
model: str = Form( | |
description="Model to use when processing image", | |
regex=r"(" + "|".join(sessions_names) + ")", | |
default="u2net", | |
), | |
a: bool = Form(default=False, description="Enable Alpha Matting"), | |
af: int = Form( | |
default=240, | |
ge=0, | |
le=255, | |
description="Alpha Matting (Foreground Threshold)", | |
), | |
ab: int = Form( | |
default=10, | |
ge=0, | |
le=255, | |
description="Alpha Matting (Background Threshold)", | |
), | |
ae: int = Form( | |
default=10, ge=0, description="Alpha Matting (Erode Structure Size)" | |
), | |
om: bool = Form(default=False, description="Only Mask"), | |
ppm: bool = Form(default=False, description="Post Process Mask"), | |
bgc: Optional[str] = Query(default=None, description="Background Color"), | |
extras: Optional[str] = Query( | |
default=None, description="Extra parameters as JSON" | |
), | |
): | |
self.model = model | |
self.a = a | |
self.af = af | |
self.ab = ab | |
self.ae = ae | |
self.om = om | |
self.ppm = ppm | |
self.extras = extras | |
self.bgc = ( | |
cast(Tuple[int, int, int, int], tuple(map(int, bgc.split(",")))) | |
if bgc | |
else None | |
) | |
def im_without_bg(content: bytes, commons: CommonQueryParams) -> Response: | |
kwargs = {} | |
if commons.extras: | |
try: | |
kwargs.update(json.loads(commons.extras)) | |
except Exception: | |
pass | |
return Response( | |
remove( | |
content, | |
session=sessions.setdefault( | |
commons.model, new_session(commons.model, **kwargs) | |
), | |
alpha_matting=commons.a, | |
alpha_matting_foreground_threshold=commons.af, | |
alpha_matting_background_threshold=commons.ab, | |
alpha_matting_erode_size=commons.ae, | |
only_mask=commons.om, | |
post_process_mask=commons.ppm, | |
bgcolor=commons.bgc, | |
**kwargs, | |
), | |
media_type="image/png", | |
) | |
def startup(): | |
try: | |
webbrowser.open(f"http://localhost:{port}") | |
except Exception: | |
pass | |
if threads is not None: | |
from anyio import CapacityLimiter | |
from anyio.lowlevel import RunVar | |
RunVar("_default_thread_limiter").set(CapacityLimiter(threads)) | |
async def get_index( | |
url: str = Query( | |
default=..., description="URL of the image that has to be processed." | |
), | |
commons: CommonQueryParams = Depends(), | |
): | |
async with aiohttp.ClientSession() as session: | |
async with session.get(url) as response: | |
file = await response.read() | |
return await asyncify(im_without_bg)(file, commons) | |
async def post_index( | |
file: bytes = File( | |
default=..., | |
description="Image file (byte stream) that has to be processed.", | |
), | |
commons: CommonQueryPostParams = Depends(), | |
): | |
return await asyncify(im_without_bg)(file, commons) # type: ignore | |
def gr_app(app): | |
def inference(input_path, model, *args): | |
output_path = "output.png" | |
a, af, ab, ae, om, ppm, cmd_args = args | |
kwargs = { | |
"alpha_matting": a, | |
"alpha_matting_foreground_threshold": af, | |
"alpha_matting_background_threshold": ab, | |
"alpha_matting_erode_size": ae, | |
"only_mask": om, | |
"post_process_mask": ppm, | |
} | |
if cmd_args: | |
kwargs.update(json.loads(cmd_args)) | |
kwargs["session"] = new_session(model, **kwargs) | |
with open(input_path, "rb") as i: | |
with open(output_path, "wb") as o: | |
input = i.read() | |
output = remove(input, **kwargs) | |
o.write(output) | |
return os.path.join(output_path) | |
interface = gr.Interface( | |
inference, | |
[ | |
gr.components.Image(type="filepath", label="Input"), | |
gr.components.Dropdown(sessions_names, value="u2net", label="Models"), | |
gr.components.Checkbox(value=True, label="Alpha matting"), | |
gr.components.Slider( | |
value=240, minimum=0, maximum=255, label="Foreground threshold" | |
), | |
gr.components.Slider( | |
value=10, minimum=0, maximum=255, label="Background threshold" | |
), | |
gr.components.Slider( | |
value=40, minimum=0, maximum=255, label="Erosion size" | |
), | |
gr.components.Checkbox(value=False, label="Only mask"), | |
gr.components.Checkbox(value=True, label="Post process mask"), | |
gr.components.Textbox(label="Arguments"), | |
], | |
gr.components.Image(type="filepath", label="Output"), | |
concurrency_limit=3, | |
) | |
app = gr.mount_gradio_app(app, interface, path="/") | |
return app | |
print( | |
f"To access the API documentation, go to http://{'localhost' if host == '0.0.0.0' else host}:{port}/api" | |
) | |
print( | |
f"To access the UI, go to http://{'localhost' if host == '0.0.0.0' else host}:{port}" | |
) | |
uvicorn.run(gr_app(app), host=host, port=port, log_level=log_level) | |