|
from modules import scripts_postprocessing, ui_components |
|
import gradio as gr |
|
|
|
from modules.ui_components import FormRow |
|
from modules.paths_internal import models_path |
|
import rembg |
|
import os |
|
|
|
models = [ |
|
"None", |
|
"isnet-general-use", |
|
"u2net", |
|
"u2netp", |
|
"u2net_human_seg", |
|
"u2net_cloth_seg", |
|
"silueta", |
|
"isnet-general-use", |
|
"isnet-anime", |
|
] |
|
|
|
class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing): |
|
name = "Rembg" |
|
order = 20000 |
|
model = None |
|
|
|
def ui(self): |
|
with ui_components.InputAccordion(False, label="Remove background") as enable: |
|
with gr.Row(): |
|
model = gr.Dropdown(label="Remove background", choices=models, value="None") |
|
return_mask = gr.Checkbox(label="Return mask", value=False) |
|
alpha_matting = gr.Checkbox(label="Alpha matting", value=False) |
|
|
|
with gr.Row(visible=False) as alpha_mask_row: |
|
alpha_matting_erode_size = gr.Slider(label="Erode size", minimum=0, maximum=40, step=1, value=10) |
|
alpha_matting_foreground_threshold = gr.Slider(label="Foreground threshold", minimum=0, maximum=255, step=1, value=240) |
|
alpha_matting_background_threshold = gr.Slider(label="Background threshold", minimum=0, maximum=255, step=1, value=10) |
|
|
|
alpha_matting.change( |
|
fn=lambda x: gr.update(visible=x), |
|
inputs=[alpha_matting], |
|
outputs=[alpha_mask_row], |
|
) |
|
|
|
return { |
|
"enable": enable, |
|
"model": model, |
|
"return_mask": return_mask, |
|
"alpha_matting": alpha_matting, |
|
"alpha_matting_foreground_threshold": alpha_matting_foreground_threshold, |
|
"alpha_matting_background_threshold": alpha_matting_background_threshold, |
|
"alpha_matting_erode_size": alpha_matting_erode_size, |
|
} |
|
|
|
def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, model, return_mask, alpha_matting, alpha_matting_foreground_threshold, alpha_matting_background_threshold, alpha_matting_erode_size): |
|
if not enable: |
|
return |
|
|
|
if not model or model == "None": |
|
return |
|
|
|
if "U2NET_HOME" not in os.environ: |
|
os.environ["U2NET_HOME"] = os.path.join(models_path, "u2net") |
|
|
|
pp.image = rembg.remove( |
|
pp.image, |
|
session=rembg.new_session(model), |
|
only_mask=return_mask, |
|
alpha_matting=alpha_matting, |
|
alpha_matting_foreground_threshold=alpha_matting_foreground_threshold, |
|
alpha_matting_background_threshold=alpha_matting_background_threshold, |
|
alpha_matting_erode_size=alpha_matting_erode_size, |
|
) |
|
|
|
pp.info["Rembg"] = model |
|
|