schirrmacher commited on
Commit
150d962
1 Parent(s): 5503d80

Upload folder using huggingface_hub

Browse files
README.md CHANGED
@@ -1,3 +1,74 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ tags:
4
+ - art
5
+ pretty_name: Open Remove Background Model
6
+ ---
7
+
8
+ # Open Remove Background Model (ormbg)
9
+
10
+ This model is a **fully open-source background remover** optimized for images with humans. It is based on [Highly Accurate Dichotomous Image Segmentation research](https://github.com/xuebinqin/DIS).
11
+
12
+ This model is similar to [RMBG-1.4](https://huggingface.co/briaai/RMBG-1.4), but with open training data/process and commercially free to use.
13
+
14
+ ## Inference
15
+
16
+ ```
17
+ test
18
+ ```
19
+
20
+ ## Training
21
+
22
+ The model was trained with the [Human Segmentation Dataset](https://huggingface.co/datasets/schirrmacher/humans).
23
+
24
+ After 10.000 iterations with a single NVIDIA GeForce RTX 4090 the following achievements were made:
25
+
26
+ - Training time: 8 hours
27
+ - Training loss 0.1179
28
+ - Validation loss: 0.1284
29
+ - Maximum F1 score: 0.9928
30
+ - Mean Absolute Error: 0.005
31
+
32
+ Output model: `/models/ormbg.pth`.
33
+
34
+ ## Want to train your own model?
35
+
36
+ Checkout _Highly Accurate Dichotomous Image Segmentation_ code:
37
+
38
+ ```
39
+ git clone https://github.com/xuebinqin/DIS.git
40
+ cd DIS
41
+ ```
42
+
43
+ Follow the installation instructions on https://github.com/xuebinqin/DIS?tab=readme-ov-file#1-clone-this-repo
44
+ Download or create some data ([like this](https://huggingface.co/datasets/schirrmacher/humans)) and place it into the DIS project folder.
45
+
46
+ I am applying the folder structure:
47
+
48
+ - training/im (images)
49
+ - training/gt (ground truth)
50
+ - validation/im (images)
51
+ - validation/gt (ground truth)
52
+
53
+ Apply this git patch for setting the right paths and remove normalization of images:
54
+
55
+ ```
56
+ git apply dis-repo.patch
57
+ ```
58
+
59
+ Start training:
60
+
61
+ ```
62
+ cd IS-Net
63
+ python train_valid_inference_main.py
64
+ ```
65
+
66
+ Export to ONNX (modify paths if needed):
67
+
68
+ ```
69
+ python utils/pth_to_onnx.py
70
+ ```
71
+
72
+ ## Support
73
+
74
+ If you identify edge cases or issues with the model, please contact me!
dis-repo.patch ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diff --git a/IS-Net/Inference.py b/IS-Net/Inference.py
2
+ index 0b2907d..ca8484b 100644
3
+ --- a/IS-Net/Inference.py
4
+ +++ b/IS-Net/Inference.py
5
+ @@ -40,7 +40,7 @@ if __name__ == "__main__":
6
+ im_tensor = torch.tensor(im, dtype=torch.float32).permute(2,0,1)
7
+ im_tensor = F.upsample(torch.unsqueeze(im_tensor,0), input_size, mode="bilinear").type(torch.uint8)
8
+ image = torch.divide(im_tensor,255.0)
9
+ - image = normalize(image,[0.5,0.5,0.5],[1.0,1.0,1.0])
10
+ + #image = normalize(image,[0.5,0.5,0.5],[1.0,1.0,1.0])
11
+
12
+ if torch.cuda.is_available():
13
+ image=image.cuda()
14
+ diff --git a/IS-Net/train_valid_inference_main.py b/IS-Net/train_valid_inference_main.py
15
+ index 375bb26..ad9043c 100644
16
+ --- a/IS-Net/train_valid_inference_main.py
17
+ +++ b/IS-Net/train_valid_inference_main.py
18
+ @@ -536,10 +536,10 @@ def main(train_datasets,
19
+ cache_size = hypar["cache_size"],
20
+ cache_boost = hypar["cache_boost_train"],
21
+ my_transforms = [
22
+ - GOSRandomHFlip(), ## this line can be uncommented for horizontal flip augmetation
23
+ + #GOSRandomHFlip(), ## this line can be uncommented for horizontal flip augmetation
24
+ # GOSResize(hypar["input_size"]),
25
+ # GOSRandomCrop(hypar["crop_size"]), ## this line can be uncommented for randomcrop augmentation
26
+ - GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0]),
27
+ + #GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0]),
28
+ ],
29
+ batch_size = hypar["batch_size_train"],
30
+ shuffle = True)
31
+ @@ -547,7 +547,7 @@ def main(train_datasets,
32
+ cache_size = hypar["cache_size"],
33
+ cache_boost = hypar["cache_boost_train"],
34
+ my_transforms = [
35
+ - GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0]),
36
+ + #GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0]),
37
+ ],
38
+ batch_size = hypar["batch_size_valid"],
39
+ shuffle = False)
40
+ @@ -561,7 +561,7 @@ def main(train_datasets,
41
+ cache_size = hypar["cache_size"],
42
+ cache_boost = hypar["cache_boost_valid"],
43
+ my_transforms = [
44
+ - GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0]),
45
+ + #GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0]),
46
+ # GOSResize(hypar["input_size"])
47
+ ],
48
+ batch_size=hypar["batch_size_valid"],
49
+ @@ -618,19 +618,19 @@ if __name__ == "__main__":
50
+ train_datasets, valid_datasets = [], []
51
+ dataset_1, dataset_1 = {}, {}
52
+
53
+ - dataset_tr = {"name": "DIS5K-TR",
54
+ - "im_dir": "../DIS5K/DIS-TR/im",
55
+ - "gt_dir": "../DIS5K/DIS-TR/gt",
56
+ - "im_ext": ".jpg",
57
+ + dataset_tr = {"name": "training",
58
+ + "im_dir": "../training/im",
59
+ + "gt_dir": "../training/gt",
60
+ + "im_ext": ".png",
61
+ "gt_ext": ".png",
62
+ - "cache_dir":"../DIS5K-Cache/DIS-TR"}
63
+ + "cache_dir":"../cache/training"}
64
+
65
+ - dataset_vd = {"name": "DIS5K-VD",
66
+ - "im_dir": "../DIS5K/DIS-VD/im",
67
+ - "gt_dir": "../DIS5K/DIS-VD/gt",
68
+ - "im_ext": ".jpg",
69
+ + dataset_vd = {"name": "validation",
70
+ + "im_dir": "../validation/im",
71
+ + "gt_dir": "../validation/gt",
72
+ + "im_ext": ".png",
73
+ "gt_ext": ".png",
74
+ - "cache_dir":"../DIS5K-Cache/DIS-VD"}
75
+ + "cache_dir":"../cache/validation"}
76
+
77
+ dataset_te1 = {"name": "DIS5K-TE1",
78
+ "im_dir": "../DIS5K/DIS-TE1/im",
79
+ @@ -685,7 +685,7 @@ if __name__ == "__main__":
80
+ if hypar["mode"] == "train":
81
+ hypar["valid_out_dir"] = "" ## for "train" model leave it as "", for "valid"("inference") mode: set it according to your local directory
82
+ hypar["model_path"] ="../saved_models/IS-Net-test" ## model weights saving (or restoring) path
83
+ - hypar["restore_model"] = "" ## name of the segmentation model weights .pth for resume training process from last stop or for the inferencing
84
+ + hypar["restore_model"] = "isnet-base-model.pth" ## name of the segmentation model weights .pth for resume training process from last stop or for the inferencing
85
+ hypar["start_ite"] = 0 ## start iteration for the training, can be changed to match the restored training process
86
+ hypar["gt_encoder_model"] = ""
87
+ else: ## configure the segmentation output path and the to-be-used model weights path
models/isnet-base-model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9e1aafea58f0b55d0c35077e0ceade6ba1ba2bce372fd4f8f77215391f3fac13
3
+ size 176579397
models/ormbg.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6e91dc17c7cd8eff882d06f293e34b0ca6d33e6f5d71c87b439bd59820f03c49
3
+ size 176180252
models/ormbg.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ba387a8348526875024f59aa97d23af9cacfff77abf4e9af14332bf477c088fa
3
+ size 176719216
utils/__pycache__/isnet.cpython-312.pyc ADDED
Binary file (27.6 kB). View file
 
utils/isnet.py ADDED
@@ -0,0 +1,647 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision import models
4
+ import torch.nn.functional as F
5
+
6
+ # https://github.com/xuebinqin/DIS/blob/main/IS-Net/models/isnet.py
7
+
8
+ bce_loss = nn.BCELoss(size_average=True)
9
+
10
+
11
+ def muti_loss_fusion(preds, target):
12
+ loss0 = 0.0
13
+ loss = 0.0
14
+
15
+ for i in range(0, len(preds)):
16
+ # print("i: ", i, preds[i].shape)
17
+ if preds[i].shape[2] != target.shape[2] or preds[i].shape[3] != target.shape[3]:
18
+ # tmp_target = _upsample_like(target,preds[i])
19
+ tmp_target = F.interpolate(
20
+ target, size=preds[i].size()[2:], mode="bilinear", align_corners=True
21
+ )
22
+ loss = loss + bce_loss(preds[i], tmp_target)
23
+ else:
24
+ loss = loss + bce_loss(preds[i], target)
25
+ if i == 0:
26
+ loss0 = loss
27
+ return loss0, loss
28
+
29
+
30
+ fea_loss = nn.MSELoss(size_average=True)
31
+ kl_loss = nn.KLDivLoss(size_average=True)
32
+ l1_loss = nn.L1Loss(size_average=True)
33
+ smooth_l1_loss = nn.SmoothL1Loss(size_average=True)
34
+
35
+
36
+ def muti_loss_fusion_kl(preds, target, dfs, fs, mode="MSE"):
37
+ loss0 = 0.0
38
+ loss = 0.0
39
+
40
+ for i in range(0, len(preds)):
41
+ # print("i: ", i, preds[i].shape)
42
+ if preds[i].shape[2] != target.shape[2] or preds[i].shape[3] != target.shape[3]:
43
+ # tmp_target = _upsample_like(target,preds[i])
44
+ tmp_target = F.interpolate(
45
+ target, size=preds[i].size()[2:], mode="bilinear", align_corners=True
46
+ )
47
+ loss = loss + bce_loss(preds[i], tmp_target)
48
+ else:
49
+ loss = loss + bce_loss(preds[i], target)
50
+ if i == 0:
51
+ loss0 = loss
52
+
53
+ for i in range(0, len(dfs)):
54
+ if mode == "MSE":
55
+ loss = loss + fea_loss(
56
+ dfs[i], fs[i]
57
+ ) ### add the mse loss of features as additional constraints
58
+ # print("fea_loss: ", fea_loss(dfs[i],fs[i]).item())
59
+ elif mode == "KL":
60
+ loss = loss + kl_loss(F.log_softmax(dfs[i], dim=1), F.softmax(fs[i], dim=1))
61
+ # print("kl_loss: ", kl_loss(F.log_softmax(dfs[i],dim=1),F.softmax(fs[i],dim=1)).item())
62
+ elif mode == "MAE":
63
+ loss = loss + l1_loss(dfs[i], fs[i])
64
+ # print("ls_loss: ", l1_loss(dfs[i],fs[i]))
65
+ elif mode == "SmoothL1":
66
+ loss = loss + smooth_l1_loss(dfs[i], fs[i])
67
+ # print("SmoothL1: ", smooth_l1_loss(dfs[i],fs[i]).item())
68
+
69
+ return loss0, loss
70
+
71
+
72
+ class REBNCONV(nn.Module):
73
+ def __init__(self, in_ch=3, out_ch=3, dirate=1, stride=1):
74
+ super(REBNCONV, self).__init__()
75
+
76
+ self.conv_s1 = nn.Conv2d(
77
+ in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate, stride=stride
78
+ )
79
+ self.bn_s1 = nn.BatchNorm2d(out_ch)
80
+ self.relu_s1 = nn.ReLU(inplace=True)
81
+
82
+ def forward(self, x):
83
+
84
+ hx = x
85
+ xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
86
+
87
+ return xout
88
+
89
+
90
+ ## upsample tensor 'src' to have the same spatial size with tensor 'tar'
91
+ def _upsample_like(src, tar):
92
+
93
+ src = F.upsample(src, size=tar.shape[2:], mode="bilinear")
94
+
95
+ return src
96
+
97
+
98
+ ### RSU-7 ###
99
+ class RSU7(nn.Module):
100
+
101
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512):
102
+ super(RSU7, self).__init__()
103
+
104
+ self.in_ch = in_ch
105
+ self.mid_ch = mid_ch
106
+ self.out_ch = out_ch
107
+
108
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) ## 1 -> 1/2
109
+
110
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
111
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
112
+
113
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
114
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
115
+
116
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
117
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
118
+
119
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
120
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
121
+
122
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
123
+ self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
124
+
125
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
126
+
127
+ self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
128
+
129
+ self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
130
+ self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
131
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
132
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
133
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
134
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
135
+
136
+ def forward(self, x):
137
+ b, c, h, w = x.shape
138
+
139
+ hx = x
140
+ hxin = self.rebnconvin(hx)
141
+
142
+ hx1 = self.rebnconv1(hxin)
143
+ hx = self.pool1(hx1)
144
+
145
+ hx2 = self.rebnconv2(hx)
146
+ hx = self.pool2(hx2)
147
+
148
+ hx3 = self.rebnconv3(hx)
149
+ hx = self.pool3(hx3)
150
+
151
+ hx4 = self.rebnconv4(hx)
152
+ hx = self.pool4(hx4)
153
+
154
+ hx5 = self.rebnconv5(hx)
155
+ hx = self.pool5(hx5)
156
+
157
+ hx6 = self.rebnconv6(hx)
158
+
159
+ hx7 = self.rebnconv7(hx6)
160
+
161
+ hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
162
+ hx6dup = _upsample_like(hx6d, hx5)
163
+
164
+ hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
165
+ hx5dup = _upsample_like(hx5d, hx4)
166
+
167
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
168
+ hx4dup = _upsample_like(hx4d, hx3)
169
+
170
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
171
+ hx3dup = _upsample_like(hx3d, hx2)
172
+
173
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
174
+ hx2dup = _upsample_like(hx2d, hx1)
175
+
176
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
177
+
178
+ return hx1d + hxin
179
+
180
+
181
+ ### RSU-6 ###
182
+ class RSU6(nn.Module):
183
+
184
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
185
+ super(RSU6, self).__init__()
186
+
187
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
188
+
189
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
190
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
191
+
192
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
193
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
194
+
195
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
196
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
197
+
198
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
199
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
200
+
201
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
202
+
203
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
204
+
205
+ self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
206
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
207
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
208
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
209
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
210
+
211
+ def forward(self, x):
212
+
213
+ hx = x
214
+
215
+ hxin = self.rebnconvin(hx)
216
+
217
+ hx1 = self.rebnconv1(hxin)
218
+ hx = self.pool1(hx1)
219
+
220
+ hx2 = self.rebnconv2(hx)
221
+ hx = self.pool2(hx2)
222
+
223
+ hx3 = self.rebnconv3(hx)
224
+ hx = self.pool3(hx3)
225
+
226
+ hx4 = self.rebnconv4(hx)
227
+ hx = self.pool4(hx4)
228
+
229
+ hx5 = self.rebnconv5(hx)
230
+
231
+ hx6 = self.rebnconv6(hx5)
232
+
233
+ hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
234
+ hx5dup = _upsample_like(hx5d, hx4)
235
+
236
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
237
+ hx4dup = _upsample_like(hx4d, hx3)
238
+
239
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
240
+ hx3dup = _upsample_like(hx3d, hx2)
241
+
242
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
243
+ hx2dup = _upsample_like(hx2d, hx1)
244
+
245
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
246
+
247
+ return hx1d + hxin
248
+
249
+
250
+ ### RSU-5 ###
251
+ class RSU5(nn.Module):
252
+
253
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
254
+ super(RSU5, self).__init__()
255
+
256
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
257
+
258
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
259
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
260
+
261
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
262
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
263
+
264
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
265
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
266
+
267
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
268
+
269
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
270
+
271
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
272
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
273
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
274
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
275
+
276
+ def forward(self, x):
277
+
278
+ hx = x
279
+
280
+ hxin = self.rebnconvin(hx)
281
+
282
+ hx1 = self.rebnconv1(hxin)
283
+ hx = self.pool1(hx1)
284
+
285
+ hx2 = self.rebnconv2(hx)
286
+ hx = self.pool2(hx2)
287
+
288
+ hx3 = self.rebnconv3(hx)
289
+ hx = self.pool3(hx3)
290
+
291
+ hx4 = self.rebnconv4(hx)
292
+
293
+ hx5 = self.rebnconv5(hx4)
294
+
295
+ hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
296
+ hx4dup = _upsample_like(hx4d, hx3)
297
+
298
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
299
+ hx3dup = _upsample_like(hx3d, hx2)
300
+
301
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
302
+ hx2dup = _upsample_like(hx2d, hx1)
303
+
304
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
305
+
306
+ return hx1d + hxin
307
+
308
+
309
+ ### RSU-4 ###
310
+ class RSU4(nn.Module):
311
+
312
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
313
+ super(RSU4, self).__init__()
314
+
315
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
316
+
317
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
318
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
319
+
320
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
321
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
322
+
323
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
324
+
325
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
326
+
327
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
328
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
329
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
330
+
331
+ def forward(self, x):
332
+
333
+ hx = x
334
+
335
+ hxin = self.rebnconvin(hx)
336
+
337
+ hx1 = self.rebnconv1(hxin)
338
+ hx = self.pool1(hx1)
339
+
340
+ hx2 = self.rebnconv2(hx)
341
+ hx = self.pool2(hx2)
342
+
343
+ hx3 = self.rebnconv3(hx)
344
+
345
+ hx4 = self.rebnconv4(hx3)
346
+
347
+ hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
348
+ hx3dup = _upsample_like(hx3d, hx2)
349
+
350
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
351
+ hx2dup = _upsample_like(hx2d, hx1)
352
+
353
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
354
+
355
+ return hx1d + hxin
356
+
357
+
358
+ ### RSU-4F ###
359
+ class RSU4F(nn.Module):
360
+
361
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
362
+ super(RSU4F, self).__init__()
363
+
364
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
365
+
366
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
367
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
368
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
369
+
370
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
371
+
372
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
373
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
374
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
375
+
376
+ def forward(self, x):
377
+
378
+ hx = x
379
+
380
+ hxin = self.rebnconvin(hx)
381
+
382
+ hx1 = self.rebnconv1(hxin)
383
+ hx2 = self.rebnconv2(hx1)
384
+ hx3 = self.rebnconv3(hx2)
385
+
386
+ hx4 = self.rebnconv4(hx3)
387
+
388
+ hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
389
+ hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
390
+ hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
391
+
392
+ return hx1d + hxin
393
+
394
+
395
+ class myrebnconv(nn.Module):
396
+ def __init__(
397
+ self,
398
+ in_ch=3,
399
+ out_ch=1,
400
+ kernel_size=3,
401
+ stride=1,
402
+ padding=1,
403
+ dilation=1,
404
+ groups=1,
405
+ ):
406
+ super(myrebnconv, self).__init__()
407
+
408
+ self.conv = nn.Conv2d(
409
+ in_ch,
410
+ out_ch,
411
+ kernel_size=kernel_size,
412
+ stride=stride,
413
+ padding=padding,
414
+ dilation=dilation,
415
+ groups=groups,
416
+ )
417
+ self.bn = nn.BatchNorm2d(out_ch)
418
+ self.rl = nn.ReLU(inplace=True)
419
+
420
+ def forward(self, x):
421
+ return self.rl(self.bn(self.conv(x)))
422
+
423
+
424
+ class ISNetGTEncoder(nn.Module):
425
+
426
+ def __init__(self, in_ch=1, out_ch=1):
427
+ super(ISNetGTEncoder, self).__init__()
428
+
429
+ self.conv_in = myrebnconv(
430
+ in_ch, 16, 3, stride=2, padding=1
431
+ ) # nn.Conv2d(in_ch,64,3,stride=2,padding=1)
432
+
433
+ self.stage1 = RSU7(16, 16, 64)
434
+ self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
435
+
436
+ self.stage2 = RSU6(64, 16, 64)
437
+ self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
438
+
439
+ self.stage3 = RSU5(64, 32, 128)
440
+ self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
441
+
442
+ self.stage4 = RSU4(128, 32, 256)
443
+ self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
444
+
445
+ self.stage5 = RSU4F(256, 64, 512)
446
+ self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
447
+
448
+ self.stage6 = RSU4F(512, 64, 512)
449
+
450
+ self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
451
+ self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
452
+ self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
453
+ self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
454
+ self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
455
+ self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
456
+
457
+ def compute_loss(self, preds, targets):
458
+
459
+ return muti_loss_fusion(preds, targets)
460
+
461
+ def forward(self, x):
462
+
463
+ hx = x
464
+
465
+ hxin = self.conv_in(hx)
466
+ # hx = self.pool_in(hxin)
467
+
468
+ # stage 1
469
+ hx1 = self.stage1(hxin)
470
+ hx = self.pool12(hx1)
471
+
472
+ # stage 2
473
+ hx2 = self.stage2(hx)
474
+ hx = self.pool23(hx2)
475
+
476
+ # stage 3
477
+ hx3 = self.stage3(hx)
478
+ hx = self.pool34(hx3)
479
+
480
+ # stage 4
481
+ hx4 = self.stage4(hx)
482
+ hx = self.pool45(hx4)
483
+
484
+ # stage 5
485
+ hx5 = self.stage5(hx)
486
+ hx = self.pool56(hx5)
487
+
488
+ # stage 6
489
+ hx6 = self.stage6(hx)
490
+
491
+ # side output
492
+ d1 = self.side1(hx1)
493
+ d1 = _upsample_like(d1, x)
494
+
495
+ d2 = self.side2(hx2)
496
+ d2 = _upsample_like(d2, x)
497
+
498
+ d3 = self.side3(hx3)
499
+ d3 = _upsample_like(d3, x)
500
+
501
+ d4 = self.side4(hx4)
502
+ d4 = _upsample_like(d4, x)
503
+
504
+ d5 = self.side5(hx5)
505
+ d5 = _upsample_like(d5, x)
506
+
507
+ d6 = self.side6(hx6)
508
+ d6 = _upsample_like(d6, x)
509
+
510
+ # d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
511
+
512
+ return [
513
+ F.sigmoid(d1),
514
+ F.sigmoid(d2),
515
+ F.sigmoid(d3),
516
+ F.sigmoid(d4),
517
+ F.sigmoid(d5),
518
+ F.sigmoid(d6),
519
+ ], [hx1, hx2, hx3, hx4, hx5, hx6]
520
+
521
+
522
+ class ISNetDIS(nn.Module):
523
+
524
+ def __init__(self, in_ch=3, out_ch=1):
525
+ super(ISNetDIS, self).__init__()
526
+
527
+ self.conv_in = nn.Conv2d(in_ch, 64, 3, stride=2, padding=1)
528
+ self.pool_in = nn.MaxPool2d(2, stride=2, ceil_mode=True)
529
+
530
+ self.stage1 = RSU7(64, 32, 64)
531
+ self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
532
+
533
+ self.stage2 = RSU6(64, 32, 128)
534
+ self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
535
+
536
+ self.stage3 = RSU5(128, 64, 256)
537
+ self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
538
+
539
+ self.stage4 = RSU4(256, 128, 512)
540
+ self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
541
+
542
+ self.stage5 = RSU4F(512, 256, 512)
543
+ self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
544
+
545
+ self.stage6 = RSU4F(512, 256, 512)
546
+
547
+ # decoder
548
+ self.stage5d = RSU4F(1024, 256, 512)
549
+ self.stage4d = RSU4(1024, 128, 256)
550
+ self.stage3d = RSU5(512, 64, 128)
551
+ self.stage2d = RSU6(256, 32, 64)
552
+ self.stage1d = RSU7(128, 16, 64)
553
+
554
+ self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
555
+ self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
556
+ self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
557
+ self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
558
+ self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
559
+ self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
560
+
561
+ # self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
562
+
563
+ def compute_loss_kl(self, preds, targets, dfs, fs, mode="MSE"):
564
+
565
+ # return muti_loss_fusion(preds,targets)
566
+ return muti_loss_fusion_kl(preds, targets, dfs, fs, mode=mode)
567
+
568
+ def compute_loss(self, preds, targets):
569
+
570
+ # return muti_loss_fusion(preds,targets)
571
+ return muti_loss_fusion(preds, targets)
572
+
573
+ def forward(self, x):
574
+
575
+ hx = x
576
+
577
+ hxin = self.conv_in(hx)
578
+ # hx = self.pool_in(hxin)
579
+
580
+ # stage 1
581
+ hx1 = self.stage1(hxin)
582
+ hx = self.pool12(hx1)
583
+
584
+ # stage 2
585
+ hx2 = self.stage2(hx)
586
+ hx = self.pool23(hx2)
587
+
588
+ # stage 3
589
+ hx3 = self.stage3(hx)
590
+ hx = self.pool34(hx3)
591
+
592
+ # stage 4
593
+ hx4 = self.stage4(hx)
594
+ hx = self.pool45(hx4)
595
+
596
+ # stage 5
597
+ hx5 = self.stage5(hx)
598
+ hx = self.pool56(hx5)
599
+
600
+ # stage 6
601
+ hx6 = self.stage6(hx)
602
+ hx6up = _upsample_like(hx6, hx5)
603
+
604
+ # -------------------- decoder --------------------
605
+ hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
606
+ hx5dup = _upsample_like(hx5d, hx4)
607
+
608
+ hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
609
+ hx4dup = _upsample_like(hx4d, hx3)
610
+
611
+ hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
612
+ hx3dup = _upsample_like(hx3d, hx2)
613
+
614
+ hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
615
+ hx2dup = _upsample_like(hx2d, hx1)
616
+
617
+ hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
618
+
619
+ # side output
620
+ d1 = self.side1(hx1d)
621
+ d1 = _upsample_like(d1, x)
622
+
623
+ d2 = self.side2(hx2d)
624
+ d2 = _upsample_like(d2, x)
625
+
626
+ d3 = self.side3(hx3d)
627
+ d3 = _upsample_like(d3, x)
628
+
629
+ d4 = self.side4(hx4d)
630
+ d4 = _upsample_like(d4, x)
631
+
632
+ d5 = self.side5(hx5d)
633
+ d5 = _upsample_like(d5, x)
634
+
635
+ d6 = self.side6(hx6)
636
+ d6 = _upsample_like(d6, x)
637
+
638
+ # d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
639
+
640
+ return [
641
+ F.sigmoid(d1),
642
+ F.sigmoid(d2),
643
+ F.sigmoid(d3),
644
+ F.sigmoid(d4),
645
+ F.sigmoid(d5),
646
+ F.sigmoid(d6),
647
+ ], [hx1d, hx2d, hx3d, hx4d, hx5d, hx6]
utils/pth_to_onnx.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import argparse
3
+ from isnet import ISNetDIS
4
+
5
+
6
+ def export_to_onnx(model_path, onnx_path):
7
+
8
+ net = ISNetDIS()
9
+
10
+ if torch.cuda.is_available():
11
+ net.load_state_dict(torch.load(model_path))
12
+ net = net.cuda()
13
+ else:
14
+ net.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
15
+
16
+ net.eval()
17
+
18
+ # Create a dummy input tensor. The size should match the model's input size.
19
+ # Adjust the dimensions as necessary; here it is assumed the input is a 3-channel image.
20
+ dummy_input = torch.randn(
21
+ 1,
22
+ 3,
23
+ 1024,
24
+ 1024,
25
+ device="cuda" if torch.cuda.is_available() else "cpu",
26
+ )
27
+
28
+ torch.onnx.export(
29
+ net,
30
+ dummy_input,
31
+ onnx_path,
32
+ export_params=True,
33
+ opset_version=10,
34
+ do_constant_folding=True,
35
+ input_names=["input"],
36
+ output_names=["output"],
37
+ )
38
+
39
+
40
+ if __name__ == "__main__":
41
+ parser = argparse.ArgumentParser(
42
+ description="Export a trained model to ONNX format."
43
+ )
44
+ parser.add_argument(
45
+ "--model_path",
46
+ type=str,
47
+ default="./models/ormbg.pth",
48
+ help="The path to the trained model file.",
49
+ )
50
+ parser.add_argument(
51
+ "--onnx_path",
52
+ type=str,
53
+ default="./models/example.onnx",
54
+ help="The path where the ONNX model will be saved.",
55
+ )
56
+
57
+ args = parser.parse_args()
58
+
59
+ export_to_onnx(args.model_path, args.onnx_path)