import gradio as gr import torch from carvekit.api.interface import Interface from carvekit.ml.wrap.basnet import BASNET from carvekit.ml.wrap.deeplab_v3 import DeepLabV3 from carvekit.ml.wrap.fba_matting import FBAMatting from carvekit.ml.wrap.tracer_b7 import TracerUniversalB7 from carvekit.ml.wrap.u2net import U2NET from carvekit.pipelines.postprocessing import MattingMethod from carvekit.pipelines.preprocessing import PreprocessingStub from carvekit.trimap.generator import TrimapGenerator device = 'cuda' if torch.cuda.is_available() else 'cpu' segment_net = { "U2NET": U2NET(device=device, batch_size=1), "BASNET": BASNET(device=device, batch_size=1), "DeepLabV3": DeepLabV3(device=device, batch_size=1), "TracerUniversalB7": TracerUniversalB7(device=device, batch_size=1) } fba = FBAMatting(device=device, input_tensor_size=2048, batch_size=1) trimap = TrimapGenerator() preprocessing = PreprocessingStub() postprocessing = MattingMethod(matting_module=fba, trimap_generator=trimap, device=device) method_choices = [k for k, v in segment_net.items()] def generate_trimap(method, original): mask = segment_net[method]([original]) return trimap(original_image=original, mask=mask[0]) def predict(method, image): method = segment_net[method] return Interface(pre_pipe=preprocessing, post_pipe=postprocessing, seg_pipe=method)([image])[0] footer = r"""
CarveKit
Demo based on CarveKit
""" with gr.Blocks(title="CarveKit") as app: gr.Markdown("

CarveKit

") gr.HTML("

High-quality image background removal

") with gr.Tabs() as tabs: with gr.TabItem("Remove background", id=0): with gr.Row(equal_height=False): with gr.Column(): input_img = gr.Image(type="pil", label="Input image") drp_itf = gr.Dropdown( value="TracerUniversalB7", label="Segmentor model", choices=method_choices) run_btn = gr.Button(variant="primary") with gr.Column(): output_img = gr.Image(type="pil", label="result") run_btn.click(predict, [drp_itf, input_img], [output_img]) with gr.TabItem("Trimap generator", id=1): with gr.Row(equal_height=False): with gr.Column(): trimap_input = gr.Image(type="pil", label="Input image") drp_itf = gr.Dropdown( value="TracerUniversalB7", label="Segmentor model", choices=method_choices) trimap_btn = gr.Button(variant="primary") with gr.Column(): trimap_output = gr.Image(type="pil", label="result") trimap_btn.click(generate_trimap, [drp_itf, trimap_input], [trimap_output]) with gr.Row(): gr.HTML(footer) app.queue() app.launch(share=False, debug=True, show_error=True)