PBRemTools / scripts /main.py
fkunn1326's picture
Upload 5 files
0c43c79
raw
history blame contribute delete
No virus
2.92 kB
import os
import io
import json
import numpy as np
import cv2
import gradio as gr
import modules.scripts as scripts
from modules import script_callbacks
from scripts.td_abg import get_foreground
from scripts.convertor import pil2cv
def processing(input_image, td_abg_enabled, h_split, v_split, n_cluster, alpha, th_rate, cascadePSP_enabled, fast, psp_L):
image = pil2cv(input_image)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
mask, image = get_foreground(image, td_abg_enabled, h_split, v_split, n_cluster, alpha, th_rate, cascadePSP_enabled, fast, psp_L)
return image, mask
class Script(scripts.Script):
def __init__(self) -> None:
super().__init__()
def title(self):
return "PBRemTools"
def show(self, is_img2img):
return scripts.AlwaysVisible
def ui(self, is_img2img):
return ()
def on_ui_tabs():
with gr.Blocks(analytics_enabled=False) as PBRemTools:
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil")
with gr.Accordion("tile division BG Remover", open=True):
with gr.Box():
td_abg_enabled = gr.Checkbox(label="enabled", show_label=True)
h_split = gr.Slider(1, 2048, value=256, step=4, label="horizontal split num", show_label=True)
v_split = gr.Slider(1, 2048, value=256, step=4, label="vertical split num", show_label=True)
n_cluster = gr.Slider(1, 1000, value=500, step=10, label="cluster num", show_label=True)
alpha = gr.Slider(1, 255, value=50, step=1, label="alpha threshold", show_label=True)
th_rate = gr.Slider(0, 1, value=0.1, step=0.01, label="mask content ratio", show_label=True)
with gr.Accordion("cascadePSP", open=True):
with gr.Box():
cascadePSP_enabled = gr.Checkbox(label="enabled", show_label=True)
fast = gr.Checkbox(label="fast", show_label=True)
psp_L = gr.Slider(1, 2048, value=900, step=1, label="Memory usage", show_label=True)
submit = gr.Button(value="Submit")
with gr.Row():
with gr.Column():
with gr.Tab("output"):
output_img = gr.Image()
with gr.Tab("mask"):
output_mask = gr.Image()
#dummy_component = gr.Label(visible=False)
#preset = gr.Text(visible=False)
submit.click(
processing,
inputs=[input_image, td_abg_enabled, h_split, v_split, n_cluster, alpha, th_rate, cascadePSP_enabled, fast, psp_L],
outputs=[output_img, output_mask]
)
return [(PBRemTools, "PBRemTools", "pbremtools")]
script_callbacks.on_ui_tabs(on_ui_tabs)