import gradio as gr import torch from carvekit.api.interface import Interface from carvekit.ml.wrap.basnet import BASNET from carvekit.ml.wrap.deeplab_v3 import DeepLabV3 from carvekit.ml.wrap.fba_matting import FBAMatting from carvekit.ml.wrap.tracer_b7 import TracerUniversalB7 from carvekit.ml.wrap.u2net import U2NET from carvekit.pipelines.postprocessing import MattingMethod from carvekit.pipelines.preprocessing import PreprocessingStub from carvekit.trimap.generator import TrimapGenerator device = 'cuda' if torch.cuda.is_available() else 'cpu' segment_net = { "U2NET": U2NET(device=device, batch_size=1), "BASNET": BASNET(device=device, batch_size=1), "DeepLabV3": DeepLabV3(device=device, batch_size=1), "TracerUniversalB7": TracerUniversalB7(device=device, batch_size=1) } fba = FBAMatting(device=device, input_tensor_size=2048, batch_size=1) trimap = TrimapGenerator() preprocessing = PreprocessingStub() postprocessing = MattingMethod(matting_module=fba, trimap_generator=trimap, device=device) method_choices = [k for k, v in segment_net.items()] def generate_trimap(method, original): mask = segment_net[method]([original]) return trimap(original_image=original, mask=mask[0]) def predict(method, image): method = segment_net[method] return Interface(pre_pipe=preprocessing, post_pipe=postprocessing, seg_pipe=method)([image])[0] footer = r"""