jhj0517
rollback pytoshop because it doesn't work in huggingface
9eb2384
raw
history blame contribute delete
No virus
2.28 kB
import gradio as gr
from modules import sam
from modules.ui_utils import *
from modules.html_constants import *
from modules.model_downloader import *
class App:
def __init__(self):
download_sam_model_url()
self.app = gr.Blocks(css=CSS)
self.sam = sam.SamInference()
def launch(self):
with self.app:
with gr.Row():
gr.Markdown(MARKDOWN_NOTE, elem_id="md_pgroject")
with gr.Row().style(equal_height=True): # bug https://github.com/gradio-app/gradio/issues/3202
with gr.Column(scale=5):
img_input = gr.Image(label="Input image here")
with gr.Column(scale=5):
# Tuable Params
nb_points_per_side = gr.Number(label="points_per_side", value=32)
sld_pred_iou_thresh = gr.Slider(label="pred_iou_thresh", value=0.88, minimum=0, maximum=1)
sld_stability_score_thresh = gr.Slider(label="stability_score_thresh", value=0.95, minimum=0,
maximum=1)
nb_crop_n_layers = gr.Number(label="crop_n_layers", value=0)
nb_crop_n_points_downscale_factor = gr.Number(label="crop_n_points_downscale_factor", value=1)
nb_min_mask_region_area = gr.Number(label="min_mask_region_area", value=0)
html_param_explain = gr.HTML(PARAMS_EXPLANATION, elem_id="html_param_explain")
with gr.Row():
btn_generate = gr.Button("GENERATE", variant="primary")
with gr.Row():
gallery_output = gr.Gallery(label="Output will be shown here", show_label=True).style(grid=5,
height="auto")
params = [nb_points_per_side, sld_pred_iou_thresh, sld_stability_score_thresh, nb_crop_n_layers,
nb_crop_n_points_downscale_factor, nb_min_mask_region_area]
btn_generate.click(fn=self.sam.generate_mask_app, inputs=[img_input] + params, outputs=[gallery_output])
self.app.queue(api_open=False).launch()
if __name__ == "__main__":
app = App()
app.launch()