|
import logging |
|
import numpy as np |
|
import gradio as gr |
|
from rembg import new_session |
|
from cutter import remove, make_label |
|
from utils import * |
|
|
|
remove_bg_models = { |
|
"U2NET": "u2net", |
|
"U2NET Human Seg": "u2net_human_seg", |
|
"U2NET Cloth Seg": "u2net_cloth_seg" |
|
} |
|
|
|
default_model = "U2NET" |
|
|
|
def predict(image): |
|
session = new_session(remove_bg_models[default_model]) |
|
smoot = False |
|
matting = (0, 0, 0) |
|
bg_color = False |
|
try: |
|
result, _ = remove(session, image, smoot, matting, bg_color) |
|
if isinstance(result, np.ndarray): |
|
result = Image.fromarray(result.astype('uint8')) |
|
return result |
|
except ValueError as err: |
|
logging.error(err) |
|
return make_label(str(err)), None |
|
|
|
with gr.Blocks(css="custom.css", title="Remove background") as app: |
|
gr.HTML("<center><h1>Background Remover</h1></center>") |
|
with gr.Row(equal_height=False): |
|
with gr.Column(): |
|
input_img = gr.Image(type="pil", label="Input image") |
|
with gr.Column(): |
|
output_img = gr.Image(type="pil", label="Result image") |
|
|
|
with gr.Row(equal_height=True): |
|
run_btn = gr.Button(value="Remove background", variant="primary") |
|
clear_btn = gr.Button(value="Clear", variant="secondary") |
|
|
|
run_btn.click(predict, inputs=[input_img], outputs=[output_img]) |
|
clear_btn.click(lambda: (None, None), inputs=None, outputs=[input_img, output_img]) |
|
|
|
app.launch(share=False, debug=True, enable_queue=True, show_error=True) |
|
|