Superlang commited on
Commit
4d0b7ae
1 Parent(s): 50a8070
.gitignore ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ .idea
10
+
11
+ *.pth
12
+
13
+ # Distribution / packaging
14
+ .Python
15
+ build/
16
+ develop-eggs/
17
+ dist/
18
+ downloads/
19
+ eggs/
20
+ .eggs/
21
+ lib/
22
+ lib64/
23
+ parts/
24
+ sdist/
25
+ var/
26
+ wheels/
27
+ share/python-wheels/
28
+ *.egg-info/
29
+ .installed.cfg
30
+ *.egg
31
+ MANIFEST
32
+
33
+ # PyInstaller
34
+ # Usually these files are written by a python script from a template
35
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
36
+ *.manifest
37
+ *.spec
38
+
39
+ # Installer logs
40
+ pip-log.txt
41
+ pip-delete-this-directory.txt
42
+
43
+ # Unit test / coverage reports
44
+ htmlcov/
45
+ .tox/
46
+ .nox/
47
+ .coverage
48
+ .coverage.*
49
+ .cache
50
+ nosetests.xml
51
+ coverage.xml
52
+ *.cover
53
+ *.py,cover
54
+ .hypothesis/
55
+ .pytest_cache/
56
+ cover/
57
+
58
+ # Translations
59
+ *.mo
60
+ *.pot
61
+
62
+ # Django stuff:
63
+ *.log
64
+ local_settings.py
65
+ db.sqlite3
66
+ db.sqlite3-journal
67
+
68
+ # Flask stuff:
69
+ instance/
70
+ .webassets-cache
71
+
72
+ # Scrapy stuff:
73
+ .scrapy
74
+
75
+ # Sphinx documentation
76
+ docs/_build/
77
+
78
+ # PyBuilder
79
+ .pybuilder/
80
+ target/
81
+
82
+ # Jupyter Notebook
83
+ .ipynb_checkpoints
84
+
85
+ # IPython
86
+ profile_default/
87
+ ipython_config.py
88
+
89
+ # pyenv
90
+ # For a library or package, you might want to ignore these files since the code is
91
+ # intended to run in multiple environments; otherwise, check them in:
92
+ # .python-version
93
+
94
+ # pipenv
95
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
96
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
97
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
98
+ # install all needed dependencies.
99
+ #Pipfile.lock
100
+
101
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
102
+ __pypackages__/
103
+
104
+ # Celery stuff
105
+ celerybeat-schedule
106
+ celerybeat.pid
107
+
108
+ # SageMath parsed files
109
+ *.sage.py
110
+
111
+ # Environments
112
+ .env
113
+ .venv
114
+ env/
115
+ venv/
116
+ ENV/
117
+ env.bak/
118
+ venv.bak/
119
+
120
+ # Spyder project settings
121
+ .spyderproject
122
+ .spyproject
123
+
124
+ # Rope project settings
125
+ .ropeproject
126
+
127
+ # mkdocs documentation
128
+ /site
129
+
130
+ # mypy
131
+ .mypy_cache/
132
+ .dmypy.json
133
+ dmypy.json
134
+
135
+ # Pyre type checker
136
+ .pyre/
137
+
138
+ # pytype static type analyzer
139
+ .pytype/
140
+
141
+ # Cython debug symbols
142
+ cython_debug/
DIS/Inference.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import numpy as np
4
+ from skimage import io
5
+ import time
6
+ from glob import glob
7
+ from tqdm import tqdm
8
+
9
+ import torch, gc
10
+ import torch.nn as nn
11
+ from torch.autograd import Variable
12
+ import torch.optim as optim
13
+ import torch.nn.functional as F
14
+ from torchvision.transforms.functional import normalize
15
+
16
+ from models import *
17
+
18
+
19
+ if __name__ == "__main__":
20
+ dataset_path="../demo_datasets/your_dataset" #Your dataset path
21
+ model_path="../saved_models/IS-Net/isnet-general-use.pth" # the model path
22
+ result_path="../demo_datasets/your_dataset_result" #The folder path that you want to save the results
23
+ input_size=[1024,1024]
24
+ net=ISNetDIS()
25
+
26
+ if torch.cuda.is_available():
27
+ net.load_state_dict(torch.load(model_path))
28
+ net=net.cuda()
29
+ else:
30
+ net.load_state_dict(torch.load(model_path,map_location="cpu"))
31
+ net.eval()
32
+ im_list = glob(dataset_path+"/*.jpg")+glob(dataset_path+"/*.JPG")+glob(dataset_path+"/*.jpeg")+glob(dataset_path+"/*.JPEG")+glob(dataset_path+"/*.png")+glob(dataset_path+"/*.PNG")+glob(dataset_path+"/*.bmp")+glob(dataset_path+"/*.BMP")+glob(dataset_path+"/*.tiff")+glob(dataset_path+"/*.TIFF")
33
+ with torch.no_grad():
34
+ for i, im_path in tqdm(enumerate(im_list), total=len(im_list)):
35
+ print("im_path: ", im_path)
36
+ im = io.imread(im_path)
37
+ if len(im.shape) < 3:
38
+ im = im[:, :, np.newaxis]
39
+ im_shp=im.shape[0:2]
40
+ im_tensor = torch.tensor(im, dtype=torch.float32).permute(2,0,1)
41
+ im_tensor = F.upsample(torch.unsqueeze(im_tensor,0), input_size, mode="bilinear").type(torch.uint8)
42
+ image = torch.divide(im_tensor,255.0)
43
+ image = normalize(image,[0.5,0.5,0.5],[1.0,1.0,1.0])
44
+
45
+ if torch.cuda.is_available():
46
+ image=image.cuda()
47
+ result=net(image)
48
+ result=torch.squeeze(F.upsample(result[0][0],im_shp,mode='bilinear'),0)
49
+ ma = torch.max(result)
50
+ mi = torch.min(result)
51
+ result = (result-mi)/(ma-mi)
52
+ im_name=im_path.split('/')[-1].split('.')[0]
53
+ io.imsave(os.path.join(result_path,im_name+".png"),(result*255).permute(1,2,0).cpu().data.numpy().astype(np.uint8))
DIS/IsNetPipeLine.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ reference: https://github.com/xuebinqin/DIS
3
+ """
4
+
5
+ import PIL.Image
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from PIL import Image
10
+ from torch import nn
11
+ from torch.autograd import Variable
12
+ from torchvision import transforms
13
+ from torchvision.transforms.functional import normalize
14
+
15
+ from .models import ISNetDIS
16
+
17
+ # Helpers
18
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
19
+
20
+
21
+ class GOSNormalize(object):
22
+ """
23
+ Normalize the Image using torch.transforms
24
+ """
25
+
26
+ def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
27
+ self.mean = mean
28
+ self.std = std
29
+
30
+ def __call__(self, image):
31
+ image = normalize(image, self.mean, self.std)
32
+ return image
33
+
34
+
35
+ def im_preprocess(im, size):
36
+ if len(im.shape) < 3:
37
+ im = im[:, :, np.newaxis]
38
+ if im.shape[2] == 1:
39
+ im = np.repeat(im, 3, axis=2)
40
+ im_tensor = torch.tensor(im.copy(), dtype=torch.float32)
41
+ im_tensor = torch.transpose(torch.transpose(im_tensor, 1, 2), 0, 1)
42
+ if len(size) < 2:
43
+ return im_tensor, im.shape[0:2]
44
+ else:
45
+ im_tensor = torch.unsqueeze(im_tensor, 0)
46
+ im_tensor = F.upsample(im_tensor, size, mode="bilinear")
47
+ im_tensor = torch.squeeze(im_tensor, 0)
48
+
49
+ return im_tensor.type(torch.uint8), im.shape[0:2]
50
+
51
+
52
+ class IsNetPipeLine:
53
+ def __init__(self, model_path=None, model_digit="full"):
54
+ self.model_digit = model_digit
55
+ self.model = ISNetDIS()
56
+ self.cache_size = [1024, 1024]
57
+ self.transform = transforms.Compose([
58
+ GOSNormalize([0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
59
+ ])
60
+
61
+ # Build Model
62
+ self.build_model(model_path)
63
+
64
+ def load_image(self, image: PIL.Image.Image):
65
+ im = np.array(image.convert("RGB"))
66
+ im, im_shp = im_preprocess(im, self.cache_size)
67
+ im = torch.divide(im, 255.0)
68
+ shape = torch.from_numpy(np.array(im_shp))
69
+ return self.transform(im).unsqueeze(0), shape.unsqueeze(0) # make a batch of image, shape
70
+
71
+ def build_model(self, model_path=None):
72
+ if model_path is not None:
73
+ self.model.load_state_dict(torch.load(model_path, map_location=device))
74
+
75
+ # convert to half precision
76
+ if self.model_digit == "half":
77
+ self.model.half()
78
+ for layer in self.model.modules():
79
+ if isinstance(layer, nn.BatchNorm2d):
80
+ layer.float()
81
+ self.model.to(device)
82
+ self.model.eval()
83
+
84
+ def __call__(self, image: PIL.Image.Image):
85
+ image_tensor, orig_size = self.load_image(image)
86
+ mask = self.predict(image_tensor, orig_size)
87
+
88
+ pil_mask = Image.fromarray(mask).convert('L')
89
+ im_rgb = image.convert("RGB")
90
+
91
+ im_rgba = im_rgb.copy()
92
+ im_rgba.putalpha(pil_mask)
93
+
94
+ return [im_rgba, pil_mask]
95
+
96
+ def predict(self, inputs_val: torch.Tensor, shapes_val):
97
+ """
98
+ Given an Image, predict the mask
99
+ """
100
+
101
+ if self.model_digit == "full":
102
+ inputs_val = inputs_val.type(torch.FloatTensor)
103
+ else:
104
+ inputs_val = inputs_val.type(torch.HalfTensor)
105
+
106
+ inputs_val_v = Variable(inputs_val, requires_grad=False).to(device) # wrap inputs in Variable
107
+
108
+ ds_val = self.model(inputs_val_v)[0] # list of 6 results
109
+
110
+ # B x 1 x H x W # we want the first one which is the most accurate prediction
111
+ pred_val = ds_val[0][0, :, :, :]
112
+
113
+ # recover the prediction spatial size to the orignal image size
114
+ pred_val = torch.squeeze(
115
+ F.upsample(torch.unsqueeze(pred_val, 0), (shapes_val[0][0], shapes_val[0][1]), mode='bilinear'))
116
+
117
+ ma = torch.max(pred_val)
118
+ mi = torch.min(pred_val)
119
+ pred_val = (pred_val - mi) / (ma - mi) # max = 1
120
+
121
+ if device == 'cuda':
122
+ torch.cuda.empty_cache()
123
+ return (pred_val.detach().cpu().numpy() * 255).astype(np.uint8) # it is the mask we need
124
+
125
+
126
+ # a = IsNetPipeLine(model_path="save_models/isnet.pth")
127
+ # input_image = Image.open("image_0mx.png")
128
+ # rgb, mask = a(input_image)
129
+ #
130
+ # rgb.save("rgb.png")
131
+ # mask.save("mask.png")
DIS/models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .isnet import ISNetGTEncoder, ISNetDIS
DIS/models/isnet.py ADDED
@@ -0,0 +1,608 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ bce_loss = nn.BCELoss(size_average=True)
6
+
7
+
8
+ def muti_loss_fusion(preds, target):
9
+ loss0 = 0.0
10
+ loss = 0.0
11
+
12
+ for i in range(0, len(preds)):
13
+ # print("i: ", i, preds[i].shape)
14
+ if (preds[i].shape[2] != target.shape[2] or preds[i].shape[3] != target.shape[3]):
15
+ # tmp_target = _upsample_like(target,preds[i])
16
+ tmp_target = F.interpolate(target, size=preds[i].size()[2:], mode='bilinear', align_corners=True)
17
+ loss = loss + bce_loss(preds[i], tmp_target)
18
+ else:
19
+ loss = loss + bce_loss(preds[i], target)
20
+ if (i == 0):
21
+ loss0 = loss
22
+ return loss0, loss
23
+
24
+
25
+ fea_loss = nn.MSELoss(size_average=True)
26
+ kl_loss = nn.KLDivLoss(size_average=True)
27
+ l1_loss = nn.L1Loss(size_average=True)
28
+ smooth_l1_loss = nn.SmoothL1Loss(size_average=True)
29
+
30
+
31
+ def muti_loss_fusion_kl(preds, target, dfs, fs, mode='MSE'):
32
+ loss0 = 0.0
33
+ loss = 0.0
34
+
35
+ for i in range(0, len(preds)):
36
+ # print("i: ", i, preds[i].shape)
37
+ if (preds[i].shape[2] != target.shape[2] or preds[i].shape[3] != target.shape[3]):
38
+ # tmp_target = _upsample_like(target,preds[i])
39
+ tmp_target = F.interpolate(target, size=preds[i].size()[2:], mode='bilinear', align_corners=True)
40
+ loss = loss + bce_loss(preds[i], tmp_target)
41
+ else:
42
+ loss = loss + bce_loss(preds[i], target)
43
+ if (i == 0):
44
+ loss0 = loss
45
+
46
+ for i in range(0, len(dfs)):
47
+ if (mode == 'MSE'):
48
+ loss = loss + fea_loss(dfs[i], fs[i]) ### add the mse loss of features as additional constraints
49
+ # print("fea_loss: ", fea_loss(dfs[i],fs[i]).item())
50
+ elif (mode == 'KL'):
51
+ loss = loss + kl_loss(F.log_softmax(dfs[i], dim=1), F.softmax(fs[i], dim=1))
52
+ # print("kl_loss: ", kl_loss(F.log_softmax(dfs[i],dim=1),F.softmax(fs[i],dim=1)).item())
53
+ elif (mode == 'MAE'):
54
+ loss = loss + l1_loss(dfs[i], fs[i])
55
+ # print("ls_loss: ", l1_loss(dfs[i],fs[i]))
56
+ elif (mode == 'SmoothL1'):
57
+ loss = loss + smooth_l1_loss(dfs[i], fs[i])
58
+ # print("SmoothL1: ", smooth_l1_loss(dfs[i],fs[i]).item())
59
+
60
+ return loss0, loss
61
+
62
+
63
+ class REBNCONV(nn.Module):
64
+ def __init__(self, in_ch=3, out_ch=3, dirate=1, stride=1):
65
+ super(REBNCONV, self).__init__()
66
+
67
+ self.conv_s1 = nn.Conv2d(in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate, stride=stride)
68
+ self.bn_s1 = nn.BatchNorm2d(out_ch)
69
+ self.relu_s1 = nn.ReLU(inplace=True)
70
+
71
+ def forward(self, x):
72
+ hx = x
73
+ xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
74
+
75
+ return xout
76
+
77
+
78
+ ## upsample tensor 'src' to have the same spatial size with tensor 'tar'
79
+ def _upsample_like(src, tar):
80
+ src = F.upsample(src, size=tar.shape[2:], mode='bilinear')
81
+
82
+ return src
83
+
84
+
85
+ ### RSU-7 ###
86
+ class RSU7(nn.Module):
87
+
88
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512):
89
+ super(RSU7, self).__init__()
90
+
91
+ self.in_ch = in_ch
92
+ self.mid_ch = mid_ch
93
+ self.out_ch = out_ch
94
+
95
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) ## 1 -> 1/2
96
+
97
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
98
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
99
+
100
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
101
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
102
+
103
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
104
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
105
+
106
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
107
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
108
+
109
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
110
+ self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
111
+
112
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
113
+
114
+ self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
115
+
116
+ self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
117
+ self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
118
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
119
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
120
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
121
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
122
+
123
+ def forward(self, x):
124
+ b, c, h, w = x.shape
125
+
126
+ hx = x
127
+ hxin = self.rebnconvin(hx)
128
+
129
+ hx1 = self.rebnconv1(hxin)
130
+ hx = self.pool1(hx1)
131
+
132
+ hx2 = self.rebnconv2(hx)
133
+ hx = self.pool2(hx2)
134
+
135
+ hx3 = self.rebnconv3(hx)
136
+ hx = self.pool3(hx3)
137
+
138
+ hx4 = self.rebnconv4(hx)
139
+ hx = self.pool4(hx4)
140
+
141
+ hx5 = self.rebnconv5(hx)
142
+ hx = self.pool5(hx5)
143
+
144
+ hx6 = self.rebnconv6(hx)
145
+
146
+ hx7 = self.rebnconv7(hx6)
147
+
148
+ hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
149
+ hx6dup = _upsample_like(hx6d, hx5)
150
+
151
+ hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
152
+ hx5dup = _upsample_like(hx5d, hx4)
153
+
154
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
155
+ hx4dup = _upsample_like(hx4d, hx3)
156
+
157
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
158
+ hx3dup = _upsample_like(hx3d, hx2)
159
+
160
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
161
+ hx2dup = _upsample_like(hx2d, hx1)
162
+
163
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
164
+
165
+ return hx1d + hxin
166
+
167
+
168
+ ### RSU-6 ###
169
+ class RSU6(nn.Module):
170
+
171
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
172
+ super(RSU6, self).__init__()
173
+
174
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
175
+
176
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
177
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
178
+
179
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
180
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
181
+
182
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
183
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
184
+
185
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
186
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
187
+
188
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
189
+
190
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
191
+
192
+ self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
193
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
194
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
195
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
196
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
197
+
198
+ def forward(self, x):
199
+ hx = x
200
+
201
+ hxin = self.rebnconvin(hx)
202
+
203
+ hx1 = self.rebnconv1(hxin)
204
+ hx = self.pool1(hx1)
205
+
206
+ hx2 = self.rebnconv2(hx)
207
+ hx = self.pool2(hx2)
208
+
209
+ hx3 = self.rebnconv3(hx)
210
+ hx = self.pool3(hx3)
211
+
212
+ hx4 = self.rebnconv4(hx)
213
+ hx = self.pool4(hx4)
214
+
215
+ hx5 = self.rebnconv5(hx)
216
+
217
+ hx6 = self.rebnconv6(hx5)
218
+
219
+ hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
220
+ hx5dup = _upsample_like(hx5d, hx4)
221
+
222
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
223
+ hx4dup = _upsample_like(hx4d, hx3)
224
+
225
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
226
+ hx3dup = _upsample_like(hx3d, hx2)
227
+
228
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
229
+ hx2dup = _upsample_like(hx2d, hx1)
230
+
231
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
232
+
233
+ return hx1d + hxin
234
+
235
+
236
+ ### RSU-5 ###
237
+ class RSU5(nn.Module):
238
+
239
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
240
+ super(RSU5, self).__init__()
241
+
242
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
243
+
244
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
245
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
246
+
247
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
248
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
249
+
250
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
251
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
252
+
253
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
254
+
255
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
256
+
257
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
258
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
259
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
260
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
261
+
262
+ def forward(self, x):
263
+ hx = x
264
+
265
+ hxin = self.rebnconvin(hx)
266
+
267
+ hx1 = self.rebnconv1(hxin)
268
+ hx = self.pool1(hx1)
269
+
270
+ hx2 = self.rebnconv2(hx)
271
+ hx = self.pool2(hx2)
272
+
273
+ hx3 = self.rebnconv3(hx)
274
+ hx = self.pool3(hx3)
275
+
276
+ hx4 = self.rebnconv4(hx)
277
+
278
+ hx5 = self.rebnconv5(hx4)
279
+
280
+ hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
281
+ hx4dup = _upsample_like(hx4d, hx3)
282
+
283
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
284
+ hx3dup = _upsample_like(hx3d, hx2)
285
+
286
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
287
+ hx2dup = _upsample_like(hx2d, hx1)
288
+
289
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
290
+
291
+ return hx1d + hxin
292
+
293
+
294
+ ### RSU-4 ###
295
+ class RSU4(nn.Module):
296
+
297
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
298
+ super(RSU4, self).__init__()
299
+
300
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
301
+
302
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
303
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
304
+
305
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
306
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
307
+
308
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
309
+
310
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
311
+
312
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
313
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
314
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
315
+
316
+ def forward(self, x):
317
+ hx = x
318
+
319
+ hxin = self.rebnconvin(hx)
320
+
321
+ hx1 = self.rebnconv1(hxin)
322
+ hx = self.pool1(hx1)
323
+
324
+ hx2 = self.rebnconv2(hx)
325
+ hx = self.pool2(hx2)
326
+
327
+ hx3 = self.rebnconv3(hx)
328
+
329
+ hx4 = self.rebnconv4(hx3)
330
+
331
+ hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
332
+ hx3dup = _upsample_like(hx3d, hx2)
333
+
334
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
335
+ hx2dup = _upsample_like(hx2d, hx1)
336
+
337
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
338
+
339
+ return hx1d + hxin
340
+
341
+
342
+ ### RSU-4F ###
343
+ class RSU4F(nn.Module):
344
+
345
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
346
+ super(RSU4F, self).__init__()
347
+
348
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
349
+
350
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
351
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
352
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
353
+
354
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
355
+
356
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
357
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
358
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
359
+
360
+ def forward(self, x):
361
+ hx = x
362
+
363
+ hxin = self.rebnconvin(hx)
364
+
365
+ hx1 = self.rebnconv1(hxin)
366
+ hx2 = self.rebnconv2(hx1)
367
+ hx3 = self.rebnconv3(hx2)
368
+
369
+ hx4 = self.rebnconv4(hx3)
370
+
371
+ hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
372
+ hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
373
+ hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
374
+
375
+ return hx1d + hxin
376
+
377
+
378
+ class myrebnconv(nn.Module):
379
+ def __init__(self, in_ch=3,
380
+ out_ch=1,
381
+ kernel_size=3,
382
+ stride=1,
383
+ padding=1,
384
+ dilation=1,
385
+ groups=1):
386
+ super(myrebnconv, self).__init__()
387
+
388
+ self.conv = nn.Conv2d(in_ch,
389
+ out_ch,
390
+ kernel_size=kernel_size,
391
+ stride=stride,
392
+ padding=padding,
393
+ dilation=dilation,
394
+ groups=groups)
395
+ self.bn = nn.BatchNorm2d(out_ch)
396
+ self.rl = nn.ReLU(inplace=True)
397
+
398
+ def forward(self, x):
399
+ return self.rl(self.bn(self.conv(x)))
400
+
401
+
402
+ class ISNetGTEncoder(nn.Module):
403
+
404
+ def __init__(self, in_ch=1, out_ch=1):
405
+ super(ISNetGTEncoder, self).__init__()
406
+
407
+ self.conv_in = myrebnconv(in_ch, 16, 3, stride=2, padding=1) # nn.Conv2d(in_ch,64,3,stride=2,padding=1)
408
+
409
+ self.stage1 = RSU7(16, 16, 64)
410
+ self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
411
+
412
+ self.stage2 = RSU6(64, 16, 64)
413
+ self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
414
+
415
+ self.stage3 = RSU5(64, 32, 128)
416
+ self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
417
+
418
+ self.stage4 = RSU4(128, 32, 256)
419
+ self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
420
+
421
+ self.stage5 = RSU4F(256, 64, 512)
422
+ self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
423
+
424
+ self.stage6 = RSU4F(512, 64, 512)
425
+
426
+ self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
427
+ self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
428
+ self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
429
+ self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
430
+ self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
431
+ self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
432
+
433
+ def compute_loss(self, preds, targets):
434
+ return muti_loss_fusion(preds, targets)
435
+
436
+ def forward(self, x):
437
+ hx = x
438
+
439
+ hxin = self.conv_in(hx)
440
+ # hx = self.pool_in(hxin)
441
+
442
+ # stage 1
443
+ hx1 = self.stage1(hxin)
444
+ hx = self.pool12(hx1)
445
+
446
+ # stage 2
447
+ hx2 = self.stage2(hx)
448
+ hx = self.pool23(hx2)
449
+
450
+ # stage 3
451
+ hx3 = self.stage3(hx)
452
+ hx = self.pool34(hx3)
453
+
454
+ # stage 4
455
+ hx4 = self.stage4(hx)
456
+ hx = self.pool45(hx4)
457
+
458
+ # stage 5
459
+ hx5 = self.stage5(hx)
460
+ hx = self.pool56(hx5)
461
+
462
+ # stage 6
463
+ hx6 = self.stage6(hx)
464
+
465
+ # side output
466
+ d1 = self.side1(hx1)
467
+ d1 = _upsample_like(d1, x)
468
+
469
+ d2 = self.side2(hx2)
470
+ d2 = _upsample_like(d2, x)
471
+
472
+ d3 = self.side3(hx3)
473
+ d3 = _upsample_like(d3, x)
474
+
475
+ d4 = self.side4(hx4)
476
+ d4 = _upsample_like(d4, x)
477
+
478
+ d5 = self.side5(hx5)
479
+ d5 = _upsample_like(d5, x)
480
+
481
+ d6 = self.side6(hx6)
482
+ d6 = _upsample_like(d6, x)
483
+
484
+ # d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
485
+
486
+ return [F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)], [hx1, hx2,
487
+ hx3, hx4,
488
+ hx5, hx6]
489
+
490
+
491
+ class ISNetDIS(nn.Module):
492
+
493
+ def __init__(self, in_ch=3, out_ch=1):
494
+ super(ISNetDIS, self).__init__()
495
+
496
+ self.conv_in = nn.Conv2d(in_ch, 64, 3, stride=2, padding=1)
497
+ self.pool_in = nn.MaxPool2d(2, stride=2, ceil_mode=True)
498
+
499
+ self.stage1 = RSU7(64, 32, 64)
500
+ self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
501
+
502
+ self.stage2 = RSU6(64, 32, 128)
503
+ self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
504
+
505
+ self.stage3 = RSU5(128, 64, 256)
506
+ self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
507
+
508
+ self.stage4 = RSU4(256, 128, 512)
509
+ self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
510
+
511
+ self.stage5 = RSU4F(512, 256, 512)
512
+ self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
513
+
514
+ self.stage6 = RSU4F(512, 256, 512)
515
+
516
+ # decoder
517
+ self.stage5d = RSU4F(1024, 256, 512)
518
+ self.stage4d = RSU4(1024, 128, 256)
519
+ self.stage3d = RSU5(512, 64, 128)
520
+ self.stage2d = RSU6(256, 32, 64)
521
+ self.stage1d = RSU7(128, 16, 64)
522
+
523
+ self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
524
+ self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
525
+ self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
526
+ self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
527
+ self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
528
+ self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
529
+
530
+ # self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
531
+
532
+ def compute_loss_kl(self, preds, targets, dfs, fs, mode='MSE'):
533
+ # return muti_loss_fusion(preds,targets)
534
+ return muti_loss_fusion_kl(preds, targets, dfs, fs, mode=mode)
535
+
536
+ def compute_loss(self, preds, targets):
537
+ # return muti_loss_fusion(preds,targets)
538
+ return muti_loss_fusion(preds, targets)
539
+
540
+ def forward(self, x):
541
+ hx = x
542
+
543
+ hxin = self.conv_in(hx)
544
+ # hx = self.pool_in(hxin)
545
+
546
+ # stage 1
547
+ hx1 = self.stage1(hxin)
548
+ hx = self.pool12(hx1)
549
+
550
+ # stage 2
551
+ hx2 = self.stage2(hx)
552
+ hx = self.pool23(hx2)
553
+
554
+ # stage 3
555
+ hx3 = self.stage3(hx)
556
+ hx = self.pool34(hx3)
557
+
558
+ # stage 4
559
+ hx4 = self.stage4(hx)
560
+ hx = self.pool45(hx4)
561
+
562
+ # stage 5
563
+ hx5 = self.stage5(hx)
564
+ hx = self.pool56(hx5)
565
+
566
+ # stage 6
567
+ hx6 = self.stage6(hx)
568
+ hx6up = _upsample_like(hx6, hx5)
569
+
570
+ # -------------------- decoder --------------------
571
+ hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
572
+ hx5dup = _upsample_like(hx5d, hx4)
573
+
574
+ hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
575
+ hx4dup = _upsample_like(hx4d, hx3)
576
+
577
+ hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
578
+ hx3dup = _upsample_like(hx3d, hx2)
579
+
580
+ hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
581
+ hx2dup = _upsample_like(hx2d, hx1)
582
+
583
+ hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
584
+
585
+ # side output
586
+ d1 = self.side1(hx1d)
587
+ d1 = _upsample_like(d1, x)
588
+
589
+ d2 = self.side2(hx2d)
590
+ d2 = _upsample_like(d2, x)
591
+
592
+ d3 = self.side3(hx3d)
593
+ d3 = _upsample_like(d3, x)
594
+
595
+ d4 = self.side4(hx4d)
596
+ d4 = _upsample_like(d4, x)
597
+
598
+ d5 = self.side5(hx5d)
599
+ d5 = _upsample_like(d5, x)
600
+
601
+ d6 = self.side6(hx6)
602
+ d6 = _upsample_like(d6, x)
603
+
604
+ # d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
605
+
606
+ return [F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)], [hx1d, hx2d,
607
+ hx3d, hx4d,
608
+ hx5d, hx6]
DIS/pytorch18.yml ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: pytorch18
2
+ channels:
3
+ - conda-forge
4
+ - anaconda
5
+ - pytorch
6
+ - defaults
7
+ dependencies:
8
+ - _libgcc_mutex=0.1=main
9
+ - _openmp_mutex=4.5=1_gnu
10
+ - blas=1.0=mkl
11
+ - brotli=1.0.9=he6710b0_2
12
+ - bzip2=1.0.8=h7b6447c_0
13
+ - ca-certificates=2022.2.1=h06a4308_0
14
+ - certifi=2021.10.8=py37h06a4308_2
15
+ - cloudpickle=2.0.0=pyhd3eb1b0_0
16
+ - colorama=0.4.4=pyhd3eb1b0_0
17
+ - cudatoolkit=10.2.89=hfd86e86_1
18
+ - cycler=0.11.0=pyhd3eb1b0_0
19
+ - cytoolz=0.11.0=py37h7b6447c_0
20
+ - dask-core=2021.10.0=pyhd3eb1b0_0
21
+ - ffmpeg=4.3=hf484d3e_0
22
+ - fonttools=4.25.0=pyhd3eb1b0_0
23
+ - freetype=2.11.0=h70c0345_0
24
+ - fsspec=2022.2.0=pyhd3eb1b0_0
25
+ - gmp=6.2.1=h2531618_2
26
+ - gnutls=3.6.15=he1e5248_0
27
+ - imageio=2.9.0=pyhd3eb1b0_0
28
+ - intel-openmp=2021.4.0=h06a4308_3561
29
+ - jpeg=9b=h024ee3a_2
30
+ - kiwisolver=1.3.2=py37h295c915_0
31
+ - lame=3.100=h7b6447c_0
32
+ - lcms2=2.12=h3be6417_0
33
+ - ld_impl_linux-64=2.35.1=h7274673_9
34
+ - libffi=3.3=he6710b0_2
35
+ - libgcc-ng=9.3.0=h5101ec6_17
36
+ - libgfortran-ng=7.5.0=ha8ba4b0_17
37
+ - libgfortran4=7.5.0=ha8ba4b0_17
38
+ - libgomp=9.3.0=h5101ec6_17
39
+ - libiconv=1.15=h63c8f33_5
40
+ - libidn2=2.3.2=h7f8727e_0
41
+ - libpng=1.6.37=hbc83047_0
42
+ - libstdcxx-ng=9.3.0=hd4cf53a_17
43
+ - libtasn1=4.16.0=h27cfd23_0
44
+ - libtiff=4.2.0=h85742a9_0
45
+ - libunistring=0.9.10=h27cfd23_0
46
+ - libuv=1.40.0=h7b6447c_0
47
+ - libwebp-base=1.2.2=h7f8727e_0
48
+ - locket=0.2.1=py37h06a4308_2
49
+ - lz4-c=1.9.3=h295c915_1
50
+ - matplotlib-base=3.5.1=py37ha18d171_1
51
+ - mkl=2021.4.0=h06a4308_640
52
+ - mkl-service=2.4.0=py37h7f8727e_0
53
+ - mkl_fft=1.3.1=py37hd3c417c_0
54
+ - mkl_random=1.2.2=py37h51133e4_0
55
+ - munkres=1.1.4=py_0
56
+ - ncurses=6.3=h7f8727e_2
57
+ - nettle=3.7.3=hbbd107a_1
58
+ - networkx=2.6.3=pyhd3eb1b0_0
59
+ - ninja=1.10.2=py37hd09550d_3
60
+ - numpy=1.21.2=py37h20f2e39_0
61
+ - numpy-base=1.21.2=py37h79a1101_0
62
+ - olefile=0.46=py37_0
63
+ - openh264=2.1.1=h4ff587b_0
64
+ - openssl=1.1.1n=h7f8727e_0
65
+ - packaging=21.3=pyhd3eb1b0_0
66
+ - partd=1.2.0=pyhd3eb1b0_1
67
+ - pillow=8.0.0=py37h9a89aac_0
68
+ - pip=21.2.2=py37h06a4308_0
69
+ - pyparsing=3.0.4=pyhd3eb1b0_0
70
+ - python=3.7.11=h12debd9_0
71
+ - python-dateutil=2.8.2=pyhd3eb1b0_0
72
+ - pytorch=1.8.0=py3.7_cuda10.2_cudnn7.6.5_0
73
+ - pywavelets=1.1.1=py37h7b6447c_2
74
+ - pyyaml=6.0=py37h7f8727e_1
75
+ - readline=8.1.2=h7f8727e_1
76
+ - scikit-image=0.15.0=py37hb3f55d8_2
77
+ - scipy=1.7.3=py37hc147768_0
78
+ - setuptools=58.0.4=py37h06a4308_0
79
+ - six=1.16.0=pyhd3eb1b0_1
80
+ - sqlite=3.38.0=hc218d9a_0
81
+ - tk=8.6.11=h1ccaba5_0
82
+ - toolz=0.11.2=pyhd3eb1b0_0
83
+ - torchaudio=0.8.0=py37
84
+ - torchvision=0.9.0=py37_cu102
85
+ - tqdm=4.63.0=pyhd8ed1ab_0
86
+ - typing_extensions=3.10.0.2=pyh06a4308_0
87
+ - wheel=0.37.1=pyhd3eb1b0_0
88
+ - xz=5.2.5=h7b6447c_0
89
+ - yaml=0.2.5=h7b6447c_0
90
+ - zlib=1.2.11=h7f8727e_4
91
+ - zstd=1.4.9=haebb681_0
92
+ prefix: /home/solar/anaconda3/envs/pytorch18
app.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ reference: https://github.com/xuebinqin/DIS
3
+ """
4
+
5
+ import os
6
+
7
+ import gdown
8
+ import gradio as gr
9
+
10
+ from DIS.IsNetPipeLine import IsNetPipeLine
11
+
12
+ save_model_path = "DIS/save_models"
13
+ model_name = os.path.join(save_model_path, "isnet.pth")
14
+ # Download official weights
15
+ if not os.path.exists(model_name):
16
+ if not os.path.exists(save_model_path):
17
+ os.mkdir(save_model_path)
18
+ MODEL_PATH_URL = "https://huggingface.co/Superlang/ImageProcess/resolve/main/isnet.pth"
19
+ gdown.download(MODEL_PATH_URL, model_name, use_cookies=False)
20
+
21
+ pipe = IsNetPipeLine(model_path=model_name)
22
+
23
+
24
+ def inference(image):
25
+ return pipe(image)
26
+
27
+
28
+ title = "remove background"
29
+ interface = gr.Interface(
30
+ fn=inference,
31
+ inputs=gr.Image(type='pil'),
32
+ outputs=["image", "image"],
33
+ title=title,
34
+ allow_flagging='never',
35
+ cache_examples=True,
36
+ ).queue(concurrency_count=1, api_open=True).launch(show_api=True, show_error=True)
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch~=2.0.0
2
+ numpy~=1.23.3
3
+ scikit-image~=0.19.2
4
+ tqdm~=4.65.0
5
+ torchvision~=0.15.1
6
+ Pillow~=9.4.0
7
+ gdown~=4.7.1
8
+ gradio~=3.23.0