Towsif7 commited on
Commit
59e40e1
1 Parent(s): ccfae17

firrst commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +36 -0
  2. carvekit/__init__.py +1 -0
  3. carvekit/__main__.py +149 -0
  4. carvekit/__pycache__/__init__.cpython-38.pyc +0 -0
  5. carvekit/api/__init__.py +0 -0
  6. carvekit/api/__pycache__/__init__.cpython-38.pyc +0 -0
  7. carvekit/api/__pycache__/high.cpython-38.pyc +0 -0
  8. carvekit/api/__pycache__/interface.cpython-38.pyc +0 -0
  9. carvekit/api/high.py +100 -0
  10. carvekit/api/interface.py +77 -0
  11. carvekit/ml/__init__.py +4 -0
  12. carvekit/ml/__pycache__/__init__.cpython-38.pyc +0 -0
  13. carvekit/ml/arch/__init__.py +0 -0
  14. carvekit/ml/arch/__pycache__/__init__.cpython-38.pyc +0 -0
  15. carvekit/ml/arch/basnet/__init__.py +0 -0
  16. carvekit/ml/arch/basnet/__pycache__/__init__.cpython-38.pyc +0 -0
  17. carvekit/ml/arch/basnet/__pycache__/basnet.cpython-38.pyc +0 -0
  18. carvekit/ml/arch/basnet/basnet.py +478 -0
  19. carvekit/ml/arch/fba_matting/__init__.py +0 -0
  20. carvekit/ml/arch/fba_matting/__pycache__/__init__.cpython-38.pyc +0 -0
  21. carvekit/ml/arch/fba_matting/__pycache__/layers_WS.cpython-38.pyc +0 -0
  22. carvekit/ml/arch/fba_matting/__pycache__/models.cpython-38.pyc +0 -0
  23. carvekit/ml/arch/fba_matting/__pycache__/resnet_GN_WS.cpython-38.pyc +0 -0
  24. carvekit/ml/arch/fba_matting/__pycache__/resnet_bn.cpython-38.pyc +0 -0
  25. carvekit/ml/arch/fba_matting/__pycache__/transforms.cpython-38.pyc +0 -0
  26. carvekit/ml/arch/fba_matting/layers_WS.py +57 -0
  27. carvekit/ml/arch/fba_matting/models.py +341 -0
  28. carvekit/ml/arch/fba_matting/resnet_GN_WS.py +151 -0
  29. carvekit/ml/arch/fba_matting/resnet_bn.py +169 -0
  30. carvekit/ml/arch/fba_matting/transforms.py +45 -0
  31. carvekit/ml/arch/tracerb7/__init__.py +0 -0
  32. carvekit/ml/arch/tracerb7/__pycache__/__init__.cpython-38.pyc +0 -0
  33. carvekit/ml/arch/tracerb7/__pycache__/att_modules.cpython-38.pyc +0 -0
  34. carvekit/ml/arch/tracerb7/__pycache__/conv_modules.cpython-38.pyc +0 -0
  35. carvekit/ml/arch/tracerb7/__pycache__/effi_utils.cpython-38.pyc +0 -0
  36. carvekit/ml/arch/tracerb7/__pycache__/efficientnet.cpython-38.pyc +0 -0
  37. carvekit/ml/arch/tracerb7/__pycache__/tracer.cpython-38.pyc +0 -0
  38. carvekit/ml/arch/tracerb7/att_modules.py +290 -0
  39. carvekit/ml/arch/tracerb7/conv_modules.py +88 -0
  40. carvekit/ml/arch/tracerb7/effi_utils.py +579 -0
  41. carvekit/ml/arch/tracerb7/efficientnet.py +325 -0
  42. carvekit/ml/arch/tracerb7/tracer.py +97 -0
  43. carvekit/ml/arch/u2net/__init__.py +0 -0
  44. carvekit/ml/arch/u2net/__pycache__/__init__.cpython-38.pyc +0 -0
  45. carvekit/ml/arch/u2net/__pycache__/u2net.cpython-38.pyc +0 -0
  46. carvekit/ml/arch/u2net/u2net.py +172 -0
  47. carvekit/ml/files/__init__.py +7 -0
  48. carvekit/ml/files/__pycache__/__init__.cpython-38.pyc +0 -0
  49. carvekit/ml/files/__pycache__/models_loc.cpython-38.pyc +0 -0
  50. carvekit/ml/files/models_loc.py +70 -0
app.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from carvekit.api.interface import Interface
3
+ from carvekit.ml.wrap.fba_matting import FBAMatting
4
+ from carvekit.ml.wrap.tracer_b7 import TracerUniversalB7
5
+ from carvekit.pipelines.postprocessing import MattingMethod
6
+ from carvekit.pipelines.preprocessing import PreprocessingStub
7
+ from carvekit.trimap.generator import TrimapGenerator
8
+ from PIL import Image
9
+
10
+ # Create Streamlit app title
11
+ st.title("Image Background Remover")
12
+
13
+ # Create a file uploader
14
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "png"])
15
+
16
+ if uploaded_file is not None:
17
+ # Load the image
18
+ image = Image.open(uploaded_file)
19
+
20
+ # Set up ML pipeline
21
+ seg_net = TracerUniversalB7(device='cpu', batch_size=1)
22
+ fba = FBAMatting(device='cpu', input_tensor_size=2048, batch_size=1)
23
+ trimap = TrimapGenerator()
24
+ preprocessing = PreprocessingStub()
25
+ postprocessing = MattingMethod(matting_module=fba, trimap_generator=trimap, device='cpu')
26
+ interface = Interface(pre_pipe=preprocessing, post_pipe=postprocessing, seg_pipe=seg_net)
27
+
28
+ # Process the image
29
+ processed_bg = interface([image])[0]
30
+
31
+ # Display original and processed images
32
+ col1, col2 = st.columns(2)
33
+ with col1:
34
+ st.image(image, caption='Original Image', use_column_width=True)
35
+ with col2:
36
+ st.image(processed_bg, caption='Background Removed', use_column_width=True)
carvekit/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ version = "4.1.0"
carvekit/__main__.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import click
4
+ import tqdm
5
+
6
+ from carvekit.utils.image_utils import ALLOWED_SUFFIXES
7
+ from carvekit.utils.pool_utils import batch_generator, thread_pool_processing
8
+ from carvekit.web.schemas.config import MLConfig
9
+ from carvekit.web.utils.init_utils import init_interface
10
+ from carvekit.utils.fs_utils import save_file
11
+
12
+
13
+ @click.command(
14
+ "removebg",
15
+ help="Performs background removal on specified photos using console interface.",
16
+ )
17
+ @click.option("-i", required=True, type=str, help="Path to input file or dir")
18
+ @click.option("-o", default="none", type=str, help="Path to output file or dir")
19
+ @click.option("--pre", default="none", type=str, help="Preprocessing method")
20
+ @click.option("--post", default="fba", type=str, help="Postprocessing method.")
21
+ @click.option("--net", default="tracer_b7", type=str, help="Segmentation Network")
22
+ @click.option(
23
+ "--recursive",
24
+ default=False,
25
+ type=bool,
26
+ help="Enables recursive search for images in a folder",
27
+ )
28
+ @click.option(
29
+ "--batch_size",
30
+ default=10,
31
+ type=int,
32
+ help="Batch Size for list of images to be loaded to RAM",
33
+ )
34
+ @click.option(
35
+ "--batch_size_seg",
36
+ default=5,
37
+ type=int,
38
+ help="Batch size for list of images to be processed by segmentation " "network",
39
+ )
40
+ @click.option(
41
+ "--batch_size_mat",
42
+ default=1,
43
+ type=int,
44
+ help="Batch size for list of images to be processed by matting " "network",
45
+ )
46
+ @click.option(
47
+ "--seg_mask_size",
48
+ default=640,
49
+ type=int,
50
+ help="The size of the input image for the segmentation neural network.",
51
+ )
52
+ @click.option(
53
+ "--matting_mask_size",
54
+ default=2048,
55
+ type=int,
56
+ help="The size of the input image for the matting neural network.",
57
+ )
58
+ @click.option(
59
+ "--trimap_dilation",
60
+ default=30,
61
+ type=int,
62
+ help="The size of the offset radius from the object mask in "
63
+ "pixels when forming an unknown area",
64
+ )
65
+ @click.option(
66
+ "--trimap_erosion",
67
+ default=5,
68
+ type=int,
69
+ help="The number of iterations of erosion that the object's "
70
+ "mask will be subjected to before forming an unknown area",
71
+ )
72
+ @click.option(
73
+ "--trimap_prob_threshold",
74
+ default=231,
75
+ type=int,
76
+ help="Probability threshold at which the prob_filter "
77
+ "and prob_as_unknown_area operations will be "
78
+ "applied",
79
+ )
80
+ @click.option("--device", default="cpu", type=str, help="Processing Device.")
81
+ @click.option(
82
+ "--fp16", default=False, type=bool, help="Enables mixed precision processing."
83
+ )
84
+ def removebg(
85
+ i: str,
86
+ o: str,
87
+ pre: str,
88
+ post: str,
89
+ net: str,
90
+ recursive: bool,
91
+ batch_size: int,
92
+ batch_size_seg: int,
93
+ batch_size_mat: int,
94
+ seg_mask_size: int,
95
+ matting_mask_size: int,
96
+ device: str,
97
+ fp16: bool,
98
+ trimap_dilation: int,
99
+ trimap_erosion: int,
100
+ trimap_prob_threshold: int,
101
+ ):
102
+ out_path = Path(o)
103
+ input_path = Path(i)
104
+ if input_path.is_dir():
105
+ if recursive:
106
+ all_images = input_path.rglob("*.*")
107
+ else:
108
+ all_images = input_path.glob("*.*")
109
+ all_images = [
110
+ i
111
+ for i in all_images
112
+ if i.suffix.lower() in ALLOWED_SUFFIXES and "_bg_removed" not in i.name
113
+ ]
114
+ else:
115
+ all_images = [input_path]
116
+
117
+ interface_config = MLConfig(
118
+ segmentation_network=net,
119
+ preprocessing_method=pre,
120
+ postprocessing_method=post,
121
+ device=device,
122
+ batch_size_seg=batch_size_seg,
123
+ batch_size_matting=batch_size_mat,
124
+ seg_mask_size=seg_mask_size,
125
+ matting_mask_size=matting_mask_size,
126
+ fp16=fp16,
127
+ trimap_dilation=trimap_dilation,
128
+ trimap_erosion=trimap_erosion,
129
+ trimap_prob_threshold=trimap_prob_threshold,
130
+ )
131
+
132
+ interface = init_interface(interface_config)
133
+
134
+ for image_batch in tqdm.tqdm(
135
+ batch_generator(all_images, n=batch_size),
136
+ total=int(len(all_images) / batch_size),
137
+ desc="Removing background",
138
+ unit=" image batch",
139
+ colour="blue",
140
+ ):
141
+ images_without_background = interface(image_batch) # Remove background
142
+ thread_pool_processing(
143
+ lambda x: save_file(out_path, image_batch[x], images_without_background[x]),
144
+ range((len(image_batch))),
145
+ ) # Drop images to fs
146
+
147
+
148
+ if __name__ == "__main__":
149
+ removebg()
carvekit/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (187 Bytes). View file
 
carvekit/api/__init__.py ADDED
File without changes
carvekit/api/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (174 Bytes). View file
 
carvekit/api/__pycache__/high.cpython-38.pyc ADDED
Binary file (3.71 kB). View file
 
carvekit/api/__pycache__/interface.cpython-38.pyc ADDED
Binary file (2.87 kB). View file
 
carvekit/api/high.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Source url: https://github.com/OPHoperHPO/image-background-remove-tool
3
+ Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
4
+ License: Apache License 2.0
5
+ """
6
+ import warnings
7
+
8
+ from carvekit.api.interface import Interface
9
+ from carvekit.ml.wrap.fba_matting import FBAMatting
10
+ from carvekit.ml.wrap.tracer_b7 import TracerUniversalB7
11
+ from carvekit.ml.wrap.u2net import U2NET
12
+ from carvekit.pipelines.postprocessing import MattingMethod
13
+ from carvekit.trimap.generator import TrimapGenerator
14
+
15
+
16
+ class HiInterface(Interface):
17
+ def __init__(
18
+ self,
19
+ object_type: str = "object",
20
+ batch_size_seg=2,
21
+ batch_size_matting=1,
22
+ device="cpu",
23
+ seg_mask_size=640,
24
+ matting_mask_size=2048,
25
+ trimap_prob_threshold=231,
26
+ trimap_dilation=30,
27
+ trimap_erosion_iters=5,
28
+ fp16=False,
29
+ ):
30
+ """
31
+ Initializes High Level interface.
32
+
33
+ Args:
34
+ object_type: Interest object type. Can be "object" or "hairs-like".
35
+ matting_mask_size: The size of the input image for the matting neural network.
36
+ seg_mask_size: The size of the input image for the segmentation neural network.
37
+ batch_size_seg: Number of images processed per one segmentation neural network call.
38
+ batch_size_matting: Number of images processed per one matting neural network call.
39
+ device: Processing device
40
+ fp16: Use half precision. Reduce memory usage and increase speed. Experimental support
41
+ trimap_prob_threshold: Probability threshold at which the prob_filter and prob_as_unknown_area operations will be applied
42
+ trimap_dilation: The size of the offset radius from the object mask in pixels when forming an unknown area
43
+ trimap_erosion_iters: The number of iterations of erosion that the object's mask will be subjected to before forming an unknown area
44
+
45
+ Notes:
46
+ 1. Changing seg_mask_size may cause an out-of-memory error if the value is too large, and it may also
47
+ result in reduced precision. I do not recommend changing this value. You can change matting_mask_size in
48
+ range from (1024 to 4096) to improve object edge refining quality, but it will cause extra large RAM and
49
+ video memory consume. Also, you can change batch size to accelerate background removal, but it also causes
50
+ extra large video memory consume, if value is too big.
51
+
52
+ 2. Changing trimap_prob_threshold, trimap_kernel_size, trimap_erosion_iters may improve object edge
53
+ refining quality,
54
+ """
55
+ if object_type == "object":
56
+ self.u2net = TracerUniversalB7(
57
+ device=device,
58
+ batch_size=batch_size_seg,
59
+ input_image_size=seg_mask_size,
60
+ fp16=fp16,
61
+ )
62
+ elif object_type == "hairs-like":
63
+ self.u2net = U2NET(
64
+ device=device,
65
+ batch_size=batch_size_seg,
66
+ input_image_size=seg_mask_size,
67
+ fp16=fp16,
68
+ )
69
+ else:
70
+ warnings.warn(
71
+ f"Unknown object type: {object_type}. Using default object type: object"
72
+ )
73
+ self.u2net = TracerUniversalB7(
74
+ device=device,
75
+ batch_size=batch_size_seg,
76
+ input_image_size=seg_mask_size,
77
+ fp16=fp16,
78
+ )
79
+
80
+ self.fba = FBAMatting(
81
+ batch_size=batch_size_matting,
82
+ device=device,
83
+ input_tensor_size=matting_mask_size,
84
+ fp16=fp16,
85
+ )
86
+ self.trimap_generator = TrimapGenerator(
87
+ prob_threshold=trimap_prob_threshold,
88
+ kernel_size=trimap_dilation,
89
+ erosion_iters=trimap_erosion_iters,
90
+ )
91
+ super(HiInterface, self).__init__(
92
+ pre_pipe=None,
93
+ seg_pipe=self.u2net,
94
+ post_pipe=MattingMethod(
95
+ matting_module=self.fba,
96
+ trimap_generator=self.trimap_generator,
97
+ device=device,
98
+ ),
99
+ device=device,
100
+ )
carvekit/api/interface.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Source url: https://github.com/OPHoperHPO/image-background-remove-tool
3
+ Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
4
+ License: Apache License 2.0
5
+ """
6
+ from pathlib import Path
7
+ from typing import Union, List, Optional
8
+
9
+ from PIL import Image
10
+
11
+ from carvekit.ml.wrap.basnet import BASNET
12
+ from carvekit.ml.wrap.deeplab_v3 import DeepLabV3
13
+ from carvekit.ml.wrap.u2net import U2NET
14
+ from carvekit.ml.wrap.tracer_b7 import TracerUniversalB7
15
+ from carvekit.pipelines.preprocessing import PreprocessingStub
16
+ from carvekit.pipelines.postprocessing import MattingMethod
17
+ from carvekit.utils.image_utils import load_image
18
+ from carvekit.utils.mask_utils import apply_mask
19
+ from carvekit.utils.pool_utils import thread_pool_processing
20
+
21
+
22
+ class Interface:
23
+ def __init__(
24
+ self,
25
+ seg_pipe: Union[U2NET, BASNET, DeepLabV3, TracerUniversalB7],
26
+ pre_pipe: Optional[Union[PreprocessingStub]] = None,
27
+ post_pipe: Optional[Union[MattingMethod]] = None,
28
+ device="cpu",
29
+ ):
30
+ """
31
+ Initializes an object for interacting with pipelines and other components of the CarveKit framework.
32
+
33
+ Args:
34
+ pre_pipe: Initialized pre-processing pipeline object
35
+ seg_pipe: Initialized segmentation network object
36
+ post_pipe: Initialized postprocessing pipeline object
37
+ device: The processing device that will be used to apply the masks to the images.
38
+ """
39
+ self.device = device
40
+ self.preprocessing_pipeline = pre_pipe
41
+ self.segmentation_pipeline = seg_pipe
42
+ self.postprocessing_pipeline = post_pipe
43
+
44
+ def __call__(
45
+ self, images: List[Union[str, Path, Image.Image]]
46
+ ) -> List[Image.Image]:
47
+ """
48
+ Removes the background from the specified images.
49
+
50
+ Args:
51
+ images: list of input images
52
+
53
+ Returns:
54
+ List of images without background as PIL.Image.Image instances
55
+ """
56
+ images = thread_pool_processing(load_image, images)
57
+ if self.preprocessing_pipeline is not None:
58
+ masks: List[Image.Image] = self.preprocessing_pipeline(
59
+ interface=self, images=images
60
+ )
61
+ else:
62
+ masks: List[Image.Image] = self.segmentation_pipeline(images=images)
63
+
64
+ if self.postprocessing_pipeline is not None:
65
+ images: List[Image.Image] = self.postprocessing_pipeline(
66
+ images=images, masks=masks
67
+ )
68
+ else:
69
+ images = list(
70
+ map(
71
+ lambda x: apply_mask(
72
+ image=images[x], mask=masks[x], device=self.device
73
+ ),
74
+ range(len(images)),
75
+ )
76
+ )
77
+ return images
carvekit/ml/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from carvekit.utils.models_utils import fix_seed, suppress_warnings
2
+
3
+ fix_seed()
4
+ suppress_warnings()
carvekit/ml/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (277 Bytes). View file
 
carvekit/ml/arch/__init__.py ADDED
File without changes
carvekit/ml/arch/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (178 Bytes). View file
 
carvekit/ml/arch/basnet/__init__.py ADDED
File without changes
carvekit/ml/arch/basnet/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (185 Bytes). View file
 
carvekit/ml/arch/basnet/__pycache__/basnet.cpython-38.pyc ADDED
Binary file (10 kB). View file
 
carvekit/ml/arch/basnet/basnet.py ADDED
@@ -0,0 +1,478 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Source url: https://github.com/NathanUA/BASNet
3
+ Modified by Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
4
+ License: MIT License
5
+ """
6
+ import torch
7
+ import torch.nn as nn
8
+ from torchvision import models
9
+
10
+
11
+ def conv3x3(in_planes, out_planes, stride=1):
12
+ """3x3 convolution with padding"""
13
+ return nn.Conv2d(
14
+ in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False
15
+ )
16
+
17
+
18
+ class BasicBlock(nn.Module):
19
+ expansion = 1
20
+
21
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
22
+ super(BasicBlock, self).__init__()
23
+ self.conv1 = conv3x3(inplanes, planes, stride)
24
+ self.bn1 = nn.BatchNorm2d(planes)
25
+ self.relu = nn.ReLU(inplace=True)
26
+ self.conv2 = conv3x3(planes, planes)
27
+ self.bn2 = nn.BatchNorm2d(planes)
28
+ self.downsample = downsample
29
+ self.stride = stride
30
+
31
+ def forward(self, x):
32
+ residual = x
33
+
34
+ out = self.conv1(x)
35
+ out = self.bn1(out)
36
+ out = self.relu(out)
37
+
38
+ out = self.conv2(out)
39
+ out = self.bn2(out)
40
+
41
+ if self.downsample is not None:
42
+ residual = self.downsample(x)
43
+
44
+ out += residual
45
+ out = self.relu(out)
46
+
47
+ return out
48
+
49
+
50
+ class BasicBlockDe(nn.Module):
51
+ expansion = 1
52
+
53
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
54
+ super(BasicBlockDe, self).__init__()
55
+
56
+ self.convRes = conv3x3(inplanes, planes, stride)
57
+ self.bnRes = nn.BatchNorm2d(planes)
58
+ self.reluRes = nn.ReLU(inplace=True)
59
+
60
+ self.conv1 = conv3x3(inplanes, planes, stride)
61
+ self.bn1 = nn.BatchNorm2d(planes)
62
+ self.relu = nn.ReLU(inplace=True)
63
+ self.conv2 = conv3x3(planes, planes)
64
+ self.bn2 = nn.BatchNorm2d(planes)
65
+ self.downsample = downsample
66
+ self.stride = stride
67
+
68
+ def forward(self, x):
69
+ residual = self.convRes(x)
70
+ residual = self.bnRes(residual)
71
+ residual = self.reluRes(residual)
72
+
73
+ out = self.conv1(x)
74
+ out = self.bn1(out)
75
+ out = self.relu(out)
76
+
77
+ out = self.conv2(out)
78
+ out = self.bn2(out)
79
+
80
+ if self.downsample is not None:
81
+ residual = self.downsample(x)
82
+
83
+ out += residual
84
+ out = self.relu(out)
85
+
86
+ return out
87
+
88
+
89
+ class Bottleneck(nn.Module):
90
+ expansion = 4
91
+
92
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
93
+ super(Bottleneck, self).__init__()
94
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
95
+ self.bn1 = nn.BatchNorm2d(planes)
96
+ self.conv2 = nn.Conv2d(
97
+ planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
98
+ )
99
+ self.bn2 = nn.BatchNorm2d(planes)
100
+ self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
101
+ self.bn3 = nn.BatchNorm2d(planes * 4)
102
+ self.relu = nn.ReLU(inplace=True)
103
+ self.downsample = downsample
104
+ self.stride = stride
105
+
106
+ def forward(self, x):
107
+ residual = x
108
+
109
+ out = self.conv1(x)
110
+ out = self.bn1(out)
111
+ out = self.relu(out)
112
+
113
+ out = self.conv2(out)
114
+ out = self.bn2(out)
115
+ out = self.relu(out)
116
+
117
+ out = self.conv3(out)
118
+ out = self.bn3(out)
119
+
120
+ if self.downsample is not None:
121
+ residual = self.downsample(x)
122
+
123
+ out += residual
124
+ out = self.relu(out)
125
+
126
+ return out
127
+
128
+
129
+ class RefUnet(nn.Module):
130
+ def __init__(self, in_ch, inc_ch):
131
+ super(RefUnet, self).__init__()
132
+
133
+ self.conv0 = nn.Conv2d(in_ch, inc_ch, 3, padding=1)
134
+
135
+ self.conv1 = nn.Conv2d(inc_ch, 64, 3, padding=1)
136
+ self.bn1 = nn.BatchNorm2d(64)
137
+ self.relu1 = nn.ReLU(inplace=True)
138
+
139
+ self.pool1 = nn.MaxPool2d(2, 2, ceil_mode=True)
140
+
141
+ self.conv2 = nn.Conv2d(64, 64, 3, padding=1)
142
+ self.bn2 = nn.BatchNorm2d(64)
143
+ self.relu2 = nn.ReLU(inplace=True)
144
+
145
+ self.pool2 = nn.MaxPool2d(2, 2, ceil_mode=True)
146
+
147
+ self.conv3 = nn.Conv2d(64, 64, 3, padding=1)
148
+ self.bn3 = nn.BatchNorm2d(64)
149
+ self.relu3 = nn.ReLU(inplace=True)
150
+
151
+ self.pool3 = nn.MaxPool2d(2, 2, ceil_mode=True)
152
+
153
+ self.conv4 = nn.Conv2d(64, 64, 3, padding=1)
154
+ self.bn4 = nn.BatchNorm2d(64)
155
+ self.relu4 = nn.ReLU(inplace=True)
156
+
157
+ self.pool4 = nn.MaxPool2d(2, 2, ceil_mode=True)
158
+
159
+ self.conv5 = nn.Conv2d(64, 64, 3, padding=1)
160
+ self.bn5 = nn.BatchNorm2d(64)
161
+ self.relu5 = nn.ReLU(inplace=True)
162
+
163
+ self.conv_d4 = nn.Conv2d(128, 64, 3, padding=1)
164
+ self.bn_d4 = nn.BatchNorm2d(64)
165
+ self.relu_d4 = nn.ReLU(inplace=True)
166
+
167
+ self.conv_d3 = nn.Conv2d(128, 64, 3, padding=1)
168
+ self.bn_d3 = nn.BatchNorm2d(64)
169
+ self.relu_d3 = nn.ReLU(inplace=True)
170
+
171
+ self.conv_d2 = nn.Conv2d(128, 64, 3, padding=1)
172
+ self.bn_d2 = nn.BatchNorm2d(64)
173
+ self.relu_d2 = nn.ReLU(inplace=True)
174
+
175
+ self.conv_d1 = nn.Conv2d(128, 64, 3, padding=1)
176
+ self.bn_d1 = nn.BatchNorm2d(64)
177
+ self.relu_d1 = nn.ReLU(inplace=True)
178
+
179
+ self.conv_d0 = nn.Conv2d(64, 1, 3, padding=1)
180
+
181
+ self.upscore2 = nn.Upsample(
182
+ scale_factor=2, mode="bilinear", align_corners=False
183
+ )
184
+
185
+ def forward(self, x):
186
+ hx = x
187
+ hx = self.conv0(hx)
188
+
189
+ hx1 = self.relu1(self.bn1(self.conv1(hx)))
190
+ hx = self.pool1(hx1)
191
+
192
+ hx2 = self.relu2(self.bn2(self.conv2(hx)))
193
+ hx = self.pool2(hx2)
194
+
195
+ hx3 = self.relu3(self.bn3(self.conv3(hx)))
196
+ hx = self.pool3(hx3)
197
+
198
+ hx4 = self.relu4(self.bn4(self.conv4(hx)))
199
+ hx = self.pool4(hx4)
200
+
201
+ hx5 = self.relu5(self.bn5(self.conv5(hx)))
202
+
203
+ hx = self.upscore2(hx5)
204
+
205
+ d4 = self.relu_d4(self.bn_d4(self.conv_d4(torch.cat((hx, hx4), 1))))
206
+ hx = self.upscore2(d4)
207
+
208
+ d3 = self.relu_d3(self.bn_d3(self.conv_d3(torch.cat((hx, hx3), 1))))
209
+ hx = self.upscore2(d3)
210
+
211
+ d2 = self.relu_d2(self.bn_d2(self.conv_d2(torch.cat((hx, hx2), 1))))
212
+ hx = self.upscore2(d2)
213
+
214
+ d1 = self.relu_d1(self.bn_d1(self.conv_d1(torch.cat((hx, hx1), 1))))
215
+
216
+ residual = self.conv_d0(d1)
217
+
218
+ return x + residual
219
+
220
+
221
+ class BASNet(nn.Module):
222
+ def __init__(self, n_channels, n_classes):
223
+ super(BASNet, self).__init__()
224
+
225
+ resnet = models.resnet34(pretrained=False)
226
+
227
+ # -------------Encoder--------------
228
+
229
+ self.inconv = nn.Conv2d(n_channels, 64, 3, padding=1)
230
+ self.inbn = nn.BatchNorm2d(64)
231
+ self.inrelu = nn.ReLU(inplace=True)
232
+
233
+ # stage 1
234
+ self.encoder1 = resnet.layer1 # 224
235
+ # stage 2
236
+ self.encoder2 = resnet.layer2 # 112
237
+ # stage 3
238
+ self.encoder3 = resnet.layer3 # 56
239
+ # stage 4
240
+ self.encoder4 = resnet.layer4 # 28
241
+
242
+ self.pool4 = nn.MaxPool2d(2, 2, ceil_mode=True)
243
+
244
+ # stage 5
245
+ self.resb5_1 = BasicBlock(512, 512)
246
+ self.resb5_2 = BasicBlock(512, 512)
247
+ self.resb5_3 = BasicBlock(512, 512) # 14
248
+
249
+ self.pool5 = nn.MaxPool2d(2, 2, ceil_mode=True)
250
+
251
+ # stage 6
252
+ self.resb6_1 = BasicBlock(512, 512)
253
+ self.resb6_2 = BasicBlock(512, 512)
254
+ self.resb6_3 = BasicBlock(512, 512) # 7
255
+
256
+ # -------------Bridge--------------
257
+
258
+ # stage Bridge
259
+ self.convbg_1 = nn.Conv2d(512, 512, 3, dilation=2, padding=2) # 7
260
+ self.bnbg_1 = nn.BatchNorm2d(512)
261
+ self.relubg_1 = nn.ReLU(inplace=True)
262
+ self.convbg_m = nn.Conv2d(512, 512, 3, dilation=2, padding=2)
263
+ self.bnbg_m = nn.BatchNorm2d(512)
264
+ self.relubg_m = nn.ReLU(inplace=True)
265
+ self.convbg_2 = nn.Conv2d(512, 512, 3, dilation=2, padding=2)
266
+ self.bnbg_2 = nn.BatchNorm2d(512)
267
+ self.relubg_2 = nn.ReLU(inplace=True)
268
+
269
+ # -------------Decoder--------------
270
+
271
+ # stage 6d
272
+ self.conv6d_1 = nn.Conv2d(1024, 512, 3, padding=1) # 16
273
+ self.bn6d_1 = nn.BatchNorm2d(512)
274
+ self.relu6d_1 = nn.ReLU(inplace=True)
275
+
276
+ self.conv6d_m = nn.Conv2d(512, 512, 3, dilation=2, padding=2)
277
+ self.bn6d_m = nn.BatchNorm2d(512)
278
+ self.relu6d_m = nn.ReLU(inplace=True)
279
+
280
+ self.conv6d_2 = nn.Conv2d(512, 512, 3, dilation=2, padding=2)
281
+ self.bn6d_2 = nn.BatchNorm2d(512)
282
+ self.relu6d_2 = nn.ReLU(inplace=True)
283
+
284
+ # stage 5d
285
+ self.conv5d_1 = nn.Conv2d(1024, 512, 3, padding=1) # 16
286
+ self.bn5d_1 = nn.BatchNorm2d(512)
287
+ self.relu5d_1 = nn.ReLU(inplace=True)
288
+
289
+ self.conv5d_m = nn.Conv2d(512, 512, 3, padding=1)
290
+ self.bn5d_m = nn.BatchNorm2d(512)
291
+ self.relu5d_m = nn.ReLU(inplace=True)
292
+
293
+ self.conv5d_2 = nn.Conv2d(512, 512, 3, padding=1)
294
+ self.bn5d_2 = nn.BatchNorm2d(512)
295
+ self.relu5d_2 = nn.ReLU(inplace=True)
296
+
297
+ # stage 4d
298
+ self.conv4d_1 = nn.Conv2d(1024, 512, 3, padding=1) # 32
299
+ self.bn4d_1 = nn.BatchNorm2d(512)
300
+ self.relu4d_1 = nn.ReLU(inplace=True)
301
+
302
+ self.conv4d_m = nn.Conv2d(512, 512, 3, padding=1)
303
+ self.bn4d_m = nn.BatchNorm2d(512)
304
+ self.relu4d_m = nn.ReLU(inplace=True)
305
+
306
+ self.conv4d_2 = nn.Conv2d(512, 256, 3, padding=1)
307
+ self.bn4d_2 = nn.BatchNorm2d(256)
308
+ self.relu4d_2 = nn.ReLU(inplace=True)
309
+
310
+ # stage 3d
311
+ self.conv3d_1 = nn.Conv2d(512, 256, 3, padding=1) # 64
312
+ self.bn3d_1 = nn.BatchNorm2d(256)
313
+ self.relu3d_1 = nn.ReLU(inplace=True)
314
+
315
+ self.conv3d_m = nn.Conv2d(256, 256, 3, padding=1)
316
+ self.bn3d_m = nn.BatchNorm2d(256)
317
+ self.relu3d_m = nn.ReLU(inplace=True)
318
+
319
+ self.conv3d_2 = nn.Conv2d(256, 128, 3, padding=1)
320
+ self.bn3d_2 = nn.BatchNorm2d(128)
321
+ self.relu3d_2 = nn.ReLU(inplace=True)
322
+
323
+ # stage 2d
324
+
325
+ self.conv2d_1 = nn.Conv2d(256, 128, 3, padding=1) # 128
326
+ self.bn2d_1 = nn.BatchNorm2d(128)
327
+ self.relu2d_1 = nn.ReLU(inplace=True)
328
+
329
+ self.conv2d_m = nn.Conv2d(128, 128, 3, padding=1)
330
+ self.bn2d_m = nn.BatchNorm2d(128)
331
+ self.relu2d_m = nn.ReLU(inplace=True)
332
+
333
+ self.conv2d_2 = nn.Conv2d(128, 64, 3, padding=1)
334
+ self.bn2d_2 = nn.BatchNorm2d(64)
335
+ self.relu2d_2 = nn.ReLU(inplace=True)
336
+
337
+ # stage 1d
338
+ self.conv1d_1 = nn.Conv2d(128, 64, 3, padding=1) # 256
339
+ self.bn1d_1 = nn.BatchNorm2d(64)
340
+ self.relu1d_1 = nn.ReLU(inplace=True)
341
+
342
+ self.conv1d_m = nn.Conv2d(64, 64, 3, padding=1)
343
+ self.bn1d_m = nn.BatchNorm2d(64)
344
+ self.relu1d_m = nn.ReLU(inplace=True)
345
+
346
+ self.conv1d_2 = nn.Conv2d(64, 64, 3, padding=1)
347
+ self.bn1d_2 = nn.BatchNorm2d(64)
348
+ self.relu1d_2 = nn.ReLU(inplace=True)
349
+
350
+ # -------------Bilinear Upsampling--------------
351
+ self.upscore6 = nn.Upsample(
352
+ scale_factor=32, mode="bilinear", align_corners=False
353
+ )
354
+ self.upscore5 = nn.Upsample(
355
+ scale_factor=16, mode="bilinear", align_corners=False
356
+ )
357
+ self.upscore4 = nn.Upsample(
358
+ scale_factor=8, mode="bilinear", align_corners=False
359
+ )
360
+ self.upscore3 = nn.Upsample(
361
+ scale_factor=4, mode="bilinear", align_corners=False
362
+ )
363
+ self.upscore2 = nn.Upsample(
364
+ scale_factor=2, mode="bilinear", align_corners=False
365
+ )
366
+
367
+ # -------------Side Output--------------
368
+ self.outconvb = nn.Conv2d(512, 1, 3, padding=1)
369
+ self.outconv6 = nn.Conv2d(512, 1, 3, padding=1)
370
+ self.outconv5 = nn.Conv2d(512, 1, 3, padding=1)
371
+ self.outconv4 = nn.Conv2d(256, 1, 3, padding=1)
372
+ self.outconv3 = nn.Conv2d(128, 1, 3, padding=1)
373
+ self.outconv2 = nn.Conv2d(64, 1, 3, padding=1)
374
+ self.outconv1 = nn.Conv2d(64, 1, 3, padding=1)
375
+
376
+ # -------------Refine Module-------------
377
+ self.refunet = RefUnet(1, 64)
378
+
379
+ def forward(self, x):
380
+ hx = x
381
+
382
+ # -------------Encoder-------------
383
+ hx = self.inconv(hx)
384
+ hx = self.inbn(hx)
385
+ hx = self.inrelu(hx)
386
+
387
+ h1 = self.encoder1(hx) # 256
388
+ h2 = self.encoder2(h1) # 128
389
+ h3 = self.encoder3(h2) # 64
390
+ h4 = self.encoder4(h3) # 32
391
+
392
+ hx = self.pool4(h4) # 16
393
+
394
+ hx = self.resb5_1(hx)
395
+ hx = self.resb5_2(hx)
396
+ h5 = self.resb5_3(hx)
397
+
398
+ hx = self.pool5(h5) # 8
399
+
400
+ hx = self.resb6_1(hx)
401
+ hx = self.resb6_2(hx)
402
+ h6 = self.resb6_3(hx)
403
+
404
+ # -------------Bridge-------------
405
+ hx = self.relubg_1(self.bnbg_1(self.convbg_1(h6))) # 8
406
+ hx = self.relubg_m(self.bnbg_m(self.convbg_m(hx)))
407
+ hbg = self.relubg_2(self.bnbg_2(self.convbg_2(hx)))
408
+
409
+ # -------------Decoder-------------
410
+
411
+ hx = self.relu6d_1(self.bn6d_1(self.conv6d_1(torch.cat((hbg, h6), 1))))
412
+ hx = self.relu6d_m(self.bn6d_m(self.conv6d_m(hx)))
413
+ hd6 = self.relu6d_2(self.bn6d_2(self.conv6d_2(hx)))
414
+
415
+ hx = self.upscore2(hd6) # 8 -> 16
416
+
417
+ hx = self.relu5d_1(self.bn5d_1(self.conv5d_1(torch.cat((hx, h5), 1))))
418
+ hx = self.relu5d_m(self.bn5d_m(self.conv5d_m(hx)))
419
+ hd5 = self.relu5d_2(self.bn5d_2(self.conv5d_2(hx)))
420
+
421
+ hx = self.upscore2(hd5) # 16 -> 32
422
+
423
+ hx = self.relu4d_1(self.bn4d_1(self.conv4d_1(torch.cat((hx, h4), 1))))
424
+ hx = self.relu4d_m(self.bn4d_m(self.conv4d_m(hx)))
425
+ hd4 = self.relu4d_2(self.bn4d_2(self.conv4d_2(hx)))
426
+
427
+ hx = self.upscore2(hd4) # 32 -> 64
428
+
429
+ hx = self.relu3d_1(self.bn3d_1(self.conv3d_1(torch.cat((hx, h3), 1))))
430
+ hx = self.relu3d_m(self.bn3d_m(self.conv3d_m(hx)))
431
+ hd3 = self.relu3d_2(self.bn3d_2(self.conv3d_2(hx)))
432
+
433
+ hx = self.upscore2(hd3) # 64 -> 128
434
+
435
+ hx = self.relu2d_1(self.bn2d_1(self.conv2d_1(torch.cat((hx, h2), 1))))
436
+ hx = self.relu2d_m(self.bn2d_m(self.conv2d_m(hx)))
437
+ hd2 = self.relu2d_2(self.bn2d_2(self.conv2d_2(hx)))
438
+
439
+ hx = self.upscore2(hd2) # 128 -> 256
440
+
441
+ hx = self.relu1d_1(self.bn1d_1(self.conv1d_1(torch.cat((hx, h1), 1))))
442
+ hx = self.relu1d_m(self.bn1d_m(self.conv1d_m(hx)))
443
+ hd1 = self.relu1d_2(self.bn1d_2(self.conv1d_2(hx)))
444
+
445
+ # -------------Side Output-------------
446
+ db = self.outconvb(hbg)
447
+ db = self.upscore6(db) # 8->256
448
+
449
+ d6 = self.outconv6(hd6)
450
+ d6 = self.upscore6(d6) # 8->256
451
+
452
+ d5 = self.outconv5(hd5)
453
+ d5 = self.upscore5(d5) # 16->256
454
+
455
+ d4 = self.outconv4(hd4)
456
+ d4 = self.upscore4(d4) # 32->256
457
+
458
+ d3 = self.outconv3(hd3)
459
+ d3 = self.upscore3(d3) # 64->256
460
+
461
+ d2 = self.outconv2(hd2)
462
+ d2 = self.upscore2(d2) # 128->256
463
+
464
+ d1 = self.outconv1(hd1) # 256
465
+
466
+ # -------------Refine Module-------------
467
+ dout = self.refunet(d1) # 256
468
+
469
+ return (
470
+ torch.sigmoid(dout),
471
+ torch.sigmoid(d1),
472
+ torch.sigmoid(d2),
473
+ torch.sigmoid(d3),
474
+ torch.sigmoid(d4),
475
+ torch.sigmoid(d5),
476
+ torch.sigmoid(d6),
477
+ torch.sigmoid(db),
478
+ )
carvekit/ml/arch/fba_matting/__init__.py ADDED
File without changes
carvekit/ml/arch/fba_matting/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (190 Bytes). View file
 
carvekit/ml/arch/fba_matting/__pycache__/layers_WS.cpython-38.pyc ADDED
Binary file (1.6 kB). View file
 
carvekit/ml/arch/fba_matting/__pycache__/models.cpython-38.pyc ADDED
Binary file (8.24 kB). View file
 
carvekit/ml/arch/fba_matting/__pycache__/resnet_GN_WS.cpython-38.pyc ADDED
Binary file (4.45 kB). View file
 
carvekit/ml/arch/fba_matting/__pycache__/resnet_bn.cpython-38.pyc ADDED
Binary file (4.69 kB). View file
 
carvekit/ml/arch/fba_matting/__pycache__/transforms.cpython-38.pyc ADDED
Binary file (1.58 kB). View file
 
carvekit/ml/arch/fba_matting/layers_WS.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Modified by Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
3
+ Source url: https://github.com/MarcoForte/FBA_Matting
4
+ License: MIT License
5
+ """
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch.nn import functional as F
9
+
10
+
11
+ class Conv2d(nn.Conv2d):
12
+ def __init__(
13
+ self,
14
+ in_channels,
15
+ out_channels,
16
+ kernel_size,
17
+ stride=1,
18
+ padding=0,
19
+ dilation=1,
20
+ groups=1,
21
+ bias=True,
22
+ ):
23
+ super(Conv2d, self).__init__(
24
+ in_channels,
25
+ out_channels,
26
+ kernel_size,
27
+ stride,
28
+ padding,
29
+ dilation,
30
+ groups,
31
+ bias,
32
+ )
33
+
34
+ def forward(self, x):
35
+ # return super(Conv2d, self).forward(x)
36
+ weight = self.weight
37
+ weight_mean = (
38
+ weight.mean(dim=1, keepdim=True)
39
+ .mean(dim=2, keepdim=True)
40
+ .mean(dim=3, keepdim=True)
41
+ )
42
+ weight = weight - weight_mean
43
+ # std = (weight).view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + 1e-5
44
+ std = (
45
+ torch.sqrt(torch.var(weight.view(weight.size(0), -1), dim=1) + 1e-12).view(
46
+ -1, 1, 1, 1
47
+ )
48
+ + 1e-5
49
+ )
50
+ weight = weight / std.expand_as(weight)
51
+ return F.conv2d(
52
+ x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups
53
+ )
54
+
55
+
56
+ def BatchNorm2d(num_features):
57
+ return nn.GroupNorm(num_channels=num_features, num_groups=32)
carvekit/ml/arch/fba_matting/models.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Modified by Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
3
+ Source url: https://github.com/MarcoForte/FBA_Matting
4
+ License: MIT License
5
+ """
6
+ import torch
7
+ import torch.nn as nn
8
+ import carvekit.ml.arch.fba_matting.resnet_GN_WS as resnet_GN_WS
9
+ import carvekit.ml.arch.fba_matting.layers_WS as L
10
+ import carvekit.ml.arch.fba_matting.resnet_bn as resnet_bn
11
+ from functools import partial
12
+
13
+
14
+ class FBA(nn.Module):
15
+ def __init__(self, encoder: str):
16
+ super(FBA, self).__init__()
17
+ self.encoder = build_encoder(arch=encoder)
18
+ self.decoder = fba_decoder(batch_norm=True if "BN" in encoder else False)
19
+
20
+ def forward(self, image, two_chan_trimap, image_n, trimap_transformed):
21
+ resnet_input = torch.cat((image_n, trimap_transformed, two_chan_trimap), 1)
22
+ conv_out, indices = self.encoder(resnet_input, return_feature_maps=True)
23
+ return self.decoder(conv_out, image, indices, two_chan_trimap)
24
+
25
+
26
+ class ResnetDilatedBN(nn.Module):
27
+ def __init__(self, orig_resnet, dilate_scale=8):
28
+ super(ResnetDilatedBN, self).__init__()
29
+
30
+ if dilate_scale == 8:
31
+ orig_resnet.layer3.apply(partial(self._nostride_dilate, dilate=2))
32
+ orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=4))
33
+ elif dilate_scale == 16:
34
+ orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=2))
35
+
36
+ # take pretrained resnet, except AvgPool and FC
37
+ self.conv1 = orig_resnet.conv1
38
+ self.bn1 = orig_resnet.bn1
39
+ self.relu1 = orig_resnet.relu1
40
+ self.conv2 = orig_resnet.conv2
41
+ self.bn2 = orig_resnet.bn2
42
+ self.relu2 = orig_resnet.relu2
43
+ self.conv3 = orig_resnet.conv3
44
+ self.bn3 = orig_resnet.bn3
45
+ self.relu3 = orig_resnet.relu3
46
+ self.maxpool = orig_resnet.maxpool
47
+ self.layer1 = orig_resnet.layer1
48
+ self.layer2 = orig_resnet.layer2
49
+ self.layer3 = orig_resnet.layer3
50
+ self.layer4 = orig_resnet.layer4
51
+
52
+ def _nostride_dilate(self, m, dilate):
53
+ classname = m.__class__.__name__
54
+ if classname.find("Conv") != -1:
55
+ # the convolution with stride
56
+ if m.stride == (2, 2):
57
+ m.stride = (1, 1)
58
+ if m.kernel_size == (3, 3):
59
+ m.dilation = (dilate // 2, dilate // 2)
60
+ m.padding = (dilate // 2, dilate // 2)
61
+ # other convoluions
62
+ else:
63
+ if m.kernel_size == (3, 3):
64
+ m.dilation = (dilate, dilate)
65
+ m.padding = (dilate, dilate)
66
+
67
+ def forward(self, x, return_feature_maps=False):
68
+ conv_out = [x]
69
+ x = self.relu1(self.bn1(self.conv1(x)))
70
+ x = self.relu2(self.bn2(self.conv2(x)))
71
+ x = self.relu3(self.bn3(self.conv3(x)))
72
+ conv_out.append(x)
73
+ x, indices = self.maxpool(x)
74
+ x = self.layer1(x)
75
+ conv_out.append(x)
76
+ x = self.layer2(x)
77
+ conv_out.append(x)
78
+ x = self.layer3(x)
79
+ conv_out.append(x)
80
+ x = self.layer4(x)
81
+ conv_out.append(x)
82
+
83
+ if return_feature_maps:
84
+ return conv_out, indices
85
+ return [x]
86
+
87
+
88
+ class Resnet(nn.Module):
89
+ def __init__(self, orig_resnet):
90
+ super(Resnet, self).__init__()
91
+
92
+ # take pretrained resnet, except AvgPool and FC
93
+ self.conv1 = orig_resnet.conv1
94
+ self.bn1 = orig_resnet.bn1
95
+ self.relu1 = orig_resnet.relu1
96
+ self.conv2 = orig_resnet.conv2
97
+ self.bn2 = orig_resnet.bn2
98
+ self.relu2 = orig_resnet.relu2
99
+ self.conv3 = orig_resnet.conv3
100
+ self.bn3 = orig_resnet.bn3
101
+ self.relu3 = orig_resnet.relu3
102
+ self.maxpool = orig_resnet.maxpool
103
+ self.layer1 = orig_resnet.layer1
104
+ self.layer2 = orig_resnet.layer2
105
+ self.layer3 = orig_resnet.layer3
106
+ self.layer4 = orig_resnet.layer4
107
+
108
+ def forward(self, x, return_feature_maps=False):
109
+ conv_out = []
110
+
111
+ x = self.relu1(self.bn1(self.conv1(x)))
112
+ x = self.relu2(self.bn2(self.conv2(x)))
113
+ x = self.relu3(self.bn3(self.conv3(x)))
114
+ conv_out.append(x)
115
+ x, indices = self.maxpool(x)
116
+
117
+ x = self.layer1(x)
118
+ conv_out.append(x)
119
+ x = self.layer2(x)
120
+ conv_out.append(x)
121
+ x = self.layer3(x)
122
+ conv_out.append(x)
123
+ x = self.layer4(x)
124
+ conv_out.append(x)
125
+
126
+ if return_feature_maps:
127
+ return conv_out
128
+ return [x]
129
+
130
+
131
+ class ResnetDilated(nn.Module):
132
+ def __init__(self, orig_resnet, dilate_scale=8):
133
+ super(ResnetDilated, self).__init__()
134
+
135
+ if dilate_scale == 8:
136
+ orig_resnet.layer3.apply(partial(self._nostride_dilate, dilate=2))
137
+ orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=4))
138
+ elif dilate_scale == 16:
139
+ orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=2))
140
+
141
+ # take pretrained resnet, except AvgPool and FC
142
+ self.conv1 = orig_resnet.conv1
143
+ self.bn1 = orig_resnet.bn1
144
+ self.relu = orig_resnet.relu
145
+ self.maxpool = orig_resnet.maxpool
146
+ self.layer1 = orig_resnet.layer1
147
+ self.layer2 = orig_resnet.layer2
148
+ self.layer3 = orig_resnet.layer3
149
+ self.layer4 = orig_resnet.layer4
150
+
151
+ def _nostride_dilate(self, m, dilate):
152
+ classname = m.__class__.__name__
153
+ if classname.find("Conv") != -1:
154
+ # the convolution with stride
155
+ if m.stride == (2, 2):
156
+ m.stride = (1, 1)
157
+ if m.kernel_size == (3, 3):
158
+ m.dilation = (dilate // 2, dilate // 2)
159
+ m.padding = (dilate // 2, dilate // 2)
160
+ # other convoluions
161
+ else:
162
+ if m.kernel_size == (3, 3):
163
+ m.dilation = (dilate, dilate)
164
+ m.padding = (dilate, dilate)
165
+
166
+ def forward(self, x, return_feature_maps=False):
167
+ conv_out = [x]
168
+ x = self.relu(self.bn1(self.conv1(x)))
169
+ conv_out.append(x)
170
+ x, indices = self.maxpool(x)
171
+ x = self.layer1(x)
172
+ conv_out.append(x)
173
+ x = self.layer2(x)
174
+ conv_out.append(x)
175
+ x = self.layer3(x)
176
+ conv_out.append(x)
177
+ x = self.layer4(x)
178
+ conv_out.append(x)
179
+
180
+ if return_feature_maps:
181
+ return conv_out, indices
182
+ return [x]
183
+
184
+
185
+ def norm(dim, bn=False):
186
+ if bn is False:
187
+ return nn.GroupNorm(32, dim)
188
+ else:
189
+ return nn.BatchNorm2d(dim)
190
+
191
+
192
+ def fba_fusion(alpha, img, F, B):
193
+ F = alpha * img + (1 - alpha**2) * F - alpha * (1 - alpha) * B
194
+ B = (1 - alpha) * img + (2 * alpha - alpha**2) * B - alpha * (1 - alpha) * F
195
+
196
+ F = torch.clamp(F, 0, 1)
197
+ B = torch.clamp(B, 0, 1)
198
+ la = 0.1
199
+ alpha = (alpha * la + torch.sum((img - B) * (F - B), 1, keepdim=True)) / (
200
+ torch.sum((F - B) * (F - B), 1, keepdim=True) + la
201
+ )
202
+ alpha = torch.clamp(alpha, 0, 1)
203
+ return alpha, F, B
204
+
205
+
206
+ class fba_decoder(nn.Module):
207
+ def __init__(self, batch_norm=False):
208
+ super(fba_decoder, self).__init__()
209
+ pool_scales = (1, 2, 3, 6)
210
+ self.batch_norm = batch_norm
211
+
212
+ self.ppm = []
213
+
214
+ for scale in pool_scales:
215
+ self.ppm.append(
216
+ nn.Sequential(
217
+ nn.AdaptiveAvgPool2d(scale),
218
+ L.Conv2d(2048, 256, kernel_size=1, bias=True),
219
+ norm(256, self.batch_norm),
220
+ nn.LeakyReLU(),
221
+ )
222
+ )
223
+ self.ppm = nn.ModuleList(self.ppm)
224
+
225
+ self.conv_up1 = nn.Sequential(
226
+ L.Conv2d(
227
+ 2048 + len(pool_scales) * 256, 256, kernel_size=3, padding=1, bias=True
228
+ ),
229
+ norm(256, self.batch_norm),
230
+ nn.LeakyReLU(),
231
+ L.Conv2d(256, 256, kernel_size=3, padding=1),
232
+ norm(256, self.batch_norm),
233
+ nn.LeakyReLU(),
234
+ )
235
+
236
+ self.conv_up2 = nn.Sequential(
237
+ L.Conv2d(256 + 256, 256, kernel_size=3, padding=1, bias=True),
238
+ norm(256, self.batch_norm),
239
+ nn.LeakyReLU(),
240
+ )
241
+ if self.batch_norm:
242
+ d_up3 = 128
243
+ else:
244
+ d_up3 = 64
245
+ self.conv_up3 = nn.Sequential(
246
+ L.Conv2d(256 + d_up3, 64, kernel_size=3, padding=1, bias=True),
247
+ norm(64, self.batch_norm),
248
+ nn.LeakyReLU(),
249
+ )
250
+
251
+ self.unpool = nn.MaxUnpool2d(2, stride=2)
252
+
253
+ self.conv_up4 = nn.Sequential(
254
+ nn.Conv2d(64 + 3 + 3 + 2, 32, kernel_size=3, padding=1, bias=True),
255
+ nn.LeakyReLU(),
256
+ nn.Conv2d(32, 16, kernel_size=3, padding=1, bias=True),
257
+ nn.LeakyReLU(),
258
+ nn.Conv2d(16, 7, kernel_size=1, padding=0, bias=True),
259
+ )
260
+
261
+ def forward(self, conv_out, img, indices, two_chan_trimap):
262
+ conv5 = conv_out[-1]
263
+
264
+ input_size = conv5.size()
265
+ ppm_out = [conv5]
266
+ for pool_scale in self.ppm:
267
+ ppm_out.append(
268
+ nn.functional.interpolate(
269
+ pool_scale(conv5),
270
+ (input_size[2], input_size[3]),
271
+ mode="bilinear",
272
+ align_corners=False,
273
+ )
274
+ )
275
+ ppm_out = torch.cat(ppm_out, 1)
276
+ x = self.conv_up1(ppm_out)
277
+
278
+ x = torch.nn.functional.interpolate(
279
+ x, scale_factor=2, mode="bilinear", align_corners=False
280
+ )
281
+
282
+ x = torch.cat((x, conv_out[-4]), 1)
283
+
284
+ x = self.conv_up2(x)
285
+ x = torch.nn.functional.interpolate(
286
+ x, scale_factor=2, mode="bilinear", align_corners=False
287
+ )
288
+
289
+ x = torch.cat((x, conv_out[-5]), 1)
290
+ x = self.conv_up3(x)
291
+
292
+ x = torch.nn.functional.interpolate(
293
+ x, scale_factor=2, mode="bilinear", align_corners=False
294
+ )
295
+ x = torch.cat((x, conv_out[-6][:, :3], img, two_chan_trimap), 1)
296
+
297
+ output = self.conv_up4(x)
298
+
299
+ alpha = torch.clamp(output[:, 0][:, None], 0, 1)
300
+ F = torch.sigmoid(output[:, 1:4])
301
+ B = torch.sigmoid(output[:, 4:7])
302
+
303
+ # FBA Fusion
304
+ alpha, F, B = fba_fusion(alpha, img, F, B)
305
+
306
+ output = torch.cat((alpha, F, B), 1)
307
+
308
+ return output
309
+
310
+
311
+ def build_encoder(arch="resnet50_GN"):
312
+ if arch == "resnet50_GN_WS":
313
+ orig_resnet = resnet_GN_WS.__dict__["l_resnet50"]()
314
+ net_encoder = ResnetDilated(orig_resnet, dilate_scale=8)
315
+ elif arch == "resnet50_BN":
316
+ orig_resnet = resnet_bn.__dict__["l_resnet50"]()
317
+ net_encoder = ResnetDilatedBN(orig_resnet, dilate_scale=8)
318
+
319
+ else:
320
+ raise ValueError("Architecture undefined!")
321
+
322
+ num_channels = 3 + 6 + 2
323
+
324
+ if num_channels > 3:
325
+ net_encoder_sd = net_encoder.state_dict()
326
+ conv1_weights = net_encoder_sd["conv1.weight"]
327
+
328
+ c_out, c_in, h, w = conv1_weights.size()
329
+ conv1_mod = torch.zeros(c_out, num_channels, h, w)
330
+ conv1_mod[:, :3, :, :] = conv1_weights
331
+
332
+ conv1 = net_encoder.conv1
333
+ conv1.in_channels = num_channels
334
+ conv1.weight = torch.nn.Parameter(conv1_mod)
335
+
336
+ net_encoder.conv1 = conv1
337
+
338
+ net_encoder_sd["conv1.weight"] = conv1_mod
339
+
340
+ net_encoder.load_state_dict(net_encoder_sd)
341
+ return net_encoder
carvekit/ml/arch/fba_matting/resnet_GN_WS.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Modified by Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
3
+ Source url: https://github.com/MarcoForte/FBA_Matting
4
+ License: MIT License
5
+ """
6
+ import torch.nn as nn
7
+ import carvekit.ml.arch.fba_matting.layers_WS as L
8
+
9
+ __all__ = ["ResNet", "l_resnet50"]
10
+
11
+
12
+ def conv3x3(in_planes, out_planes, stride=1):
13
+ """3x3 convolution with padding"""
14
+ return L.Conv2d(
15
+ in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False
16
+ )
17
+
18
+
19
+ def conv1x1(in_planes, out_planes, stride=1):
20
+ """1x1 convolution"""
21
+ return L.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
22
+
23
+
24
+ class BasicBlock(nn.Module):
25
+ expansion = 1
26
+
27
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
28
+ super(BasicBlock, self).__init__()
29
+ self.conv1 = conv3x3(inplanes, planes, stride)
30
+ self.bn1 = L.BatchNorm2d(planes)
31
+ self.relu = nn.ReLU(inplace=True)
32
+ self.conv2 = conv3x3(planes, planes)
33
+ self.bn2 = L.BatchNorm2d(planes)
34
+ self.downsample = downsample
35
+ self.stride = stride
36
+
37
+ def forward(self, x):
38
+ identity = x
39
+
40
+ out = self.conv1(x)
41
+ out = self.bn1(out)
42
+ out = self.relu(out)
43
+
44
+ out = self.conv2(out)
45
+ out = self.bn2(out)
46
+
47
+ if self.downsample is not None:
48
+ identity = self.downsample(x)
49
+
50
+ out += identity
51
+ out = self.relu(out)
52
+
53
+ return out
54
+
55
+
56
+ class Bottleneck(nn.Module):
57
+ expansion = 4
58
+
59
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
60
+ super(Bottleneck, self).__init__()
61
+ self.conv1 = conv1x1(inplanes, planes)
62
+ self.bn1 = L.BatchNorm2d(planes)
63
+ self.conv2 = conv3x3(planes, planes, stride)
64
+ self.bn2 = L.BatchNorm2d(planes)
65
+ self.conv3 = conv1x1(planes, planes * self.expansion)
66
+ self.bn3 = L.BatchNorm2d(planes * self.expansion)
67
+ self.relu = nn.ReLU(inplace=True)
68
+ self.downsample = downsample
69
+ self.stride = stride
70
+
71
+ def forward(self, x):
72
+ identity = x
73
+
74
+ out = self.conv1(x)
75
+ out = self.bn1(out)
76
+ out = self.relu(out)
77
+
78
+ out = self.conv2(out)
79
+ out = self.bn2(out)
80
+ out = self.relu(out)
81
+
82
+ out = self.conv3(out)
83
+ out = self.bn3(out)
84
+
85
+ if self.downsample is not None:
86
+ identity = self.downsample(x)
87
+
88
+ out += identity
89
+ out = self.relu(out)
90
+
91
+ return out
92
+
93
+
94
+ class ResNet(nn.Module):
95
+ def __init__(self, block, layers, num_classes=1000):
96
+ super(ResNet, self).__init__()
97
+ self.inplanes = 64
98
+ self.conv1 = L.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
99
+ self.bn1 = L.BatchNorm2d(64)
100
+ self.relu = nn.ReLU(inplace=True)
101
+ self.maxpool = nn.MaxPool2d(
102
+ kernel_size=3, stride=2, padding=1, return_indices=True
103
+ )
104
+ self.layer1 = self._make_layer(block, 64, layers[0])
105
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
106
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
107
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
108
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
109
+ self.fc = nn.Linear(512 * block.expansion, num_classes)
110
+
111
+ def _make_layer(self, block, planes, blocks, stride=1):
112
+ downsample = None
113
+ if stride != 1 or self.inplanes != planes * block.expansion:
114
+ downsample = nn.Sequential(
115
+ conv1x1(self.inplanes, planes * block.expansion, stride),
116
+ L.BatchNorm2d(planes * block.expansion),
117
+ )
118
+
119
+ layers = []
120
+ layers.append(block(self.inplanes, planes, stride, downsample))
121
+ self.inplanes = planes * block.expansion
122
+ for _ in range(1, blocks):
123
+ layers.append(block(self.inplanes, planes))
124
+
125
+ return nn.Sequential(*layers)
126
+
127
+ def forward(self, x):
128
+ x = self.conv1(x)
129
+ x = self.bn1(x)
130
+ x = self.relu(x)
131
+ x = self.maxpool(x)
132
+
133
+ x = self.layer1(x)
134
+ x = self.layer2(x)
135
+ x = self.layer3(x)
136
+ x = self.layer4(x)
137
+
138
+ x = self.avgpool(x)
139
+ x = x.view(x.size(0), -1)
140
+ x = self.fc(x)
141
+
142
+ return x
143
+
144
+
145
+ def l_resnet50(pretrained=False, **kwargs):
146
+ """Constructs a ResNet-50 model.
147
+ Args:
148
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
149
+ """
150
+ model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
151
+ return model
carvekit/ml/arch/fba_matting/resnet_bn.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Modified by Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
3
+ Source url: https://github.com/MarcoForte/FBA_Matting
4
+ License: MIT License
5
+ """
6
+ import torch.nn as nn
7
+ import math
8
+ from torch.nn import BatchNorm2d
9
+
10
+ __all__ = ["ResNet"]
11
+
12
+
13
+ def conv3x3(in_planes, out_planes, stride=1):
14
+ "3x3 convolution with padding"
15
+ return nn.Conv2d(
16
+ in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False
17
+ )
18
+
19
+
20
+ class BasicBlock(nn.Module):
21
+ expansion = 1
22
+
23
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
24
+ super(BasicBlock, self).__init__()
25
+ self.conv1 = conv3x3(inplanes, planes, stride)
26
+ self.bn1 = BatchNorm2d(planes)
27
+ self.relu = nn.ReLU(inplace=True)
28
+ self.conv2 = conv3x3(planes, planes)
29
+ self.bn2 = BatchNorm2d(planes)
30
+ self.downsample = downsample
31
+ self.stride = stride
32
+
33
+ def forward(self, x):
34
+ residual = x
35
+
36
+ out = self.conv1(x)
37
+ out = self.bn1(out)
38
+ out = self.relu(out)
39
+
40
+ out = self.conv2(out)
41
+ out = self.bn2(out)
42
+
43
+ if self.downsample is not None:
44
+ residual = self.downsample(x)
45
+
46
+ out += residual
47
+ out = self.relu(out)
48
+
49
+ return out
50
+
51
+
52
+ class Bottleneck(nn.Module):
53
+ expansion = 4
54
+
55
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
56
+ super(Bottleneck, self).__init__()
57
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
58
+ self.bn1 = BatchNorm2d(planes)
59
+ self.conv2 = nn.Conv2d(
60
+ planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
61
+ )
62
+ self.bn2 = BatchNorm2d(planes, momentum=0.01)
63
+ self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
64
+ self.bn3 = BatchNorm2d(planes * 4)
65
+ self.relu = nn.ReLU(inplace=True)
66
+ self.downsample = downsample
67
+ self.stride = stride
68
+
69
+ def forward(self, x):
70
+ residual = x
71
+
72
+ out = self.conv1(x)
73
+ out = self.bn1(out)
74
+ out = self.relu(out)
75
+
76
+ out = self.conv2(out)
77
+ out = self.bn2(out)
78
+ out = self.relu(out)
79
+
80
+ out = self.conv3(out)
81
+ out = self.bn3(out)
82
+
83
+ if self.downsample is not None:
84
+ residual = self.downsample(x)
85
+
86
+ out += residual
87
+ out = self.relu(out)
88
+
89
+ return out
90
+
91
+
92
+ class ResNet(nn.Module):
93
+ def __init__(self, block, layers, num_classes=1000):
94
+ self.inplanes = 128
95
+ super(ResNet, self).__init__()
96
+ self.conv1 = conv3x3(3, 64, stride=2)
97
+ self.bn1 = BatchNorm2d(64)
98
+ self.relu1 = nn.ReLU(inplace=True)
99
+ self.conv2 = conv3x3(64, 64)
100
+ self.bn2 = BatchNorm2d(64)
101
+ self.relu2 = nn.ReLU(inplace=True)
102
+ self.conv3 = conv3x3(64, 128)
103
+ self.bn3 = BatchNorm2d(128)
104
+ self.relu3 = nn.ReLU(inplace=True)
105
+ self.maxpool = nn.MaxPool2d(
106
+ kernel_size=3, stride=2, padding=1, return_indices=True
107
+ )
108
+
109
+ self.layer1 = self._make_layer(block, 64, layers[0])
110
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
111
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
112
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
113
+ self.avgpool = nn.AvgPool2d(7, stride=1)
114
+ self.fc = nn.Linear(512 * block.expansion, num_classes)
115
+
116
+ for m in self.modules():
117
+ if isinstance(m, nn.Conv2d):
118
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
119
+ m.weight.data.normal_(0, math.sqrt(2.0 / n))
120
+ elif isinstance(m, BatchNorm2d):
121
+ m.weight.data.fill_(1)
122
+ m.bias.data.zero_()
123
+
124
+ def _make_layer(self, block, planes, blocks, stride=1):
125
+ downsample = None
126
+ if stride != 1 or self.inplanes != planes * block.expansion:
127
+ downsample = nn.Sequential(
128
+ nn.Conv2d(
129
+ self.inplanes,
130
+ planes * block.expansion,
131
+ kernel_size=1,
132
+ stride=stride,
133
+ bias=False,
134
+ ),
135
+ BatchNorm2d(planes * block.expansion),
136
+ )
137
+
138
+ layers = []
139
+ layers.append(block(self.inplanes, planes, stride, downsample))
140
+ self.inplanes = planes * block.expansion
141
+ for i in range(1, blocks):
142
+ layers.append(block(self.inplanes, planes))
143
+
144
+ return nn.Sequential(*layers)
145
+
146
+ def forward(self, x):
147
+ x = self.relu1(self.bn1(self.conv1(x)))
148
+ x = self.relu2(self.bn2(self.conv2(x)))
149
+ x = self.relu3(self.bn3(self.conv3(x)))
150
+ x, indices = self.maxpool(x)
151
+
152
+ x = self.layer1(x)
153
+ x = self.layer2(x)
154
+ x = self.layer3(x)
155
+ x = self.layer4(x)
156
+
157
+ x = self.avgpool(x)
158
+ x = x.view(x.size(0), -1)
159
+ x = self.fc(x)
160
+ return x
161
+
162
+
163
+ def l_resnet50():
164
+ """Constructs a ResNet-50 model.
165
+ Args:
166
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
167
+ """
168
+ model = ResNet(Bottleneck, [3, 4, 6, 3])
169
+ return model
carvekit/ml/arch/fba_matting/transforms.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Modified by Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
3
+ Source url: https://github.com/MarcoForte/FBA_Matting
4
+ License: MIT License
5
+ """
6
+ import cv2
7
+ import numpy as np
8
+
9
+ group_norm_std = [0.229, 0.224, 0.225]
10
+ group_norm_mean = [0.485, 0.456, 0.406]
11
+
12
+
13
+ def dt(a):
14
+ return cv2.distanceTransform((a * 255).astype(np.uint8), cv2.DIST_L2, 0)
15
+
16
+
17
+ def trimap_transform(trimap):
18
+ h, w = trimap.shape[0], trimap.shape[1]
19
+
20
+ clicks = np.zeros((h, w, 6))
21
+ for k in range(2):
22
+ if np.count_nonzero(trimap[:, :, k]) > 0:
23
+ dt_mask = -dt(1 - trimap[:, :, k]) ** 2
24
+ L = 320
25
+ clicks[:, :, 3 * k] = np.exp(dt_mask / (2 * ((0.02 * L) ** 2)))
26
+ clicks[:, :, 3 * k + 1] = np.exp(dt_mask / (2 * ((0.08 * L) ** 2)))
27
+ clicks[:, :, 3 * k + 2] = np.exp(dt_mask / (2 * ((0.16 * L) ** 2)))
28
+
29
+ return clicks
30
+
31
+
32
+ def groupnorm_normalise_image(img, format="nhwc"):
33
+ """
34
+ Accept rgb in range 0,1
35
+ """
36
+ if format == "nhwc":
37
+ for i in range(3):
38
+ img[..., i] = (img[..., i] - group_norm_mean[i]) / group_norm_std[i]
39
+ else:
40
+ for i in range(3):
41
+ img[..., i, :, :] = (
42
+ img[..., i, :, :] - group_norm_mean[i]
43
+ ) / group_norm_std[i]
44
+
45
+ return img
carvekit/ml/arch/tracerb7/__init__.py ADDED
File without changes
carvekit/ml/arch/tracerb7/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (187 Bytes). View file
 
carvekit/ml/arch/tracerb7/__pycache__/att_modules.cpython-38.pyc ADDED
Binary file (7.42 kB). View file
 
carvekit/ml/arch/tracerb7/__pycache__/conv_modules.cpython-38.pyc ADDED
Binary file (2.44 kB). View file
 
carvekit/ml/arch/tracerb7/__pycache__/effi_utils.cpython-38.pyc ADDED
Binary file (14.9 kB). View file
 
carvekit/ml/arch/tracerb7/__pycache__/efficientnet.cpython-38.pyc ADDED
Binary file (8.02 kB). View file
 
carvekit/ml/arch/tracerb7/__pycache__/tracer.cpython-38.pyc ADDED
Binary file (2.83 kB). View file
 
carvekit/ml/arch/tracerb7/att_modules.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Source url: https://github.com/Karel911/TRACER
3
+ Author: Min Seok Lee and Wooseok Shin
4
+ License: Apache License 2.0
5
+ """
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from carvekit.ml.arch.tracerb7.conv_modules import BasicConv2d, DWConv, DWSConv
11
+
12
+
13
+ class RFB_Block(nn.Module):
14
+ def __init__(self, in_channel, out_channel):
15
+ super(RFB_Block, self).__init__()
16
+ self.relu = nn.ReLU(True)
17
+ self.branch0 = nn.Sequential(
18
+ BasicConv2d(in_channel, out_channel, 1),
19
+ )
20
+ self.branch1 = nn.Sequential(
21
+ BasicConv2d(in_channel, out_channel, 1),
22
+ BasicConv2d(out_channel, out_channel, kernel_size=(1, 3), padding=(0, 1)),
23
+ BasicConv2d(out_channel, out_channel, kernel_size=(3, 1), padding=(1, 0)),
24
+ BasicConv2d(out_channel, out_channel, 3, padding=3, dilation=3),
25
+ )
26
+ self.branch2 = nn.Sequential(
27
+ BasicConv2d(in_channel, out_channel, 1),
28
+ BasicConv2d(out_channel, out_channel, kernel_size=(1, 5), padding=(0, 2)),
29
+ BasicConv2d(out_channel, out_channel, kernel_size=(5, 1), padding=(2, 0)),
30
+ BasicConv2d(out_channel, out_channel, 3, padding=5, dilation=5),
31
+ )
32
+ self.branch3 = nn.Sequential(
33
+ BasicConv2d(in_channel, out_channel, 1),
34
+ BasicConv2d(out_channel, out_channel, kernel_size=(1, 7), padding=(0, 3)),
35
+ BasicConv2d(out_channel, out_channel, kernel_size=(7, 1), padding=(3, 0)),
36
+ BasicConv2d(out_channel, out_channel, 3, padding=7, dilation=7),
37
+ )
38
+ self.conv_cat = BasicConv2d(4 * out_channel, out_channel, 3, padding=1)
39
+ self.conv_res = BasicConv2d(in_channel, out_channel, 1)
40
+
41
+ def forward(self, x):
42
+ x0 = self.branch0(x)
43
+ x1 = self.branch1(x)
44
+ x2 = self.branch2(x)
45
+ x3 = self.branch3(x)
46
+ x_cat = torch.cat((x0, x1, x2, x3), 1)
47
+ x_cat = self.conv_cat(x_cat)
48
+
49
+ x = self.relu(x_cat + self.conv_res(x))
50
+ return x
51
+
52
+
53
+ class GlobalAvgPool(nn.Module):
54
+ def __init__(self, flatten=False):
55
+ super(GlobalAvgPool, self).__init__()
56
+ self.flatten = flatten
57
+
58
+ def forward(self, x):
59
+ if self.flatten:
60
+ in_size = x.size()
61
+ return x.view((in_size[0], in_size[1], -1)).mean(dim=2)
62
+ else:
63
+ return (
64
+ x.view(x.size(0), x.size(1), -1)
65
+ .mean(-1)
66
+ .view(x.size(0), x.size(1), 1, 1)
67
+ )
68
+
69
+
70
+ class UnionAttentionModule(nn.Module):
71
+ def __init__(self, n_channels, only_channel_tracing=False):
72
+ super(UnionAttentionModule, self).__init__()
73
+ self.GAP = GlobalAvgPool()
74
+ self.confidence_ratio = 0.1
75
+ self.bn = nn.BatchNorm2d(n_channels)
76
+ self.norm = nn.Sequential(
77
+ nn.BatchNorm2d(n_channels), nn.Dropout3d(self.confidence_ratio)
78
+ )
79
+ self.channel_q = nn.Conv2d(
80
+ in_channels=n_channels,
81
+ out_channels=n_channels,
82
+ kernel_size=1,
83
+ stride=1,
84
+ padding=0,
85
+ bias=False,
86
+ )
87
+ self.channel_k = nn.Conv2d(
88
+ in_channels=n_channels,
89
+ out_channels=n_channels,
90
+ kernel_size=1,
91
+ stride=1,
92
+ padding=0,
93
+ bias=False,
94
+ )
95
+ self.channel_v = nn.Conv2d(
96
+ in_channels=n_channels,
97
+ out_channels=n_channels,
98
+ kernel_size=1,
99
+ stride=1,
100
+ padding=0,
101
+ bias=False,
102
+ )
103
+
104
+ self.fc = nn.Conv2d(
105
+ in_channels=n_channels,
106
+ out_channels=n_channels,
107
+ kernel_size=1,
108
+ stride=1,
109
+ padding=0,
110
+ bias=False,
111
+ )
112
+
113
+ if only_channel_tracing is False:
114
+ self.spatial_q = nn.Conv2d(
115
+ in_channels=n_channels,
116
+ out_channels=1,
117
+ kernel_size=1,
118
+ stride=1,
119
+ padding=0,
120
+ bias=False,
121
+ )
122
+ self.spatial_k = nn.Conv2d(
123
+ in_channels=n_channels,
124
+ out_channels=1,
125
+ kernel_size=1,
126
+ stride=1,
127
+ padding=0,
128
+ bias=False,
129
+ )
130
+ self.spatial_v = nn.Conv2d(
131
+ in_channels=n_channels,
132
+ out_channels=1,
133
+ kernel_size=1,
134
+ stride=1,
135
+ padding=0,
136
+ bias=False,
137
+ )
138
+ self.sigmoid = nn.Sigmoid()
139
+
140
+ def masking(self, x, mask):
141
+ mask = mask.squeeze(3).squeeze(2)
142
+ threshold = torch.quantile(
143
+ mask.float(), self.confidence_ratio, dim=-1, keepdim=True
144
+ )
145
+ mask[mask <= threshold] = 0.0
146
+ mask = mask.unsqueeze(2).unsqueeze(3)
147
+ mask = mask.expand(-1, x.shape[1], x.shape[2], x.shape[3]).contiguous()
148
+ masked_x = x * mask
149
+
150
+ return masked_x
151
+
152
+ def Channel_Tracer(self, x):
153
+ avg_pool = self.GAP(x)
154
+ x_norm = self.norm(avg_pool)
155
+
156
+ q = self.channel_q(x_norm).squeeze(-1)
157
+ k = self.channel_k(x_norm).squeeze(-1)
158
+ v = self.channel_v(x_norm).squeeze(-1)
159
+
160
+ # softmax(Q*K^T)
161
+ QK_T = torch.matmul(q, k.transpose(1, 2))
162
+ alpha = F.softmax(QK_T, dim=-1)
163
+
164
+ # a*v
165
+ att = torch.matmul(alpha, v).unsqueeze(-1)
166
+ att = self.fc(att)
167
+ att = self.sigmoid(att)
168
+
169
+ output = (x * att) + x
170
+ alpha_mask = att.clone()
171
+
172
+ return output, alpha_mask
173
+
174
+ def forward(self, x):
175
+ X_c, alpha_mask = self.Channel_Tracer(x)
176
+ X_c = self.bn(X_c)
177
+ x_drop = self.masking(X_c, alpha_mask)
178
+
179
+ q = self.spatial_q(x_drop).squeeze(1)
180
+ k = self.spatial_k(x_drop).squeeze(1)
181
+ v = self.spatial_v(x_drop).squeeze(1)
182
+
183
+ # softmax(Q*K^T)
184
+ QK_T = torch.matmul(q, k.transpose(1, 2))
185
+ alpha = F.softmax(QK_T, dim=-1)
186
+
187
+ output = torch.matmul(alpha, v).unsqueeze(1) + v.unsqueeze(1)
188
+
189
+ return output
190
+
191
+
192
+ class aggregation(nn.Module):
193
+ def __init__(self, channel):
194
+ super(aggregation, self).__init__()
195
+ self.relu = nn.ReLU(True)
196
+
197
+ self.upsample = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
198
+ self.conv_upsample1 = BasicConv2d(channel[2], channel[1], 3, padding=1)
199
+ self.conv_upsample2 = BasicConv2d(channel[2], channel[0], 3, padding=1)
200
+ self.conv_upsample3 = BasicConv2d(channel[1], channel[0], 3, padding=1)
201
+ self.conv_upsample4 = BasicConv2d(channel[2], channel[2], 3, padding=1)
202
+ self.conv_upsample5 = BasicConv2d(
203
+ channel[2] + channel[1], channel[2] + channel[1], 3, padding=1
204
+ )
205
+
206
+ self.conv_concat2 = BasicConv2d(
207
+ (channel[2] + channel[1]), (channel[2] + channel[1]), 3, padding=1
208
+ )
209
+ self.conv_concat3 = BasicConv2d(
210
+ (channel[0] + channel[1] + channel[2]),
211
+ (channel[0] + channel[1] + channel[2]),
212
+ 3,
213
+ padding=1,
214
+ )
215
+
216
+ self.UAM = UnionAttentionModule(channel[0] + channel[1] + channel[2])
217
+
218
+ def forward(self, e4, e3, e2):
219
+ e4_1 = e4
220
+ e3_1 = self.conv_upsample1(self.upsample(e4)) * e3
221
+ e2_1 = (
222
+ self.conv_upsample2(self.upsample(self.upsample(e4)))
223
+ * self.conv_upsample3(self.upsample(e3))
224
+ * e2
225
+ )
226
+
227
+ e3_2 = torch.cat((e3_1, self.conv_upsample4(self.upsample(e4_1))), 1)
228
+ e3_2 = self.conv_concat2(e3_2)
229
+
230
+ e2_2 = torch.cat((e2_1, self.conv_upsample5(self.upsample(e3_2))), 1)
231
+ x = self.conv_concat3(e2_2)
232
+
233
+ output = self.UAM(x)
234
+
235
+ return output
236
+
237
+
238
+ class ObjectAttention(nn.Module):
239
+ def __init__(self, channel, kernel_size):
240
+ super(ObjectAttention, self).__init__()
241
+ self.channel = channel
242
+ self.DWSConv = DWSConv(
243
+ channel, channel // 2, kernel=kernel_size, padding=1, kernels_per_layer=1
244
+ )
245
+ self.DWConv1 = nn.Sequential(
246
+ DWConv(channel // 2, channel // 2, kernel=1, padding=0, dilation=1),
247
+ BasicConv2d(channel // 2, channel // 8, 1),
248
+ )
249
+ self.DWConv2 = nn.Sequential(
250
+ DWConv(channel // 2, channel // 2, kernel=3, padding=1, dilation=1),
251
+ BasicConv2d(channel // 2, channel // 8, 1),
252
+ )
253
+ self.DWConv3 = nn.Sequential(
254
+ DWConv(channel // 2, channel // 2, kernel=3, padding=3, dilation=3),
255
+ BasicConv2d(channel // 2, channel // 8, 1),
256
+ )
257
+ self.DWConv4 = nn.Sequential(
258
+ DWConv(channel // 2, channel // 2, kernel=3, padding=5, dilation=5),
259
+ BasicConv2d(channel // 2, channel // 8, 1),
260
+ )
261
+ self.conv1 = BasicConv2d(channel // 2, 1, 1)
262
+
263
+ def forward(self, decoder_map, encoder_map):
264
+ """
265
+ Args:
266
+ decoder_map: decoder representation (B, 1, H, W).
267
+ encoder_map: encoder block output (B, C, H, W).
268
+ Returns:
269
+ decoder representation: (B, 1, H, W)
270
+ """
271
+ mask_bg = -1 * torch.sigmoid(decoder_map) + 1 # Sigmoid & Reverse
272
+ mask_ob = torch.sigmoid(decoder_map) # object attention
273
+ x = mask_ob.expand(-1, self.channel, -1, -1).mul(encoder_map)
274
+
275
+ edge = mask_bg.clone()
276
+ edge[edge > 0.93] = 0
277
+ x = x + (edge * encoder_map)
278
+
279
+ x = self.DWSConv(x)
280
+ skip = x.clone()
281
+ x = (
282
+ torch.cat(
283
+ [self.DWConv1(x), self.DWConv2(x), self.DWConv3(x), self.DWConv4(x)],
284
+ dim=1,
285
+ )
286
+ + skip
287
+ )
288
+ x = torch.relu(self.conv1(x))
289
+
290
+ return x + decoder_map
carvekit/ml/arch/tracerb7/conv_modules.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Source url: https://github.com/Karel911/TRACER
3
+ Author: Min Seok Lee and Wooseok Shin
4
+ License: Apache License 2.0
5
+ """
6
+ import torch.nn as nn
7
+
8
+
9
+ class BasicConv2d(nn.Module):
10
+ def __init__(
11
+ self,
12
+ in_channel,
13
+ out_channel,
14
+ kernel_size,
15
+ stride=(1, 1),
16
+ padding=(0, 0),
17
+ dilation=(1, 1),
18
+ ):
19
+ super(BasicConv2d, self).__init__()
20
+ self.conv = nn.Conv2d(
21
+ in_channel,
22
+ out_channel,
23
+ kernel_size=kernel_size,
24
+ stride=stride,
25
+ padding=padding,
26
+ dilation=dilation,
27
+ bias=False,
28
+ )
29
+ self.bn = nn.BatchNorm2d(out_channel)
30
+ self.selu = nn.SELU()
31
+
32
+ def forward(self, x):
33
+ x = self.conv(x)
34
+ x = self.bn(x)
35
+ x = self.selu(x)
36
+
37
+ return x
38
+
39
+
40
+ class DWConv(nn.Module):
41
+ def __init__(self, in_channel, out_channel, kernel, dilation, padding):
42
+ super(DWConv, self).__init__()
43
+ self.out_channel = out_channel
44
+ self.DWConv = nn.Conv2d(
45
+ in_channel,
46
+ out_channel,
47
+ kernel_size=kernel,
48
+ padding=padding,
49
+ groups=in_channel,
50
+ dilation=dilation,
51
+ bias=False,
52
+ )
53
+ self.bn = nn.BatchNorm2d(out_channel)
54
+ self.selu = nn.SELU()
55
+
56
+ def forward(self, x):
57
+ x = self.DWConv(x)
58
+ out = self.selu(self.bn(x))
59
+
60
+ return out
61
+
62
+
63
+ class DWSConv(nn.Module):
64
+ def __init__(self, in_channel, out_channel, kernel, padding, kernels_per_layer):
65
+ super(DWSConv, self).__init__()
66
+ self.out_channel = out_channel
67
+ self.DWConv = nn.Conv2d(
68
+ in_channel,
69
+ in_channel * kernels_per_layer,
70
+ kernel_size=kernel,
71
+ padding=padding,
72
+ groups=in_channel,
73
+ bias=False,
74
+ )
75
+ self.bn = nn.BatchNorm2d(in_channel * kernels_per_layer)
76
+ self.selu = nn.SELU()
77
+ self.PWConv = nn.Conv2d(
78
+ in_channel * kernels_per_layer, out_channel, kernel_size=1, bias=False
79
+ )
80
+ self.bn2 = nn.BatchNorm2d(out_channel)
81
+
82
+ def forward(self, x):
83
+ x = self.DWConv(x)
84
+ x = self.selu(self.bn(x))
85
+ out = self.PWConv(x)
86
+ out = self.selu(self.bn2(out))
87
+
88
+ return out
carvekit/ml/arch/tracerb7/effi_utils.py ADDED
@@ -0,0 +1,579 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Original author: lukemelas (github username)
3
+ Github repo: https://github.com/lukemelas/EfficientNet-PyTorch
4
+ With adjustments and added comments by workingcoder (github username).
5
+ License: Apache License 2.0
6
+ Reimplemented: Min Seok Lee and Wooseok Shin
7
+ """
8
+
9
+ import collections
10
+ import re
11
+ from functools import partial
12
+
13
+ import math
14
+ import torch
15
+ from torch import nn
16
+ from torch.nn import functional as F
17
+
18
+ # Parameters for the entire model (stem, all blocks, and head)
19
+ GlobalParams = collections.namedtuple(
20
+ "GlobalParams",
21
+ [
22
+ "width_coefficient",
23
+ "depth_coefficient",
24
+ "image_size",
25
+ "dropout_rate",
26
+ "num_classes",
27
+ "batch_norm_momentum",
28
+ "batch_norm_epsilon",
29
+ "drop_connect_rate",
30
+ "depth_divisor",
31
+ "min_depth",
32
+ "include_top",
33
+ ],
34
+ )
35
+
36
+ # Parameters for an individual model block
37
+ BlockArgs = collections.namedtuple(
38
+ "BlockArgs",
39
+ [
40
+ "num_repeat",
41
+ "kernel_size",
42
+ "stride",
43
+ "expand_ratio",
44
+ "input_filters",
45
+ "output_filters",
46
+ "se_ratio",
47
+ "id_skip",
48
+ ],
49
+ )
50
+
51
+ # Set GlobalParams and BlockArgs's defaults
52
+ GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields)
53
+ BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields)
54
+
55
+
56
+ # An ordinary implementation of Swish function
57
+ class Swish(nn.Module):
58
+ def forward(self, x):
59
+ return x * torch.sigmoid(x)
60
+
61
+
62
+ # A memory-efficient implementation of Swish function
63
+ class SwishImplementation(torch.autograd.Function):
64
+ @staticmethod
65
+ def forward(ctx, i):
66
+ result = i * torch.sigmoid(i)
67
+ ctx.save_for_backward(i)
68
+ return result
69
+
70
+ @staticmethod
71
+ def backward(ctx, grad_output):
72
+ i = ctx.saved_tensors[0]
73
+ sigmoid_i = torch.sigmoid(i)
74
+ return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
75
+
76
+
77
+ class MemoryEfficientSwish(nn.Module):
78
+ def forward(self, x):
79
+ return SwishImplementation.apply(x)
80
+
81
+
82
+ def round_filters(filters, global_params):
83
+ """Calculate and round number of filters based on width multiplier.
84
+ Use width_coefficient, depth_divisor and min_depth of global_params.
85
+
86
+ Args:
87
+ filters (int): Filters number to be calculated.
88
+ global_params (namedtuple): Global params of the model.
89
+
90
+ Returns:
91
+ new_filters: New filters number after calculating.
92
+ """
93
+ multiplier = global_params.width_coefficient
94
+ if not multiplier:
95
+ return filters
96
+ divisor = global_params.depth_divisor
97
+ min_depth = global_params.min_depth
98
+ filters *= multiplier
99
+ min_depth = min_depth or divisor # pay attention to this line when using min_depth
100
+ # follow the formula transferred from official TensorFlow implementation
101
+ new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor)
102
+ if new_filters < 0.9 * filters: # prevent rounding by more than 10%
103
+ new_filters += divisor
104
+ return int(new_filters)
105
+
106
+
107
+ def round_repeats(repeats, global_params):
108
+ """Calculate module's repeat number of a block based on depth multiplier.
109
+ Use depth_coefficient of global_params.
110
+
111
+ Args:
112
+ repeats (int): num_repeat to be calculated.
113
+ global_params (namedtuple): Global params of the model.
114
+
115
+ Returns:
116
+ new repeat: New repeat number after calculating.
117
+ """
118
+ multiplier = global_params.depth_coefficient
119
+ if not multiplier:
120
+ return repeats
121
+ # follow the formula transferred from official TensorFlow implementation
122
+ return int(math.ceil(multiplier * repeats))
123
+
124
+
125
+ def drop_connect(inputs, p, training):
126
+ """Drop connect.
127
+
128
+ Args:
129
+ input (tensor: BCWH): Input of this structure.
130
+ p (float: 0.0~1.0): Probability of drop connection.
131
+ training (bool): The running mode.
132
+
133
+ Returns:
134
+ output: Output after drop connection.
135
+ """
136
+ assert 0 <= p <= 1, "p must be in range of [0,1]"
137
+
138
+ if not training:
139
+ return inputs
140
+
141
+ batch_size = inputs.shape[0]
142
+ keep_prob = 1 - p
143
+
144
+ # generate binary_tensor mask according to probability (p for 0, 1-p for 1)
145
+ random_tensor = keep_prob
146
+ random_tensor += torch.rand(
147
+ [batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device
148
+ )
149
+ binary_tensor = torch.floor(random_tensor)
150
+
151
+ output = inputs / keep_prob * binary_tensor
152
+ return output
153
+
154
+
155
+ def get_width_and_height_from_size(x):
156
+ """Obtain height and width from x.
157
+
158
+ Args:
159
+ x (int, tuple or list): Data size.
160
+
161
+ Returns:
162
+ size: A tuple or list (H,W).
163
+ """
164
+ if isinstance(x, int):
165
+ return x, x
166
+ if isinstance(x, list) or isinstance(x, tuple):
167
+ return x
168
+ else:
169
+ raise TypeError()
170
+
171
+
172
+ def calculate_output_image_size(input_image_size, stride):
173
+ """Calculates the output image size when using Conv2dSamePadding with a stride.
174
+ Necessary for static padding. Thanks to mannatsingh for pointing this out.
175
+
176
+ Args:
177
+ input_image_size (int, tuple or list): Size of input image.
178
+ stride (int, tuple or list): Conv2d operation's stride.
179
+
180
+ Returns:
181
+ output_image_size: A list [H,W].
182
+ """
183
+ if input_image_size is None:
184
+ return None
185
+ image_height, image_width = get_width_and_height_from_size(input_image_size)
186
+ stride = stride if isinstance(stride, int) else stride[0]
187
+ image_height = int(math.ceil(image_height / stride))
188
+ image_width = int(math.ceil(image_width / stride))
189
+ return [image_height, image_width]
190
+
191
+
192
+ # Note:
193
+ # The following 'SamePadding' functions make output size equal ceil(input size/stride).
194
+ # Only when stride equals 1, can the output size be the same as input size.
195
+ # Don't be confused by their function names ! ! !
196
+
197
+
198
+ def get_same_padding_conv2d(image_size=None):
199
+ """Chooses static padding if you have specified an image size, and dynamic padding otherwise.
200
+ Static padding is necessary for ONNX exporting of models.
201
+
202
+ Args:
203
+ image_size (int or tuple): Size of the image.
204
+
205
+ Returns:
206
+ Conv2dDynamicSamePadding or Conv2dStaticSamePadding.
207
+ """
208
+ if image_size is None:
209
+ return Conv2dDynamicSamePadding
210
+ else:
211
+ return partial(Conv2dStaticSamePadding, image_size=image_size)
212
+
213
+
214
+ class Conv2dDynamicSamePadding(nn.Conv2d):
215
+ """2D Convolutions like TensorFlow, for a dynamic image size.
216
+ The padding is operated in forward function by calculating dynamically.
217
+ """
218
+
219
+ # Tips for 'SAME' mode padding.
220
+ # Given the following:
221
+ # i: width or height
222
+ # s: stride
223
+ # k: kernel size
224
+ # d: dilation
225
+ # p: padding
226
+ # Output after Conv2d:
227
+ # o = floor((i+p-((k-1)*d+1))/s+1)
228
+ # If o equals i, i = floor((i+p-((k-1)*d+1))/s+1),
229
+ # => p = (i-1)*s+((k-1)*d+1)-i
230
+
231
+ def __init__(
232
+ self,
233
+ in_channels,
234
+ out_channels,
235
+ kernel_size,
236
+ stride=1,
237
+ dilation=1,
238
+ groups=1,
239
+ bias=True,
240
+ ):
241
+ super().__init__(
242
+ in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias
243
+ )
244
+ self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2
245
+
246
+ def forward(self, x):
247
+ ih, iw = x.size()[-2:]
248
+ kh, kw = self.weight.size()[-2:]
249
+ sh, sw = self.stride
250
+ oh, ow = math.ceil(ih / sh), math.ceil(
251
+ iw / sw
252
+ ) # change the output size according to stride ! ! !
253
+ pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
254
+ pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
255
+ if pad_h > 0 or pad_w > 0:
256
+ x = F.pad(
257
+ x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]
258
+ )
259
+ return F.conv2d(
260
+ x,
261
+ self.weight,
262
+ self.bias,
263
+ self.stride,
264
+ self.padding,
265
+ self.dilation,
266
+ self.groups,
267
+ )
268
+
269
+
270
+ class Conv2dStaticSamePadding(nn.Conv2d):
271
+ """2D Convolutions like TensorFlow's 'SAME' mode, with the given input image size.
272
+ The padding mudule is calculated in construction function, then used in forward.
273
+ """
274
+
275
+ # With the same calculation as Conv2dDynamicSamePadding
276
+
277
+ def __init__(
278
+ self,
279
+ in_channels,
280
+ out_channels,
281
+ kernel_size,
282
+ stride=1,
283
+ image_size=None,
284
+ **kwargs
285
+ ):
286
+ super().__init__(in_channels, out_channels, kernel_size, stride, **kwargs)
287
+ self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2
288
+
289
+ # Calculate padding based on image size and save it
290
+ assert image_size is not None
291
+ ih, iw = (image_size, image_size) if isinstance(image_size, int) else image_size
292
+ kh, kw = self.weight.size()[-2:]
293
+ sh, sw = self.stride
294
+ oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
295
+ pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
296
+ pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
297
+ if pad_h > 0 or pad_w > 0:
298
+ self.static_padding = nn.ZeroPad2d(
299
+ (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)
300
+ )
301
+ else:
302
+ self.static_padding = nn.Identity()
303
+
304
+ def forward(self, x):
305
+ x = self.static_padding(x)
306
+ x = F.conv2d(
307
+ x,
308
+ self.weight,
309
+ self.bias,
310
+ self.stride,
311
+ self.padding,
312
+ self.dilation,
313
+ self.groups,
314
+ )
315
+ return x
316
+
317
+
318
+ def get_same_padding_maxPool2d(image_size=None):
319
+ """Chooses static padding if you have specified an image size, and dynamic padding otherwise.
320
+ Static padding is necessary for ONNX exporting of models.
321
+
322
+ Args:
323
+ image_size (int or tuple): Size of the image.
324
+
325
+ Returns:
326
+ MaxPool2dDynamicSamePadding or MaxPool2dStaticSamePadding.
327
+ """
328
+ if image_size is None:
329
+ return MaxPool2dDynamicSamePadding
330
+ else:
331
+ return partial(MaxPool2dStaticSamePadding, image_size=image_size)
332
+
333
+
334
+ class MaxPool2dDynamicSamePadding(nn.MaxPool2d):
335
+ """2D MaxPooling like TensorFlow's 'SAME' mode, with a dynamic image size.
336
+ The padding is operated in forward function by calculating dynamically.
337
+ """
338
+
339
+ def __init__(
340
+ self,
341
+ kernel_size,
342
+ stride,
343
+ padding=0,
344
+ dilation=1,
345
+ return_indices=False,
346
+ ceil_mode=False,
347
+ ):
348
+ super().__init__(
349
+ kernel_size, stride, padding, dilation, return_indices, ceil_mode
350
+ )
351
+ self.stride = [self.stride] * 2 if isinstance(self.stride, int) else self.stride
352
+ self.kernel_size = (
353
+ [self.kernel_size] * 2
354
+ if isinstance(self.kernel_size, int)
355
+ else self.kernel_size
356
+ )
357
+ self.dilation = (
358
+ [self.dilation] * 2 if isinstance(self.dilation, int) else self.dilation
359
+ )
360
+
361
+ def forward(self, x):
362
+ ih, iw = x.size()[-2:]
363
+ kh, kw = self.kernel_size
364
+ sh, sw = self.stride
365
+ oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
366
+ pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
367
+ pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
368
+ if pad_h > 0 or pad_w > 0:
369
+ x = F.pad(
370
+ x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]
371
+ )
372
+ return F.max_pool2d(
373
+ x,
374
+ self.kernel_size,
375
+ self.stride,
376
+ self.padding,
377
+ self.dilation,
378
+ self.ceil_mode,
379
+ self.return_indices,
380
+ )
381
+
382
+
383
+ class MaxPool2dStaticSamePadding(nn.MaxPool2d):
384
+ """2D MaxPooling like TensorFlow's 'SAME' mode, with the given input image size.
385
+ The padding mudule is calculated in construction function, then used in forward.
386
+ """
387
+
388
+ def __init__(self, kernel_size, stride, image_size=None, **kwargs):
389
+ super().__init__(kernel_size, stride, **kwargs)
390
+ self.stride = [self.stride] * 2 if isinstance(self.stride, int) else self.stride
391
+ self.kernel_size = (
392
+ [self.kernel_size] * 2
393
+ if isinstance(self.kernel_size, int)
394
+ else self.kernel_size
395
+ )
396
+ self.dilation = (
397
+ [self.dilation] * 2 if isinstance(self.dilation, int) else self.dilation
398
+ )
399
+
400
+ # Calculate padding based on image size and save it
401
+ assert image_size is not None
402
+ ih, iw = (image_size, image_size) if isinstance(image_size, int) else image_size
403
+ kh, kw = self.kernel_size
404
+ sh, sw = self.stride
405
+ oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
406
+ pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
407
+ pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
408
+ if pad_h > 0 or pad_w > 0:
409
+ self.static_padding = nn.ZeroPad2d(
410
+ (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)
411
+ )
412
+ else:
413
+ self.static_padding = nn.Identity()
414
+
415
+ def forward(self, x):
416
+ x = self.static_padding(x)
417
+ x = F.max_pool2d(
418
+ x,
419
+ self.kernel_size,
420
+ self.stride,
421
+ self.padding,
422
+ self.dilation,
423
+ self.ceil_mode,
424
+ self.return_indices,
425
+ )
426
+ return x
427
+
428
+
429
+ class BlockDecoder(object):
430
+ """Block Decoder for readability,
431
+ straight from the official TensorFlow repository.
432
+ """
433
+
434
+ @staticmethod
435
+ def _decode_block_string(block_string):
436
+ """Get a block through a string notation of arguments.
437
+
438
+ Args:
439
+ block_string (str): A string notation of arguments.
440
+ Examples: 'r1_k3_s11_e1_i32_o16_se0.25_noskip'.
441
+
442
+ Returns:
443
+ BlockArgs: The namedtuple defined at the top of this file.
444
+ """
445
+ assert isinstance(block_string, str)
446
+
447
+ ops = block_string.split("_")
448
+ options = {}
449
+ for op in ops:
450
+ splits = re.split(r"(\d.*)", op)
451
+ if len(splits) >= 2:
452
+ key, value = splits[:2]
453
+ options[key] = value
454
+
455
+ # Check stride
456
+ assert ("s" in options and len(options["s"]) == 1) or (
457
+ len(options["s"]) == 2 and options["s"][0] == options["s"][1]
458
+ )
459
+
460
+ return BlockArgs(
461
+ num_repeat=int(options["r"]),
462
+ kernel_size=int(options["k"]),
463
+ stride=[int(options["s"][0])],
464
+ expand_ratio=int(options["e"]),
465
+ input_filters=int(options["i"]),
466
+ output_filters=int(options["o"]),
467
+ se_ratio=float(options["se"]) if "se" in options else None,
468
+ id_skip=("noskip" not in block_string),
469
+ )
470
+
471
+ @staticmethod
472
+ def _encode_block_string(block):
473
+ """Encode a block to a string.
474
+
475
+ Args:
476
+ block (namedtuple): A BlockArgs type argument.
477
+
478
+ Returns:
479
+ block_string: A String form of BlockArgs.
480
+ """
481
+ args = [
482
+ "r%d" % block.num_repeat,
483
+ "k%d" % block.kernel_size,
484
+ "s%d%d" % (block.strides[0], block.strides[1]),
485
+ "e%s" % block.expand_ratio,
486
+ "i%d" % block.input_filters,
487
+ "o%d" % block.output_filters,
488
+ ]
489
+ if 0 < block.se_ratio <= 1:
490
+ args.append("se%s" % block.se_ratio)
491
+ if block.id_skip is False:
492
+ args.append("noskip")
493
+ return "_".join(args)
494
+
495
+ @staticmethod
496
+ def decode(string_list):
497
+ """Decode a list of string notations to specify blocks inside the network.
498
+
499
+ Args:
500
+ string_list (list[str]): A list of strings, each string is a notation of block.
501
+
502
+ Returns:
503
+ blocks_args: A list of BlockArgs namedtuples of block args.
504
+ """
505
+ assert isinstance(string_list, list)
506
+ blocks_args = []
507
+ for block_string in string_list:
508
+ blocks_args.append(BlockDecoder._decode_block_string(block_string))
509
+ return blocks_args
510
+
511
+ @staticmethod
512
+ def encode(blocks_args):
513
+ """Encode a list of BlockArgs to a list of strings.
514
+
515
+ Args:
516
+ blocks_args (list[namedtuples]): A list of BlockArgs namedtuples of block args.
517
+
518
+ Returns:
519
+ block_strings: A list of strings, each string is a notation of block.
520
+ """
521
+ block_strings = []
522
+ for block in blocks_args:
523
+ block_strings.append(BlockDecoder._encode_block_string(block))
524
+ return block_strings
525
+
526
+
527
+ def create_block_args(
528
+ width_coefficient=None,
529
+ depth_coefficient=None,
530
+ image_size=None,
531
+ dropout_rate=0.2,
532
+ drop_connect_rate=0.2,
533
+ num_classes=1000,
534
+ include_top=True,
535
+ ):
536
+ """Create BlockArgs and GlobalParams for efficientnet model.
537
+
538
+ Args:
539
+ width_coefficient (float)
540
+ depth_coefficient (float)
541
+ image_size (int)
542
+ dropout_rate (float)
543
+ drop_connect_rate (float)
544
+ num_classes (int)
545
+
546
+ Meaning as the name suggests.
547
+
548
+ Returns:
549
+ blocks_args, global_params.
550
+ """
551
+
552
+ # Blocks args for the whole model(efficientnet-b0 by default)
553
+ # It will be modified in the construction of EfficientNet Class according to model
554
+ blocks_args = [
555
+ "r1_k3_s11_e1_i32_o16_se0.25",
556
+ "r2_k3_s22_e6_i16_o24_se0.25",
557
+ "r2_k5_s22_e6_i24_o40_se0.25",
558
+ "r3_k3_s22_e6_i40_o80_se0.25",
559
+ "r3_k5_s11_e6_i80_o112_se0.25",
560
+ "r4_k5_s22_e6_i112_o192_se0.25",
561
+ "r1_k3_s11_e6_i192_o320_se0.25",
562
+ ]
563
+ blocks_args = BlockDecoder.decode(blocks_args)
564
+
565
+ global_params = GlobalParams(
566
+ width_coefficient=width_coefficient,
567
+ depth_coefficient=depth_coefficient,
568
+ image_size=image_size,
569
+ dropout_rate=dropout_rate,
570
+ num_classes=num_classes,
571
+ batch_norm_momentum=0.99,
572
+ batch_norm_epsilon=1e-3,
573
+ drop_connect_rate=drop_connect_rate,
574
+ depth_divisor=8,
575
+ min_depth=None,
576
+ include_top=include_top,
577
+ )
578
+
579
+ return blocks_args, global_params
carvekit/ml/arch/tracerb7/efficientnet.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Source url: https://github.com/lukemelas/EfficientNet-PyTorch
3
+ Modified by Min Seok Lee, Wooseok Shin, Nikita Selin
4
+ License: Apache License 2.0
5
+ Changes:
6
+ - Added support for extracting edge features
7
+ - Added support for extracting object features at different levels
8
+ - Refactored the code
9
+ """
10
+ from typing import Any, List
11
+
12
+ import torch
13
+ from torch import nn
14
+ from torch.nn import functional as F
15
+
16
+ from carvekit.ml.arch.tracerb7.effi_utils import (
17
+ get_same_padding_conv2d,
18
+ calculate_output_image_size,
19
+ MemoryEfficientSwish,
20
+ drop_connect,
21
+ round_filters,
22
+ round_repeats,
23
+ Swish,
24
+ create_block_args,
25
+ )
26
+
27
+
28
+ class MBConvBlock(nn.Module):
29
+ """Mobile Inverted Residual Bottleneck Block.
30
+
31
+ Args:
32
+ block_args (namedtuple): BlockArgs, defined in utils.py.
33
+ global_params (namedtuple): GlobalParam, defined in utils.py.
34
+ image_size (tuple or list): [image_height, image_width].
35
+
36
+ References:
37
+ [1] https://arxiv.org/abs/1704.04861 (MobileNet v1)
38
+ [2] https://arxiv.org/abs/1801.04381 (MobileNet v2)
39
+ [3] https://arxiv.org/abs/1905.02244 (MobileNet v3)
40
+ """
41
+
42
+ def __init__(self, block_args, global_params, image_size=None):
43
+ super().__init__()
44
+ self._block_args = block_args
45
+ self._bn_mom = (
46
+ 1 - global_params.batch_norm_momentum
47
+ ) # pytorch's difference from tensorflow
48
+ self._bn_eps = global_params.batch_norm_epsilon
49
+ self.has_se = (self._block_args.se_ratio is not None) and (
50
+ 0 < self._block_args.se_ratio <= 1
51
+ )
52
+ self.id_skip = (
53
+ block_args.id_skip
54
+ ) # whether to use skip connection and drop connect
55
+
56
+ # Expansion phase (Inverted Bottleneck)
57
+ inp = self._block_args.input_filters # number of input channels
58
+ oup = (
59
+ self._block_args.input_filters * self._block_args.expand_ratio
60
+ ) # number of output channels
61
+ if self._block_args.expand_ratio != 1:
62
+ Conv2d = get_same_padding_conv2d(image_size=image_size)
63
+ self._expand_conv = Conv2d(
64
+ in_channels=inp, out_channels=oup, kernel_size=1, bias=False
65
+ )
66
+ self._bn0 = nn.BatchNorm2d(
67
+ num_features=oup, momentum=self._bn_mom, eps=self._bn_eps
68
+ )
69
+ # image_size = calculate_output_image_size(image_size, 1) <-- this wouldn't modify image_size
70
+
71
+ # Depthwise convolution phase
72
+ k = self._block_args.kernel_size
73
+ s = self._block_args.stride
74
+ Conv2d = get_same_padding_conv2d(image_size=image_size)
75
+ self._depthwise_conv = Conv2d(
76
+ in_channels=oup,
77
+ out_channels=oup,
78
+ groups=oup, # groups makes it depthwise
79
+ kernel_size=k,
80
+ stride=s,
81
+ bias=False,
82
+ )
83
+ self._bn1 = nn.BatchNorm2d(
84
+ num_features=oup, momentum=self._bn_mom, eps=self._bn_eps
85
+ )
86
+ image_size = calculate_output_image_size(image_size, s)
87
+
88
+ # Squeeze and Excitation layer, if desired
89
+ if self.has_se:
90
+ Conv2d = get_same_padding_conv2d(image_size=(1, 1))
91
+ num_squeezed_channels = max(
92
+ 1, int(self._block_args.input_filters * self._block_args.se_ratio)
93
+ )
94
+ self._se_reduce = Conv2d(
95
+ in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1
96
+ )
97
+ self._se_expand = Conv2d(
98
+ in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1
99
+ )
100
+
101
+ # Pointwise convolution phase
102
+ final_oup = self._block_args.output_filters
103
+ Conv2d = get_same_padding_conv2d(image_size=image_size)
104
+ self._project_conv = Conv2d(
105
+ in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False
106
+ )
107
+ self._bn2 = nn.BatchNorm2d(
108
+ num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps
109
+ )
110
+ self._swish = MemoryEfficientSwish()
111
+
112
+ def forward(self, inputs, drop_connect_rate=None):
113
+ """MBConvBlock's forward function.
114
+
115
+ Args:
116
+ inputs (tensor): Input tensor.
117
+ drop_connect_rate (bool): Drop connect rate (float, between 0 and 1).
118
+
119
+ Returns:
120
+ Output of this block after processing.
121
+ """
122
+
123
+ # Expansion and Depthwise Convolution
124
+ x = inputs
125
+ if self._block_args.expand_ratio != 1:
126
+ x = self._expand_conv(inputs)
127
+ x = self._bn0(x)
128
+ x = self._swish(x)
129
+
130
+ x = self._depthwise_conv(x)
131
+ x = self._bn1(x)
132
+ x = self._swish(x)
133
+
134
+ # Squeeze and Excitation
135
+ if self.has_se:
136
+ x_squeezed = F.adaptive_avg_pool2d(x, 1)
137
+ x_squeezed = self._se_reduce(x_squeezed)
138
+ x_squeezed = self._swish(x_squeezed)
139
+ x_squeezed = self._se_expand(x_squeezed)
140
+ x = torch.sigmoid(x_squeezed) * x
141
+
142
+ # Pointwise Convolution
143
+ x = self._project_conv(x)
144
+ x = self._bn2(x)
145
+
146
+ # Skip connection and drop connect
147
+ input_filters, output_filters = (
148
+ self._block_args.input_filters,
149
+ self._block_args.output_filters,
150
+ )
151
+ if (
152
+ self.id_skip
153
+ and self._block_args.stride == 1
154
+ and input_filters == output_filters
155
+ ):
156
+ # The combination of skip connection and drop connect brings about stochastic depth.
157
+ if drop_connect_rate:
158
+ x = drop_connect(x, p=drop_connect_rate, training=self.training)
159
+ x = x + inputs # skip connection
160
+ return x
161
+
162
+ def set_swish(self, memory_efficient=True):
163
+ """Sets swish function as memory efficient (for training) or standard (for export).
164
+
165
+ Args:
166
+ memory_efficient (bool): Whether to use memory-efficient version of swish.
167
+ """
168
+ self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
169
+
170
+
171
+ class EfficientNet(nn.Module):
172
+ def __init__(self, blocks_args=None, global_params=None):
173
+ super().__init__()
174
+ assert isinstance(blocks_args, list), "blocks_args should be a list"
175
+ assert len(blocks_args) > 0, "block args must be greater than 0"
176
+ self._global_params = global_params
177
+ self._blocks_args = blocks_args
178
+
179
+ # Batch norm parameters
180
+ bn_mom = 1 - self._global_params.batch_norm_momentum
181
+ bn_eps = self._global_params.batch_norm_epsilon
182
+
183
+ # Get stem static or dynamic convolution depending on image size
184
+ image_size = global_params.image_size
185
+ Conv2d = get_same_padding_conv2d(image_size=image_size)
186
+
187
+ # Stem
188
+ in_channels = 3 # rgb
189
+ out_channels = round_filters(
190
+ 32, self._global_params
191
+ ) # number of output channels
192
+ self._conv_stem = Conv2d(
193
+ in_channels, out_channels, kernel_size=3, stride=2, bias=False
194
+ )
195
+ self._bn0 = nn.BatchNorm2d(
196
+ num_features=out_channels, momentum=bn_mom, eps=bn_eps
197
+ )
198
+ image_size = calculate_output_image_size(image_size, 2)
199
+
200
+ # Build blocks
201
+ self._blocks = nn.ModuleList([])
202
+ for block_args in self._blocks_args:
203
+
204
+ # Update block input and output filters based on depth multiplier.
205
+ block_args = block_args._replace(
206
+ input_filters=round_filters(
207
+ block_args.input_filters, self._global_params
208
+ ),
209
+ output_filters=round_filters(
210
+ block_args.output_filters, self._global_params
211
+ ),
212
+ num_repeat=round_repeats(block_args.num_repeat, self._global_params),
213
+ )
214
+
215
+ # The first block needs to take care of stride and filter size increase.
216
+ self._blocks.append(
217
+ MBConvBlock(block_args, self._global_params, image_size=image_size)
218
+ )
219
+ image_size = calculate_output_image_size(image_size, block_args.stride)
220
+ if block_args.num_repeat > 1: # modify block_args to keep same output size
221
+ block_args = block_args._replace(
222
+ input_filters=block_args.output_filters, stride=1
223
+ )
224
+ for _ in range(block_args.num_repeat - 1):
225
+ self._blocks.append(
226
+ MBConvBlock(block_args, self._global_params, image_size=image_size)
227
+ )
228
+ # image_size = calculate_output_image_size(image_size, block_args.stride) # stride = 1
229
+
230
+ self._swish = MemoryEfficientSwish()
231
+
232
+ def set_swish(self, memory_efficient=True):
233
+ """Sets swish function as memory efficient (for training) or standard (for export).
234
+
235
+ Args:
236
+ memory_efficient (bool): Whether to use memory-efficient version of swish.
237
+
238
+ """
239
+ self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
240
+ for block in self._blocks:
241
+ block.set_swish(memory_efficient)
242
+
243
+ def extract_endpoints(self, inputs):
244
+ endpoints = dict()
245
+
246
+ # Stem
247
+ x = self._swish(self._bn0(self._conv_stem(inputs)))
248
+ prev_x = x
249
+
250
+ # Blocks
251
+ for idx, block in enumerate(self._blocks):
252
+ drop_connect_rate = self._global_params.drop_connect_rate
253
+ if drop_connect_rate:
254
+ drop_connect_rate *= float(idx) / len(
255
+ self._blocks
256
+ ) # scale drop connect_rate
257
+ x = block(x, drop_connect_rate=drop_connect_rate)
258
+ if prev_x.size(2) > x.size(2):
259
+ endpoints["reduction_{}".format(len(endpoints) + 1)] = prev_x
260
+ prev_x = x
261
+
262
+ # Head
263
+ x = self._swish(self._bn1(self._conv_head(x)))
264
+ endpoints["reduction_{}".format(len(endpoints) + 1)] = x
265
+
266
+ return endpoints
267
+
268
+ def _change_in_channels(self, in_channels):
269
+ """Adjust model's first convolution layer to in_channels, if in_channels not equals 3.
270
+
271
+ Args:
272
+ in_channels (int): Input data's channel number.
273
+ """
274
+ if in_channels != 3:
275
+ Conv2d = get_same_padding_conv2d(image_size=self._global_params.image_size)
276
+ out_channels = round_filters(32, self._global_params)
277
+ self._conv_stem = Conv2d(
278
+ in_channels, out_channels, kernel_size=3, stride=2, bias=False
279
+ )
280
+
281
+
282
+ class EfficientEncoderB7(EfficientNet):
283
+ def __init__(self):
284
+ super().__init__(
285
+ *create_block_args(
286
+ width_coefficient=2.0,
287
+ depth_coefficient=3.1,
288
+ dropout_rate=0.5,
289
+ image_size=600,
290
+ )
291
+ )
292
+ self._change_in_channels(3)
293
+ self.block_idx = [10, 17, 37, 54]
294
+ self.channels = [48, 80, 224, 640]
295
+
296
+ def initial_conv(self, inputs):
297
+ x = self._swish(self._bn0(self._conv_stem(inputs)))
298
+ return x
299
+
300
+ def get_blocks(self, x, H, W, block_idx):
301
+ features = []
302
+ for idx, block in enumerate(self._blocks):
303
+ drop_connect_rate = self._global_params.drop_connect_rate
304
+ if drop_connect_rate:
305
+ drop_connect_rate *= float(idx) / len(
306
+ self._blocks
307
+ ) # scale drop connect_rate
308
+ x = block(x, drop_connect_rate=drop_connect_rate)
309
+ if idx == block_idx[0]:
310
+ features.append(x.clone())
311
+ if idx == block_idx[1]:
312
+ features.append(x.clone())
313
+ if idx == block_idx[2]:
314
+ features.append(x.clone())
315
+ if idx == block_idx[3]:
316
+ features.append(x.clone())
317
+
318
+ return features
319
+
320
+ def forward(self, inputs: torch.Tensor) -> List[Any]:
321
+ B, C, H, W = inputs.size()
322
+ x = self.initial_conv(inputs) # Prepare input for the backbone
323
+ return self.get_blocks(
324
+ x, H, W, block_idx=self.block_idx
325
+ ) # Get backbone features and edge maps
carvekit/ml/arch/tracerb7/tracer.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Source url: https://github.com/Karel911/TRACER
3
+ Author: Min Seok Lee and Wooseok Shin
4
+ Modified by Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
5
+ License: Apache License 2.0
6
+ Changes:
7
+ - Refactored code
8
+ - Removed unused code
9
+ - Added comments
10
+ """
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ from typing import List, Optional, Tuple
16
+
17
+ from torch import Tensor
18
+
19
+ from carvekit.ml.arch.tracerb7.efficientnet import EfficientEncoderB7
20
+ from carvekit.ml.arch.tracerb7.att_modules import (
21
+ RFB_Block,
22
+ aggregation,
23
+ ObjectAttention,
24
+ )
25
+
26
+
27
+ class TracerDecoder(nn.Module):
28
+ """Tracer Decoder"""
29
+
30
+ def __init__(
31
+ self,
32
+ encoder: EfficientEncoderB7,
33
+ features_channels: Optional[List[int]] = None,
34
+ rfb_channel: Optional[List[int]] = None,
35
+ ):
36
+ """
37
+ Initialize the tracer decoder.
38
+
39
+ Args:
40
+ encoder: The encoder to use.
41
+ features_channels: The channels of the backbone features at different stages. default: [48, 80, 224, 640]
42
+ rfb_channel: The channels of the RFB features. default: [32, 64, 128]
43
+ """
44
+ super().__init__()
45
+ if rfb_channel is None:
46
+ rfb_channel = [32, 64, 128]
47
+ if features_channels is None:
48
+ features_channels = [48, 80, 224, 640]
49
+ self.encoder = encoder
50
+ self.features_channels = features_channels
51
+
52
+ # Receptive Field Blocks
53
+ features_channels = rfb_channel
54
+ self.rfb2 = RFB_Block(self.features_channels[1], features_channels[0])
55
+ self.rfb3 = RFB_Block(self.features_channels[2], features_channels[1])
56
+ self.rfb4 = RFB_Block(self.features_channels[3], features_channels[2])
57
+
58
+ # Multi-level aggregation
59
+ self.agg = aggregation(features_channels)
60
+
61
+ # Object Attention
62
+ self.ObjectAttention2 = ObjectAttention(
63
+ channel=self.features_channels[1], kernel_size=3
64
+ )
65
+ self.ObjectAttention1 = ObjectAttention(
66
+ channel=self.features_channels[0], kernel_size=3
67
+ )
68
+
69
+ def forward(self, inputs: torch.Tensor) -> Tensor:
70
+ """
71
+ Forward pass of the tracer decoder.
72
+
73
+ Args:
74
+ inputs: Preprocessed images.
75
+
76
+ Returns:
77
+ Tensors of segmentation masks and mask of object edges.
78
+ """
79
+ features = self.encoder(inputs)
80
+ x3_rfb = self.rfb2(features[1])
81
+ x4_rfb = self.rfb3(features[2])
82
+ x5_rfb = self.rfb4(features[3])
83
+
84
+ D_0 = self.agg(x5_rfb, x4_rfb, x3_rfb)
85
+
86
+ ds_map0 = F.interpolate(D_0, scale_factor=8, mode="bilinear")
87
+
88
+ D_1 = self.ObjectAttention2(D_0, features[1])
89
+ ds_map1 = F.interpolate(D_1, scale_factor=8, mode="bilinear")
90
+
91
+ ds_map = F.interpolate(D_1, scale_factor=2, mode="bilinear")
92
+ D_2 = self.ObjectAttention1(ds_map, features[0])
93
+ ds_map2 = F.interpolate(D_2, scale_factor=4, mode="bilinear")
94
+
95
+ final_map = (ds_map2 + ds_map1 + ds_map0) / 3
96
+
97
+ return torch.sigmoid(final_map)
carvekit/ml/arch/u2net/__init__.py ADDED
File without changes
carvekit/ml/arch/u2net/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (184 Bytes). View file
 
carvekit/ml/arch/u2net/__pycache__/u2net.cpython-38.pyc ADDED
Binary file (6.13 kB). View file
 
carvekit/ml/arch/u2net/u2net.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Modified by Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
3
+ Source url: https://github.com/xuebinqin/U-2-Net
4
+ License: Apache License 2.0
5
+ """
6
+ from typing import Union
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+ import math
12
+
13
+ __all__ = ["U2NETArchitecture"]
14
+
15
+
16
+ def _upsample_like(x, size):
17
+ return nn.Upsample(size=size, mode="bilinear", align_corners=False)(x)
18
+
19
+
20
+ def _size_map(x, height):
21
+ # {height: size} for Upsample
22
+ size = list(x.shape[-2:])
23
+ sizes = {}
24
+ for h in range(1, height):
25
+ sizes[h] = size
26
+ size = [math.ceil(w / 2) for w in size]
27
+ return sizes
28
+
29
+
30
+ class REBNCONV(nn.Module):
31
+ def __init__(self, in_ch=3, out_ch=3, dilate=1):
32
+ super(REBNCONV, self).__init__()
33
+
34
+ self.conv_s1 = nn.Conv2d(
35
+ in_ch, out_ch, 3, padding=1 * dilate, dilation=1 * dilate
36
+ )
37
+ self.bn_s1 = nn.BatchNorm2d(out_ch)
38
+ self.relu_s1 = nn.ReLU(inplace=True)
39
+
40
+ def forward(self, x):
41
+ return self.relu_s1(self.bn_s1(self.conv_s1(x)))
42
+
43
+
44
+ class RSU(nn.Module):
45
+ def __init__(self, name, height, in_ch, mid_ch, out_ch, dilated=False):
46
+ super(RSU, self).__init__()
47
+ self.name = name
48
+ self.height = height
49
+ self.dilated = dilated
50
+ self._make_layers(height, in_ch, mid_ch, out_ch, dilated)
51
+
52
+ def forward(self, x):
53
+ sizes = _size_map(x, self.height)
54
+ x = self.rebnconvin(x)
55
+
56
+ # U-Net like symmetric encoder-decoder structure
57
+ def unet(x, height=1):
58
+ if height < self.height:
59
+ x1 = getattr(self, f"rebnconv{height}")(x)
60
+ if not self.dilated and height < self.height - 1:
61
+ x2 = unet(getattr(self, "downsample")(x1), height + 1)
62
+ else:
63
+ x2 = unet(x1, height + 1)
64
+
65
+ x = getattr(self, f"rebnconv{height}d")(torch.cat((x2, x1), 1))
66
+ return (
67
+ _upsample_like(x, sizes[height - 1])
68
+ if not self.dilated and height > 1
69
+ else x
70
+ )
71
+ else:
72
+ return getattr(self, f"rebnconv{height}")(x)
73
+
74
+ return x + unet(x)
75
+
76
+ def _make_layers(self, height, in_ch, mid_ch, out_ch, dilated=False):
77
+ self.add_module("rebnconvin", REBNCONV(in_ch, out_ch))
78
+ self.add_module("downsample", nn.MaxPool2d(2, stride=2, ceil_mode=True))
79
+
80
+ self.add_module("rebnconv1", REBNCONV(out_ch, mid_ch))
81
+ self.add_module("rebnconv1d", REBNCONV(mid_ch * 2, out_ch))
82
+
83
+ for i in range(2, height):
84
+ dilate = 1 if not dilated else 2 ** (i - 1)
85
+ self.add_module(f"rebnconv{i}", REBNCONV(mid_ch, mid_ch, dilate=dilate))
86
+ self.add_module(
87
+ f"rebnconv{i}d", REBNCONV(mid_ch * 2, mid_ch, dilate=dilate)
88
+ )
89
+
90
+ dilate = 2 if not dilated else 2 ** (height - 1)
91
+ self.add_module(f"rebnconv{height}", REBNCONV(mid_ch, mid_ch, dilate=dilate))
92
+
93
+
94
+ class U2NETArchitecture(nn.Module):
95
+ def __init__(self, cfg_type: Union[dict, str] = "full", out_ch: int = 1):
96
+ super(U2NETArchitecture, self).__init__()
97
+ if isinstance(cfg_type, str):
98
+ if cfg_type == "full":
99
+ layers_cfgs = {
100
+ # cfgs for building RSUs and sides
101
+ # {stage : [name, (height(L), in_ch, mid_ch, out_ch, dilated), side]}
102
+ "stage1": ["En_1", (7, 3, 32, 64), -1],
103
+ "stage2": ["En_2", (6, 64, 32, 128), -1],
104
+ "stage3": ["En_3", (5, 128, 64, 256), -1],
105
+ "stage4": ["En_4", (4, 256, 128, 512), -1],
106
+ "stage5": ["En_5", (4, 512, 256, 512, True), -1],
107
+ "stage6": ["En_6", (4, 512, 256, 512, True), 512],
108
+ "stage5d": ["De_5", (4, 1024, 256, 512, True), 512],
109
+ "stage4d": ["De_4", (4, 1024, 128, 256), 256],
110
+ "stage3d": ["De_3", (5, 512, 64, 128), 128],
111
+ "stage2d": ["De_2", (6, 256, 32, 64), 64],
112
+ "stage1d": ["De_1", (7, 128, 16, 64), 64],
113
+ }
114
+ else:
115
+ raise ValueError("Unknown U^2-Net architecture conf. name")
116
+ elif isinstance(cfg_type, dict):
117
+ layers_cfgs = cfg_type
118
+ else:
119
+ raise ValueError("Unknown U^2-Net architecture conf. type")
120
+ self.out_ch = out_ch
121
+ self._make_layers(layers_cfgs)
122
+
123
+ def forward(self, x):
124
+ sizes = _size_map(x, self.height)
125
+ maps = [] # storage for maps
126
+
127
+ # side saliency map
128
+ def unet(x, height=1):
129
+ if height < 6:
130
+ x1 = getattr(self, f"stage{height}")(x)
131
+ x2 = unet(getattr(self, "downsample")(x1), height + 1)
132
+ x = getattr(self, f"stage{height}d")(torch.cat((x2, x1), 1))
133
+ side(x, height)
134
+ return _upsample_like(x, sizes[height - 1]) if height > 1 else x
135
+ else:
136
+ x = getattr(self, f"stage{height}")(x)
137
+ side(x, height)
138
+ return _upsample_like(x, sizes[height - 1])
139
+
140
+ def side(x, h):
141
+ # side output saliency map (before sigmoid)
142
+ x = getattr(self, f"side{h}")(x)
143
+ x = _upsample_like(x, sizes[1])
144
+ maps.append(x)
145
+
146
+ def fuse():
147
+ # fuse saliency probability maps
148
+ maps.reverse()
149
+ x = torch.cat(maps, 1)
150
+ x = getattr(self, "outconv")(x)
151
+ maps.insert(0, x)
152
+ return [torch.sigmoid(x) for x in maps]
153
+
154
+ unet(x)
155
+ maps = fuse()
156
+ return maps
157
+
158
+ def _make_layers(self, cfgs):
159
+ self.height = int((len(cfgs) + 1) / 2)
160
+ self.add_module("downsample", nn.MaxPool2d(2, stride=2, ceil_mode=True))
161
+ for k, v in cfgs.items():
162
+ # build rsu block
163
+ self.add_module(k, RSU(v[0], *v[1]))
164
+ if v[2] > 0:
165
+ # build side layer
166
+ self.add_module(
167
+ f"side{v[0][-1]}", nn.Conv2d(v[2], self.out_ch, 3, padding=1)
168
+ )
169
+ # build fuse layer
170
+ self.add_module(
171
+ "outconv", nn.Conv2d(int(self.height * self.out_ch), self.out_ch, 1)
172
+ )
carvekit/ml/files/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ carvekit_dir = Path.home().joinpath(".cache/carvekit")
4
+
5
+ carvekit_dir.mkdir(parents=True, exist_ok=True)
6
+
7
+ checkpoints_dir = carvekit_dir.joinpath("checkpoints")
carvekit/ml/files/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (365 Bytes). View file
 
carvekit/ml/files/__pycache__/models_loc.cpython-38.pyc ADDED
Binary file (2.01 kB). View file
 
carvekit/ml/files/models_loc.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Source url: https://github.com/OPHoperHPO/image-background-remove-tool
3
+ Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
4
+ License: Apache License 2.0
5
+ """
6
+ import pathlib
7
+ from carvekit.ml.files import checkpoints_dir
8
+ from carvekit.utils.download_models import downloader
9
+
10
+
11
+ def u2net_full_pretrained() -> pathlib.Path:
12
+ """Returns u2net pretrained model location
13
+
14
+ Returns:
15
+ pathlib.Path to model location
16
+ """
17
+ return downloader("u2net.pth")
18
+
19
+
20
+ def basnet_pretrained() -> pathlib.Path:
21
+ """Returns basnet pretrained model location
22
+
23
+ Returns:
24
+ pathlib.Path to model location
25
+ """
26
+ return downloader("basnet.pth")
27
+
28
+
29
+ def deeplab_pretrained() -> pathlib.Path:
30
+ """Returns basnet pretrained model location
31
+
32
+ Returns:
33
+ pathlib.Path to model location
34
+ """
35
+ return downloader("deeplab.pth")
36
+
37
+
38
+ def fba_pretrained() -> pathlib.Path:
39
+ """Returns basnet pretrained model location
40
+
41
+ Returns:
42
+ pathlib.Path to model location
43
+ """
44
+ return downloader("fba_matting.pth")
45
+
46
+
47
+ def tracer_b7_pretrained() -> pathlib.Path:
48
+ """Returns TRACER with EfficientNet v1 b7 encoder pretrained model location
49
+
50
+ Returns:
51
+ pathlib.Path to model location
52
+ """
53
+ return downloader("tracer_b7.pth")
54
+
55
+
56
+ def tracer_hair_pretrained() -> pathlib.Path:
57
+ """Returns TRACER with EfficientNet v1 b7 encoder model for hair segmentation location
58
+
59
+ Returns:
60
+ pathlib.Path to model location
61
+ """
62
+ return downloader("tracer_hair.pth")
63
+
64
+
65
+ def download_all():
66
+ u2net_full_pretrained()
67
+ fba_pretrained()
68
+ deeplab_pretrained()
69
+ basnet_pretrained()
70
+ tracer_b7_pretrained()