Spaces:
Runtime error
Runtime error
from os import getenv | |
from typing import Union | |
from loguru import logger | |
from carvekit.web.schemas.config import WebAPIConfig, MLConfig, AuthConfig | |
from carvekit.api.interface import Interface | |
from carvekit.ml.wrap.fba_matting import FBAMatting | |
from carvekit.ml.wrap.u2net import U2NET | |
from carvekit.ml.wrap.deeplab_v3 import DeepLabV3 | |
from carvekit.ml.wrap.basnet import BASNET | |
from carvekit.ml.wrap.tracer_b7 import TracerUniversalB7 | |
from carvekit.pipelines.postprocessing import MattingMethod | |
from carvekit.pipelines.preprocessing import PreprocessingStub | |
from carvekit.trimap.generator import TrimapGenerator | |
def init_config() -> WebAPIConfig: | |
default_config = WebAPIConfig() | |
config = WebAPIConfig( | |
**dict( | |
port=int(getenv("CARVEKIT_PORT", default_config.port)), | |
host=getenv("CARVEKIT_HOST", default_config.host), | |
ml=MLConfig( | |
segmentation_network=getenv( | |
"CARVEKIT_SEGMENTATION_NETWORK", | |
default_config.ml.segmentation_network, | |
), | |
preprocessing_method=getenv( | |
"CARVEKIT_PREPROCESSING_METHOD", | |
default_config.ml.preprocessing_method, | |
), | |
postprocessing_method=getenv( | |
"CARVEKIT_POSTPROCESSING_METHOD", | |
default_config.ml.postprocessing_method, | |
), | |
device=getenv("CARVEKIT_DEVICE", default_config.ml.device), | |
batch_size_seg=int( | |
getenv("CARVEKIT_BATCH_SIZE_SEG", default_config.ml.batch_size_seg) | |
), | |
batch_size_matting=int( | |
getenv( | |
"CARVEKIT_BATCH_SIZE_MATTING", | |
default_config.ml.batch_size_matting, | |
) | |
), | |
seg_mask_size=int( | |
getenv("CARVEKIT_SEG_MASK_SIZE", default_config.ml.seg_mask_size) | |
), | |
matting_mask_size=int( | |
getenv( | |
"CARVEKIT_MATTING_MASK_SIZE", | |
default_config.ml.matting_mask_size, | |
) | |
), | |
fp16=bool(int(getenv("CARVEKIT_FP16", default_config.ml.fp16))), | |
trimap_prob_threshold=int( | |
getenv( | |
"CARVEKIT_TRIMAP_PROB_THRESHOLD", | |
default_config.ml.trimap_prob_threshold, | |
) | |
), | |
trimap_dilation=int( | |
getenv( | |
"CARVEKIT_TRIMAP_DILATION", default_config.ml.trimap_dilation | |
) | |
), | |
trimap_erosion=int( | |
getenv("CARVEKIT_TRIMAP_EROSION", default_config.ml.trimap_erosion) | |
), | |
), | |
auth=AuthConfig( | |
auth=bool( | |
int(getenv("CARVEKIT_AUTH_ENABLE", default_config.auth.auth)) | |
), | |
admin_token=getenv( | |
"CARVEKIT_ADMIN_TOKEN", default_config.auth.admin_token | |
), | |
allowed_tokens=default_config.auth.allowed_tokens | |
if getenv("CARVEKIT_ALLOWED_TOKENS") is None | |
else getenv("CARVEKIT_ALLOWED_TOKENS").split(","), | |
), | |
) | |
) | |
logger.info(f"Admin token for Web API is {config.auth.admin_token}") | |
logger.debug(f"Running Web API with this config: {config.json()}") | |
return config | |
def init_interface(config: Union[WebAPIConfig, MLConfig]) -> Interface: | |
if isinstance(config, WebAPIConfig): | |
config = config.ml | |
if config.segmentation_network == "u2net": | |
seg_net = U2NET( | |
device=config.device, | |
batch_size=config.batch_size_seg, | |
input_image_size=config.seg_mask_size, | |
fp16=config.fp16, | |
) | |
elif config.segmentation_network == "deeplabv3": | |
seg_net = DeepLabV3( | |
device=config.device, | |
batch_size=config.batch_size_seg, | |
input_image_size=config.seg_mask_size, | |
fp16=config.fp16, | |
) | |
elif config.segmentation_network == "basnet": | |
seg_net = BASNET( | |
device=config.device, | |
batch_size=config.batch_size_seg, | |
input_image_size=config.seg_mask_size, | |
fp16=config.fp16, | |
) | |
elif config.segmentation_network == "tracer_b7": | |
seg_net = TracerUniversalB7( | |
device=config.device, | |
batch_size=config.batch_size_seg, | |
input_image_size=config.seg_mask_size, | |
fp16=config.fp16, | |
) | |
else: | |
seg_net = TracerUniversalB7( | |
device=config.device, | |
batch_size=config.batch_size_seg, | |
input_image_size=config.seg_mask_size, | |
fp16=config.fp16, | |
) | |
if config.preprocessing_method == "stub": | |
preprocessing = PreprocessingStub() | |
elif config.preprocessing_method == "none": | |
preprocessing = None | |
else: | |
preprocessing = None | |
if config.postprocessing_method == "fba": | |
fba = FBAMatting( | |
device=config.device, | |
batch_size=config.batch_size_matting, | |
input_tensor_size=config.matting_mask_size, | |
fp16=config.fp16, | |
) | |
trimap_generator = TrimapGenerator( | |
prob_threshold=config.trimap_prob_threshold, | |
kernel_size=config.trimap_dilation, | |
erosion_iters=config.trimap_erosion, | |
) | |
postprocessing = MattingMethod( | |
device=config.device, matting_module=fba, trimap_generator=trimap_generator | |
) | |
elif config.postprocessing_method == "none": | |
postprocessing = None | |
else: | |
postprocessing = None | |
interface = Interface( | |
pre_pipe=preprocessing, | |
post_pipe=postprocessing, | |
seg_pipe=seg_net, | |
device=config.device, | |
) | |
return interface | |