Spaces:
Runtime error
Runtime error
import pathlib | |
import sys | |
import time | |
from enum import Enum | |
from typing import IO, cast | |
import aiohttp | |
import click | |
import filetype | |
import uvicorn | |
from asyncer import asyncify | |
from fastapi import Depends, FastAPI, File, Form, Query | |
from fastapi.middleware.cors import CORSMiddleware | |
from starlette.responses import Response | |
from tqdm import tqdm | |
from watchdog.events import FileSystemEvent, FileSystemEventHandler | |
from watchdog.observers import Observer | |
from . import _version | |
from .bg import remove | |
from .session_base import BaseSession | |
from .session_factory import new_session | |
def main() -> None: | |
pass | |
def i(model: str, input: IO, output: IO, **kwargs) -> None: | |
output.write(remove(input.read(), session=new_session(model), **kwargs)) | |
def p( | |
model: str, input: pathlib.Path, output: pathlib.Path, watch: bool, **kwargs | |
) -> None: | |
session = new_session(model) | |
def process(each_input: pathlib.Path) -> None: | |
try: | |
mimetype = filetype.guess(each_input) | |
if mimetype is None: | |
return | |
if mimetype.mime.find("image") < 0: | |
return | |
each_output = (output / each_input.name).with_suffix(".png") | |
each_output.parents[0].mkdir(parents=True, exist_ok=True) | |
if not each_output.exists(): | |
each_output.write_bytes( | |
cast( | |
bytes, | |
remove(each_input.read_bytes(), session=session, **kwargs), | |
) | |
) | |
if watch: | |
print( | |
f"processed: {each_input.absolute()} -> {each_output.absolute()}" | |
) | |
except Exception as e: | |
print(e) | |
inputs = list(input.glob("**/*")) | |
if not watch: | |
inputs = tqdm(inputs) | |
for each_input in inputs: | |
if not each_input.is_dir(): | |
process(each_input) | |
if watch: | |
observer = Observer() | |
class EventHandler(FileSystemEventHandler): | |
def on_any_event(self, event: FileSystemEvent) -> None: | |
if not ( | |
event.is_directory or event.event_type in ["deleted", "closed"] | |
): | |
process(pathlib.Path(event.src_path)) | |
event_handler = EventHandler() | |
observer.schedule(event_handler, input, recursive=False) | |
observer.start() | |
try: | |
while True: | |
time.sleep(1) | |
finally: | |
observer.stop() | |
observer.join() | |
def s(port: int, log_level: str) -> None: | |
sessions: dict[str, BaseSession] = {} | |
tags_metadata = [ | |
{ | |
"name": "Background Removal", | |
"description": "Endpoints that perform background removal with different image sources.", | |
"externalDocs": { | |
"description": "GitHub Source", | |
"url": "https://github.com/danielgatis/rembg", | |
}, | |
}, | |
] | |
app = FastAPI( | |
title="Rembg", | |
description="Rembg is a tool to remove images background. That is it.", | |
version=_version.get_versions()["version"], | |
contact={ | |
"name": "Daniel Gatis", | |
"url": "https://github.com/danielgatis", | |
"email": "danielgatis@gmail.com", | |
}, | |
license_info={ | |
"name": "MIT License", | |
"url": "https://github.com/danielgatis/rembg/blob/main/LICENSE.txt", | |
}, | |
openapi_tags=tags_metadata, | |
) | |
app.add_middleware( | |
CORSMiddleware, | |
allow_credentials=True, | |
allow_origins=["*"], | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
class ModelType(str, Enum): | |
u2net = "u2net" | |
u2netp = "u2netp" | |
u2net_human_seg = "u2net_human_seg" | |
u2net_cloth_seg = "u2net_cloth_seg" | |
class CommonQueryParams: | |
def __init__( | |
self, | |
model: ModelType = Query( | |
default=ModelType.u2net, | |
description="Model to use when processing image", | |
), | |
a: bool = Query(default=False, description="Enable Alpha Matting"), | |
af: int = Query( | |
default=240, | |
ge=0, | |
le=255, | |
description="Alpha Matting (Foreground Threshold)", | |
), | |
ab: int = Query( | |
default=10, | |
ge=0, | |
le=255, | |
description="Alpha Matting (Background Threshold)", | |
), | |
ae: int = Query( | |
default=10, ge=0, description="Alpha Matting (Erode Structure Size)" | |
), | |
om: bool = Query(default=False, description="Only Mask"), | |
ppm: bool = Query(default=False, description="Post Process Mask"), | |
): | |
self.model = model | |
self.a = a | |
self.af = af | |
self.ab = ab | |
self.ae = ae | |
self.om = om | |
self.ppm = ppm | |
class CommonQueryPostParams: | |
def __init__( | |
self, | |
model: ModelType = Form( | |
default=ModelType.u2net, | |
description="Model to use when processing image", | |
), | |
a: bool = Form(default=False, description="Enable Alpha Matting"), | |
af: int = Form( | |
default=240, | |
ge=0, | |
le=255, | |
description="Alpha Matting (Foreground Threshold)", | |
), | |
ab: int = Form( | |
default=10, | |
ge=0, | |
le=255, | |
description="Alpha Matting (Background Threshold)", | |
), | |
ae: int = Form( | |
default=10, ge=0, description="Alpha Matting (Erode Structure Size)" | |
), | |
om: bool = Form(default=False, description="Only Mask"), | |
ppm: bool = Form(default=False, description="Post Process Mask"), | |
): | |
self.model = model | |
self.a = a | |
self.af = af | |
self.ab = ab | |
self.ae = ae | |
self.om = om | |
self.ppm = ppm | |
def im_without_bg(content: bytes, commons: CommonQueryParams) -> Response: | |
return Response( | |
remove( | |
content, | |
session=sessions.setdefault( | |
commons.model.value, new_session(commons.model.value) | |
), | |
alpha_matting=commons.a, | |
alpha_matting_foreground_threshold=commons.af, | |
alpha_matting_background_threshold=commons.ab, | |
alpha_matting_erode_size=commons.ae, | |
only_mask=commons.om, | |
post_process_mask=commons.ppm, | |
), | |
media_type="image/png", | |
) | |
async def get_index( | |
url: str = Query( | |
default=..., description="URL of the image that has to be processed." | |
), | |
commons: CommonQueryParams = Depends(), | |
): | |
async with aiohttp.ClientSession() as session: | |
async with session.get(url) as response: | |
file = await response.read() | |
return await asyncify(im_without_bg)(file, commons) | |
async def post_index( | |
file: bytes = File( | |
default=..., | |
description="Image file (byte stream) that has to be processed.", | |
), | |
commons: CommonQueryPostParams = Depends(), | |
): | |
return await asyncify(im_without_bg)(file, commons) | |
uvicorn.run(app, host="0.0.0.0", port=port, log_level=log_level) | |