File size: 4,158 Bytes
59e40e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
"""
Source url: https://github.com/OPHoperHPO/image-background-remove-tool
Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
License: Apache License 2.0
"""
import warnings

from carvekit.api.interface import Interface
from carvekit.ml.wrap.fba_matting import FBAMatting
from carvekit.ml.wrap.tracer_b7 import TracerUniversalB7
from carvekit.ml.wrap.u2net import U2NET
from carvekit.pipelines.postprocessing import MattingMethod
from carvekit.trimap.generator import TrimapGenerator


class HiInterface(Interface):
    def __init__(
        self,
        object_type: str = "object",
        batch_size_seg=2,
        batch_size_matting=1,
        device="cpu",
        seg_mask_size=640,
        matting_mask_size=2048,
        trimap_prob_threshold=231,
        trimap_dilation=30,
        trimap_erosion_iters=5,
        fp16=False,
    ):
        """
        Initializes High Level interface.

        Args:
            object_type: Interest object type. Can be "object" or "hairs-like".
            matting_mask_size:  The size of the input image for the matting neural network.
            seg_mask_size: The size of the input image for the segmentation neural network.
            batch_size_seg: Number of images processed per one segmentation neural network call.
            batch_size_matting: Number of images processed per one matting neural network call.
            device: Processing device
            fp16: Use half precision. Reduce memory usage and increase speed. Experimental support
            trimap_prob_threshold: Probability threshold at which the prob_filter and prob_as_unknown_area operations will be applied
            trimap_dilation: The size of the offset radius from the object mask in pixels when forming an unknown area
            trimap_erosion_iters: The number of iterations of erosion that the object's mask will be subjected to before forming an unknown area

        Notes:
            1. Changing seg_mask_size may cause an out-of-memory error if the value is too large, and it may also
            result in reduced precision. I do not recommend changing this value. You can change matting_mask_size in
            range from (1024 to 4096) to improve object edge refining quality, but it will cause extra large RAM and
            video memory consume. Also, you can change batch size to accelerate background removal, but it also causes
            extra large video memory consume, if value is too big.

            2. Changing trimap_prob_threshold, trimap_kernel_size, trimap_erosion_iters may improve object edge
            refining quality,
        """
        if object_type == "object":
            self.u2net = TracerUniversalB7(
                device=device,
                batch_size=batch_size_seg,
                input_image_size=seg_mask_size,
                fp16=fp16,
            )
        elif object_type == "hairs-like":
            self.u2net = U2NET(
                device=device,
                batch_size=batch_size_seg,
                input_image_size=seg_mask_size,
                fp16=fp16,
            )
        else:
            warnings.warn(
                f"Unknown object type: {object_type}. Using default object type: object"
            )
            self.u2net = TracerUniversalB7(
                device=device,
                batch_size=batch_size_seg,
                input_image_size=seg_mask_size,
                fp16=fp16,
            )

        self.fba = FBAMatting(
            batch_size=batch_size_matting,
            device=device,
            input_tensor_size=matting_mask_size,
            fp16=fp16,
        )
        self.trimap_generator = TrimapGenerator(
            prob_threshold=trimap_prob_threshold,
            kernel_size=trimap_dilation,
            erosion_iters=trimap_erosion_iters,
        )
        super(HiInterface, self).__init__(
            pre_pipe=None,
            seg_pipe=self.u2net,
            post_pipe=MattingMethod(
                matting_module=self.fba,
                trimap_generator=self.trimap_generator,
                device=device,
            ),
            device=device,
        )