gaviego commited on
Commit
1ba239b
1 Parent(s): d25fd03

adding model

Browse files
Files changed (6) hide show
  1. .gitignore +163 -0
  2. app.py +141 -1
  3. data_loader_cache.py +385 -0
  4. models/__init__.py +1 -0
  5. models/isnet.py +610 -0
  6. saved_models/isnet.pth +3 -0
.gitignore ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
161
+
162
+ *.jpeg
163
+ *.png
app.py CHANGED
@@ -1,10 +1,150 @@
1
  import gradio as gr
 
 
 
2
  from PIL import Image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  def bw(image_file:Image):
5
  img = Image.open(image_file)
6
  img = img.convert("L")
7
  return img
8
 
9
- iface = gr.Interface(fn=bw, inputs=gr.Image(type='filepath'), outputs=["image"])
 
 
 
 
 
10
  iface.launch()
1
  import gradio as gr
2
+ import cv2
3
+ import gradio as gr
4
+ import os
5
  from PIL import Image
6
+ import numpy as np
7
+ import torch
8
+ from torch.autograd import Variable
9
+ from torchvision import transforms
10
+ import torch.nn.functional as F
11
+ import matplotlib.pyplot as plt
12
+ import warnings
13
+ warnings.filterwarnings("ignore")
14
+
15
+ # os.system("git clone https://github.com/xuebinqin/DIS")
16
+ # os.system("mv DIS/IS-Net/* .")
17
+
18
+ # project imports
19
+ from data_loader_cache import normalize, im_reader, im_preprocess
20
+ from models import *
21
+
22
+ #Helpers
23
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
24
+
25
+ # Download official weights
26
+ # if not os.path.exists("saved_models"):
27
+ # os.mkdir("saved_models")
28
+ # MODEL_PATH_URL = "https://drive.google.com/uc?id=1KyMpRjewZdyYfxHPYcd-ZbanIXtin0Sn"
29
+ # gdown.download(MODEL_PATH_URL, "saved_models/isnet.pth", use_cookies=False)
30
+
31
+ class GOSNormalize(object):
32
+ '''
33
+ Normalize the Image using torch.transforms
34
+ '''
35
+ def __init__(self, mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]):
36
+ self.mean = mean
37
+ self.std = std
38
+
39
+ def __call__(self,image):
40
+ image = normalize(image,self.mean,self.std)
41
+ return image
42
+
43
+
44
+ transform = transforms.Compose([GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0])])
45
+
46
+ def load_image(im_path, hypar):
47
+ im = im_reader(im_path)
48
+ im, im_shp = im_preprocess(im, hypar["cache_size"])
49
+ im = torch.divide(im,255.0)
50
+ shape = torch.from_numpy(np.array(im_shp))
51
+ return transform(im).unsqueeze(0), shape.unsqueeze(0) # make a batch of image, shape
52
+
53
+
54
+ def build_model(hypar,device):
55
+ net = hypar["model"]#GOSNETINC(3,1)
56
+
57
+ # convert to half precision
58
+ if(hypar["model_digit"]=="half"):
59
+ net.half()
60
+ for layer in net.modules():
61
+ if isinstance(layer, nn.BatchNorm2d):
62
+ layer.float()
63
+
64
+ net.to(device)
65
+
66
+ if(hypar["restore_model"]!=""):
67
+ net.load_state_dict(torch.load(hypar["model_path"]+"/"+hypar["restore_model"], map_location=device))
68
+ net.to(device)
69
+ net.eval()
70
+ return net
71
+
72
+
73
+ def predict(net, inputs_val, shapes_val, hypar, device):
74
+ '''
75
+ Given an Image, predict the mask
76
+ '''
77
+ net.eval()
78
+
79
+ if(hypar["model_digit"]=="full"):
80
+ inputs_val = inputs_val.type(torch.FloatTensor)
81
+ else:
82
+ inputs_val = inputs_val.type(torch.HalfTensor)
83
+
84
+
85
+ inputs_val_v = Variable(inputs_val, requires_grad=False).to(device) # wrap inputs in Variable
86
+
87
+ ds_val = net(inputs_val_v)[0] # list of 6 results
88
+
89
+ pred_val = ds_val[0][0,:,:,:] # B x 1 x H x W # we want the first one which is the most accurate prediction
90
+
91
+ ## recover the prediction spatial size to the orignal image size
92
+ pred_val = torch.squeeze(F.upsample(torch.unsqueeze(pred_val,0),(shapes_val[0][0],shapes_val[0][1]),mode='bilinear'))
93
+
94
+ ma = torch.max(pred_val)
95
+ mi = torch.min(pred_val)
96
+ pred_val = (pred_val-mi)/(ma-mi) # max = 1
97
+
98
+ if device == 'cuda': torch.cuda.empty_cache()
99
+ return (pred_val.detach().cpu().numpy()*255).astype(np.uint8) # it is the mask we need
100
+
101
+ # Set Parameters
102
+ hypar = {} # paramters for inferencing
103
+
104
+
105
+ hypar["model_path"] ="./saved_models" ## load trained weights from this path
106
+ hypar["restore_model"] = "isnet.pth" ## name of the to-be-loaded weights
107
+ hypar["interm_sup"] = False ## indicate if activate intermediate feature supervision
108
+
109
+ ## choose floating point accuracy --
110
+ hypar["model_digit"] = "full" ## indicates "half" or "full" accuracy of float number
111
+ hypar["seed"] = 0
112
+
113
+ hypar["cache_size"] = [1024, 1024] ## cached input spatial resolution, can be configured into different size
114
+
115
+ ## data augmentation parameters ---
116
+ hypar["input_size"] = [1024, 1024] ## mdoel input spatial size, usually use the same value hypar["cache_size"], which means we don't further resize the images
117
+ hypar["crop_size"] = [1024, 1024] ## random crop size from the input, it is usually set as smaller than hypar["cache_size"], e.g., [920,920] for data augmentation
118
+
119
+ hypar["model"] = ISNetDIS()
120
+
121
+ # Build Model
122
+ net = build_model(hypar, device)
123
+
124
+
125
+ def inference(image: Image):
126
+ image_path = image
127
+
128
+ image_tensor, orig_size = load_image(image_path, hypar)
129
+ mask = predict(net, image_tensor, orig_size, hypar, device)
130
+
131
+ pil_mask = Image.fromarray(mask).convert('L')
132
+ im_rgb = Image.open(image).convert("RGB")
133
+
134
+ im_rgba = im_rgb.copy()
135
+ im_rgba.putalpha(pil_mask)
136
+
137
+ return im_rgba
138
 
139
  def bw(image_file:Image):
140
  img = Image.open(image_file)
141
  img = img.convert("L")
142
  return img
143
 
144
+ iface = gr.Interface(fn=inference,
145
+ inputs=gr.Image(type='filepath'),
146
+ outputs=["image"],
147
+ title="Remove Background",
148
+ description="Uses <a href='https://github.com/xuebinqin/DIS'>DIS</a> to remove background"
149
+ )
150
  iface.launch()
data_loader_cache.py ADDED
@@ -0,0 +1,385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## data loader
2
+ ## Ackownledgement:
3
+ ## We would like to thank Dr. Ibrahim Almakky (https://scholar.google.co.uk/citations?user=T9MTcK0AAAAJ&hl=en)
4
+ ## for his helps in implementing cache machanism of our DIS dataloader.
5
+ from __future__ import print_function, division
6
+
7
+ import numpy as np
8
+ import random
9
+ from copy import deepcopy
10
+ import json
11
+ from tqdm import tqdm
12
+ from skimage import io
13
+ import os
14
+ from glob import glob
15
+
16
+ import torch
17
+ from torch.utils.data import Dataset, DataLoader
18
+ from torchvision import transforms, utils
19
+ from torchvision.transforms.functional import normalize
20
+ import torch.nn.functional as F
21
+
22
+ #### --------------------- DIS dataloader cache ---------------------####
23
+
24
+ def get_im_gt_name_dict(datasets, flag='valid'):
25
+ print("------------------------------", flag, "--------------------------------")
26
+ name_im_gt_list = []
27
+ for i in range(len(datasets)):
28
+ print("--->>>", flag, " dataset ",i,"/",len(datasets)," ",datasets[i]["name"],"<<<---")
29
+ tmp_im_list, tmp_gt_list = [], []
30
+ tmp_im_list = glob(datasets[i]["im_dir"]+os.sep+'*'+datasets[i]["im_ext"])
31
+
32
+ # img_name_dict[im_dirs[i][0]] = tmp_im_list
33
+ print('-im-',datasets[i]["name"],datasets[i]["im_dir"], ': ',len(tmp_im_list))
34
+
35
+ if(datasets[i]["gt_dir"]==""):
36
+ print('-gt-', datasets[i]["name"], datasets[i]["gt_dir"], ': ', 'No Ground Truth Found')
37
+ tmp_gt_list = []
38
+ else:
39
+ tmp_gt_list = [datasets[i]["gt_dir"]+os.sep+x.split(os.sep)[-1].split(datasets[i]["im_ext"])[0]+datasets[i]["gt_ext"] for x in tmp_im_list]
40
+
41
+ # lbl_name_dict[im_dirs[i][0]] = tmp_gt_list
42
+ print('-gt-', datasets[i]["name"],datasets[i]["gt_dir"], ': ',len(tmp_gt_list))
43
+
44
+
45
+ if flag=="train": ## combine multiple training sets into one dataset
46
+ if len(name_im_gt_list)==0:
47
+ name_im_gt_list.append({"dataset_name":datasets[i]["name"],
48
+ "im_path":tmp_im_list,
49
+ "gt_path":tmp_gt_list,
50
+ "im_ext":datasets[i]["im_ext"],
51
+ "gt_ext":datasets[i]["gt_ext"],
52
+ "cache_dir":datasets[i]["cache_dir"]})
53
+ else:
54
+ name_im_gt_list[0]["dataset_name"] = name_im_gt_list[0]["dataset_name"] + "_" + datasets[i]["name"]
55
+ name_im_gt_list[0]["im_path"] = name_im_gt_list[0]["im_path"] + tmp_im_list
56
+ name_im_gt_list[0]["gt_path"] = name_im_gt_list[0]["gt_path"] + tmp_gt_list
57
+ if datasets[i]["im_ext"]!=".jpg" or datasets[i]["gt_ext"]!=".png":
58
+ print("Error: Please make sure all you images and ground truth masks are in jpg and png format respectively !!!")
59
+ exit()
60
+ name_im_gt_list[0]["im_ext"] = ".jpg"
61
+ name_im_gt_list[0]["gt_ext"] = ".png"
62
+ name_im_gt_list[0]["cache_dir"] = os.sep.join(datasets[i]["cache_dir"].split(os.sep)[0:-1])+os.sep+name_im_gt_list[0]["dataset_name"]
63
+ else: ## keep different validation or inference datasets as separate ones
64
+ name_im_gt_list.append({"dataset_name":datasets[i]["name"],
65
+ "im_path":tmp_im_list,
66
+ "gt_path":tmp_gt_list,
67
+ "im_ext":datasets[i]["im_ext"],
68
+ "gt_ext":datasets[i]["gt_ext"],
69
+ "cache_dir":datasets[i]["cache_dir"]})
70
+
71
+ return name_im_gt_list
72
+
73
+ def create_dataloaders(name_im_gt_list, cache_size=[], cache_boost=True, my_transforms=[], batch_size=1, shuffle=False):
74
+ ## model="train": return one dataloader for training
75
+ ## model="valid": return a list of dataloaders for validation or testing
76
+
77
+ gos_dataloaders = []
78
+ gos_datasets = []
79
+
80
+ if(len(name_im_gt_list)==0):
81
+ return gos_dataloaders, gos_datasets
82
+
83
+ num_workers_ = 1
84
+ if(batch_size>1):
85
+ num_workers_ = 2
86
+ if(batch_size>4):
87
+ num_workers_ = 4
88
+ if(batch_size>8):
89
+ num_workers_ = 8
90
+
91
+ for i in range(0,len(name_im_gt_list)):
92
+ gos_dataset = GOSDatasetCache([name_im_gt_list[i]],
93
+ cache_size = cache_size,
94
+ cache_path = name_im_gt_list[i]["cache_dir"],
95
+ cache_boost = cache_boost,
96
+ transform = transforms.Compose(my_transforms))
97
+ gos_dataloaders.append(DataLoader(gos_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers_))
98
+ gos_datasets.append(gos_dataset)
99
+
100
+ return gos_dataloaders, gos_datasets
101
+
102
+ def im_reader(im_path):
103
+ return io.imread(im_path)
104
+
105
+ def im_preprocess(im,size):
106
+ if len(im.shape) < 3:
107
+ im = im[:, :, np.newaxis]
108
+ if im.shape[2] == 1:
109
+ im = np.repeat(im, 3, axis=2)
110
+ im_tensor = torch.tensor(im.copy(), dtype=torch.float32)
111
+ im_tensor = torch.transpose(torch.transpose(im_tensor,1,2),0,1)
112
+ if(len(size)<2):
113
+ return im_tensor, im.shape[0:2]
114
+ else:
115
+ im_tensor = torch.unsqueeze(im_tensor,0)
116
+ im_tensor = F.upsample(im_tensor, size, mode="bilinear")
117
+ im_tensor = torch.squeeze(im_tensor,0)
118
+
119
+ return im_tensor.type(torch.uint8), im.shape[0:2]
120
+
121
+ def gt_preprocess(gt,size):
122
+ if len(gt.shape) > 2:
123
+ gt = gt[:, :, 0]
124
+
125
+ gt_tensor = torch.unsqueeze(torch.tensor(gt, dtype=torch.uint8),0)
126
+
127
+ if(len(size)<2):
128
+ return gt_tensor.type(torch.uint8), gt.shape[0:2]
129
+ else:
130
+ gt_tensor = torch.unsqueeze(torch.tensor(gt_tensor, dtype=torch.float32),0)
131
+ gt_tensor = F.upsample(gt_tensor, size, mode="bilinear")
132
+ gt_tensor = torch.squeeze(gt_tensor,0)
133
+
134
+ return gt_tensor.type(torch.uint8), gt.shape[0:2]
135
+ # return gt_tensor, gt.shape[0:2]
136
+
137
+ class GOSRandomHFlip(object):
138
+ def __init__(self,prob=0.5):
139
+ self.prob = prob
140
+ def __call__(self,sample):
141
+ imidx, image, label, shape = sample['imidx'], sample['image'], sample['label'], sample['shape']
142
+
143
+ # random horizontal flip
144
+ if random.random() >= self.prob:
145
+ image = torch.flip(image,dims=[2])
146
+ label = torch.flip(label,dims=[2])
147
+
148
+ return {'imidx':imidx,'image':image, 'label':label, 'shape':shape}
149
+
150
+ class GOSResize(object):
151
+ def __init__(self,size=[320,320]):
152
+ self.size = size
153
+ def __call__(self,sample):
154
+ imidx, image, label, shape = sample['imidx'], sample['image'], sample['label'], sample['shape']
155
+
156
+ # import time
157
+ # start = time.time()
158
+
159
+ image = torch.squeeze(F.upsample(torch.unsqueeze(image,0),self.size,mode='bilinear'),dim=0)
160
+ label = torch.squeeze(F.upsample(torch.unsqueeze(label,0),self.size,mode='bilinear'),dim=0)
161
+
162
+ # print("time for resize: ", time.time()-start)
163
+
164
+ return {'imidx':imidx,'image':image, 'label':label, 'shape':shape}
165
+
166
+ class GOSRandomCrop(object):
167
+ def __init__(self,size=[288,288]):
168
+ self.size = size
169
+ def __call__(self,sample):
170
+ imidx, image, label, shape = sample['imidx'], sample['image'], sample['label'], sample['shape']
171
+
172
+ h, w = image.shape[1:]
173
+ new_h, new_w = self.size
174
+
175
+ top = np.random.randint(0, h - new_h)
176
+ left = np.random.randint(0, w - new_w)
177
+
178
+ image = image[:,top:top+new_h,left:left+new_w]
179
+ label = label[:,top:top+new_h,left:left+new_w]
180
+
181
+ return {'imidx':imidx,'image':image, 'label':label, 'shape':shape}
182
+
183
+
184
+ class GOSNormalize(object):
185
+ def __init__(self, mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]):
186
+ self.mean = mean
187
+ self.std = std
188
+
189
+ def __call__(self,sample):
190
+
191
+ imidx, image, label, shape = sample['imidx'], sample['image'], sample['label'], sample['shape']
192
+ image = normalize(image,self.mean,self.std)
193
+
194
+ return {'imidx':imidx,'image':image, 'label':label, 'shape':shape}
195
+
196
+
197
+ class GOSDatasetCache(Dataset):
198
+
199
+ def __init__(self, name_im_gt_list, cache_size=[], cache_path='./cache', cache_file_name='dataset.json', cache_boost=False, transform=None):
200
+
201
+
202
+ self.cache_size = cache_size
203
+ self.cache_path = cache_path
204
+ self.cache_file_name = cache_file_name
205
+ self.cache_boost_name = ""
206
+
207
+ self.cache_boost = cache_boost
208
+ # self.ims_npy = None
209
+ # self.gts_npy = None
210
+
211
+ ## cache all the images and ground truth into a single pytorch tensor
212
+ self.ims_pt = None
213
+ self.gts_pt = None
214
+
215
+ ## we will cache the npy as well regardless of the cache_boost
216
+ # if(self.cache_boost):
217
+ self.cache_boost_name = cache_file_name.split('.json')[0]
218
+
219
+ self.transform = transform
220
+
221
+ self.dataset = {}
222
+
223
+ ## combine different datasets into one
224
+ dataset_names = []
225
+ dt_name_list = [] # dataset name per image
226
+ im_name_list = [] # image name
227
+ im_path_list = [] # im path
228
+ gt_path_list = [] # gt path
229
+ im_ext_list = [] # im ext
230
+ gt_ext_list = [] # gt ext
231
+ for i in range(0,len(name_im_gt_list)):
232
+ dataset_names.append(name_im_gt_list[i]["dataset_name"])
233
+ # dataset name repeated based on the number of images in this dataset
234
+ dt_name_list.extend([name_im_gt_list[i]["dataset_name"] for x in name_im_gt_list[i]["im_path"]])
235
+ im_name_list.extend([x.split(os.sep)[-1].split(name_im_gt_list[i]["im_ext"])[0] for x in name_im_gt_list[i]["im_path"]])
236
+ im_path_list.extend(name_im_gt_list[i]["im_path"])
237
+ gt_path_list.extend(name_im_gt_list[i]["gt_path"])
238
+ im_ext_list.extend([name_im_gt_list[i]["im_ext"] for x in name_im_gt_list[i]["im_path"]])
239
+ gt_ext_list.extend([name_im_gt_list[i]["gt_ext"] for x in name_im_gt_list[i]["gt_path"]])
240
+
241
+
242
+ self.dataset["data_name"] = dt_name_list
243
+ self.dataset["im_name"] = im_name_list
244
+ self.dataset["im_path"] = im_path_list
245
+ self.dataset["ori_im_path"] = deepcopy(im_path_list)
246
+ self.dataset["gt_path"] = gt_path_list
247
+ self.dataset["ori_gt_path"] = deepcopy(gt_path_list)
248
+ self.dataset["im_shp"] = []
249
+ self.dataset["gt_shp"] = []
250
+ self.dataset["im_ext"] = im_ext_list
251
+ self.dataset["gt_ext"] = gt_ext_list
252
+
253
+
254
+ self.dataset["ims_pt_dir"] = ""
255
+ self.dataset["gts_pt_dir"] = ""
256
+
257
+ self.dataset = self.manage_cache(dataset_names)
258
+
259
+ def manage_cache(self,dataset_names):
260
+ if not os.path.exists(self.cache_path): # create the folder for cache
261
+ os.makedirs(self.cache_path)
262
+ cache_folder = os.path.join(self.cache_path, "_".join(dataset_names)+"_"+"x".join([str(x) for x in self.cache_size]))
263
+ if not os.path.exists(cache_folder): # check if the cache files are there, if not then cache
264
+ return self.cache(cache_folder)
265
+ return self.load_cache(cache_folder)
266
+
267
+ def cache(self,cache_folder):
268
+ os.mkdir(cache_folder)
269
+ cached_dataset = deepcopy(self.dataset)
270
+
271
+ # ims_list = []
272
+ # gts_list = []
273
+ ims_pt_list = []
274
+ gts_pt_list = []
275
+ for i, im_path in tqdm(enumerate(self.dataset["im_path"]), total=len(self.dataset["im_path"])):
276
+
277
+ im_id = cached_dataset["im_name"][i]
278
+ print("im_path: ", im_path)
279
+ im = im_reader(im_path)
280
+ im, im_shp = im_preprocess(im,self.cache_size)
281
+ im_cache_file = os.path.join(cache_folder,self.dataset["data_name"][i]+"_"+im_id + "_im.pt")
282
+ torch.save(im,im_cache_file)
283
+
284
+ cached_dataset["im_path"][i] = im_cache_file
285
+ if(self.cache_boost):
286
+ ims_pt_list.append(torch.unsqueeze(im,0))
287
+ # ims_list.append(im.cpu().data.numpy().astype(np.uint8))
288
+
289
+ gt = np.zeros(im.shape[0:2])
290
+ if len(self.dataset["gt_path"])!=0:
291
+ gt = im_reader(self.dataset["gt_path"][i])
292
+ gt, gt_shp = gt_preprocess(gt,self.cache_size)
293
+ gt_cache_file = os.path.join(cache_folder,self.dataset["data_name"][i]+"_"+im_id + "_gt.pt")
294
+ torch.save(gt,gt_cache_file)
295
+ if len(self.dataset["gt_path"])>0:
296
+ cached_dataset["gt_path"][i] = gt_cache_file
297
+ else:
298
+ cached_dataset["gt_path"].append(gt_cache_file)
299
+ if(self.cache_boost):
300
+ gts_pt_list.append(torch.unsqueeze(gt,0))
301
+ # gts_list.append(gt.cpu().data.numpy().astype(np.uint8))
302
+
303
+ # im_shp_cache_file = os.path.join(cache_folder,im_id + "_im_shp.pt")
304
+ # torch.save(gt_shp, shp_cache_file)
305
+ cached_dataset["im_shp"].append(im_shp)
306
+ # self.dataset["im_shp"].append(im_shp)
307
+
308
+ # shp_cache_file = os.path.join(cache_folder,im_id + "_gt_shp.pt")
309
+ # torch.save(gt_shp, shp_cache_file)
310
+ cached_dataset["gt_shp"].append(gt_shp)
311
+ # self.dataset["gt_shp"].append(gt_shp)
312
+
313
+ if(self.cache_boost):
314
+ cached_dataset["ims_pt_dir"] = os.path.join(cache_folder, self.cache_boost_name+'_ims.pt')
315
+ cached_dataset["gts_pt_dir"] = os.path.join(cache_folder, self.cache_boost_name+'_gts.pt')
316
+ self.ims_pt = torch.cat(ims_pt_list,dim=0)
317
+ self.gts_pt = torch.cat(gts_pt_list,dim=0)
318
+ torch.save(torch.cat(ims_pt_list,dim=0),cached_dataset["ims_pt_dir"])
319
+ torch.save(torch.cat(gts_pt_list,dim=0),cached_dataset["gts_pt_dir"])
320
+
321
+ try:
322
+ json_file = open(os.path.join(cache_folder, self.cache_file_name),"w")
323
+ json.dump(cached_dataset, json_file)
324
+ json_file.close()
325
+ except Exception:
326
+ raise FileNotFoundError("Cannot create JSON")
327
+ return cached_dataset
328
+
329
+ def load_cache(self, cache_folder):
330
+ json_file = open(os.path.join(cache_folder,self.cache_file_name),"r")
331
+ dataset = json.load(json_file)
332
+ json_file.close()
333
+ ## if cache_boost is true, we will load the image npy and ground truth npy into the RAM
334
+ ## otherwise the pytorch tensor will be loaded
335
+ if(self.cache_boost):
336
+ # self.ims_npy = np.load(dataset["ims_npy_dir"])
337
+ # self.gts_npy = np.load(dataset["gts_npy_dir"])
338
+ self.ims_pt = torch.load(dataset["ims_pt_dir"], map_location='cpu')
339
+ self.gts_pt = torch.load(dataset["gts_pt_dir"], map_location='cpu')
340
+ return dataset
341
+
342
+ def __len__(self):
343
+ return len(self.dataset["im_path"])
344
+
345
+ def __getitem__(self, idx):
346
+
347
+ im = None
348
+ gt = None
349
+ if(self.cache_boost and self.ims_pt is not None):
350
+
351
+ # start = time.time()
352
+ im = self.ims_pt[idx]#.type(torch.float32)
353
+ gt = self.gts_pt[idx]#.type(torch.float32)
354
+ # print(idx, 'time for pt loading: ', time.time()-start)
355
+
356
+ else:
357
+ # import time
358
+ # start = time.time()
359
+ # print("tensor***")
360
+ im_pt_path = os.path.join(self.cache_path,os.sep.join(self.dataset["im_path"][idx].split(os.sep)[-2:]))
361
+ im = torch.load(im_pt_path)#(self.dataset["im_path"][idx])
362
+ gt_pt_path = os.path.join(self.cache_path,os.sep.join(self.dataset["gt_path"][idx].split(os.sep)[-2:]))
363
+ gt = torch.load(gt_pt_path)#(self.dataset["gt_path"][idx])
364
+ # print(idx,'time for tensor loading: ', time.time()-start)
365
+
366
+
367
+ im_shp = self.dataset["im_shp"][idx]
368
+ # print("time for loading im and gt: ", time.time()-start)
369
+
370
+ # start_time = time.time()
371
+ im = torch.divide(im,255.0)
372
+ gt = torch.divide(gt,255.0)
373
+ # print(idx, 'time for normalize torch divide: ', time.time()-start_time)
374
+
375
+ sample = {
376
+ "imidx": torch.from_numpy(np.array(idx)),
377
+ "image": im,
378
+ "label": gt,
379
+ "shape": torch.from_numpy(np.array(im_shp)),
380
+ }
381
+
382
+ if self.transform:
383
+ sample = self.transform(sample)
384
+
385
+ return sample
models/__init__.py ADDED
@@ -0,0 +1 @@
 
1
+ from models.isnet import ISNetGTEncoder, ISNetDIS
models/isnet.py ADDED
@@ -0,0 +1,610 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision import models
4
+ import torch.nn.functional as F
5
+
6
+
7
+ bce_loss = nn.BCELoss(size_average=True)
8
+ def muti_loss_fusion(preds, target):
9
+ loss0 = 0.0
10
+ loss = 0.0
11
+
12
+ for i in range(0,len(preds)):
13
+ # print("i: ", i, preds[i].shape)
14
+ if(preds[i].shape[2]!=target.shape[2] or preds[i].shape[3]!=target.shape[3]):
15
+ # tmp_target = _upsample_like(target,preds[i])
16
+ tmp_target = F.interpolate(target, size=preds[i].size()[2:], mode='bilinear', align_corners=True)
17
+ loss = loss + bce_loss(preds[i],tmp_target)
18
+ else:
19
+ loss = loss + bce_loss(preds[i],target)
20
+ if(i==0):
21
+ loss0 = loss
22
+ return loss0, loss
23
+
24
+ fea_loss = nn.MSELoss(size_average=True)
25
+ kl_loss = nn.KLDivLoss(size_average=True)
26
+ l1_loss = nn.L1Loss(size_average=True)
27
+ smooth_l1_loss = nn.SmoothL1Loss(size_average=True)
28
+ def muti_loss_fusion_kl(preds, target, dfs, fs, mode='MSE'):
29
+ loss0 = 0.0
30
+ loss = 0.0
31
+
32
+ for i in range(0,len(preds)):
33
+ # print("i: ", i, preds[i].shape)
34
+ if(preds[i].shape[2]!=target.shape[2] or preds[i].shape[3]!=target.shape[3]):
35
+ # tmp_target = _upsample_like(target,preds[i])
36
+ tmp_target = F.interpolate(target, size=preds[i].size()[2:], mode='bilinear', align_corners=True)
37
+ loss = loss + bce_loss(preds[i],tmp_target)
38
+ else:
39
+ loss = loss + bce_loss(preds[i],target)
40
+ if(i==0):
41
+ loss0 = loss
42
+
43
+ for i in range(0,len(dfs)):
44
+ if(mode=='MSE'):
45
+ loss = loss + fea_loss(dfs[i],fs[i]) ### add the mse loss of features as additional constraints
46
+ # print("fea_loss: ", fea_loss(dfs[i],fs[i]).item())
47
+ elif(mode=='KL'):
48
+ loss = loss + kl_loss(F.log_softmax(dfs[i],dim=1),F.softmax(fs[i],dim=1))
49
+ # print("kl_loss: ", kl_loss(F.log_softmax(dfs[i],dim=1),F.softmax(fs[i],dim=1)).item())
50
+ elif(mode=='MAE'):
51
+ loss = loss + l1_loss(dfs[i],fs[i])
52
+ # print("ls_loss: ", l1_loss(dfs[i],fs[i]))
53
+ elif(mode=='SmoothL1'):
54
+ loss = loss + smooth_l1_loss(dfs[i],fs[i])
55
+ # print("SmoothL1: ", smooth_l1_loss(dfs[i],fs[i]).item())
56
+
57
+ return loss0, loss
58
+
59
+ class REBNCONV(nn.Module):
60
+ def __init__(self,in_ch=3,out_ch=3,dirate=1,stride=1):
61
+ super(REBNCONV,self).__init__()
62
+
63
+ self.conv_s1 = nn.Conv2d(in_ch,out_ch,3,padding=1*dirate,dilation=1*dirate,stride=stride)
64
+ self.bn_s1 = nn.BatchNorm2d(out_ch)
65
+ self.relu_s1 = nn.ReLU(inplace=True)
66
+
67
+ def forward(self,x):
68
+
69
+ hx = x
70
+ xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
71
+
72
+ return xout
73
+
74
+ ## upsample tensor 'src' to have the same spatial size with tensor 'tar'
75
+ def _upsample_like(src,tar):
76
+
77
+ src = F.upsample(src,size=tar.shape[2:],mode='bilinear')
78
+
79
+ return src
80
+
81
+
82
+ ### RSU-7 ###
83
+ class RSU7(nn.Module):
84
+
85
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512):
86
+ super(RSU7,self).__init__()
87
+
88
+ self.in_ch = in_ch
89
+ self.mid_ch = mid_ch
90
+ self.out_ch = out_ch
91
+
92
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) ## 1 -> 1/2
93
+
94
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
95
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
96
+
97
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
98
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
99
+
100
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
101
+ self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
102
+
103
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
104
+ self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
105
+
106
+ self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
107
+ self.pool5 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
108
+
109
+ self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=1)
110
+
111
+ self.rebnconv7 = REBNCONV(mid_ch,mid_ch,dirate=2)
112
+
113
+ self.rebnconv6d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
114
+ self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
115
+ self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
116
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
117
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
118
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
119
+
120
+ def forward(self,x):
121
+ b, c, h, w = x.shape
122
+
123
+ hx = x
124
+ hxin = self.rebnconvin(hx)
125
+
126
+ hx1 = self.rebnconv1(hxin)
127
+ hx = self.pool1(hx1)
128
+
129
+ hx2 = self.rebnconv2(hx)
130
+ hx = self.pool2(hx2)
131
+
132
+ hx3 = self.rebnconv3(hx)
133
+ hx = self.pool3(hx3)
134
+
135
+ hx4 = self.rebnconv4(hx)
136
+ hx = self.pool4(hx4)
137
+
138
+ hx5 = self.rebnconv5(hx)
139
+ hx = self.pool5(hx5)
140
+
141
+ hx6 = self.rebnconv6(hx)
142
+
143
+ hx7 = self.rebnconv7(hx6)
144
+
145
+ hx6d = self.rebnconv6d(torch.cat((hx7,hx6),1))
146
+ hx6dup = _upsample_like(hx6d,hx5)
147
+
148
+ hx5d = self.rebnconv5d(torch.cat((hx6dup,hx5),1))
149
+ hx5dup = _upsample_like(hx5d,hx4)
150
+
151
+ hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
152
+ hx4dup = _upsample_like(hx4d,hx3)
153
+
154
+ hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
155
+ hx3dup = _upsample_like(hx3d,hx2)
156
+
157
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
158
+ hx2dup = _upsample_like(hx2d,hx1)
159
+
160
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
161
+
162
+ return hx1d + hxin
163
+
164
+
165
+ ### RSU-6 ###
166
+ class RSU6(nn.Module):
167
+
168
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
169
+ super(RSU6,self).__init__()
170
+
171
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
172
+
173
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
174
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
175
+
176
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
177
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
178
+
179
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
180
+ self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
181
+
182
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
183
+ self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
184
+
185
+ self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
186
+
187
+ self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=2)
188
+
189
+ self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
190
+ self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
191
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
192
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
193
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
194
+
195
+ def forward(self,x):
196
+
197
+ hx = x
198
+
199
+ hxin = self.rebnconvin(hx)
200
+
201
+ hx1 = self.rebnconv1(hxin)
202
+ hx = self.pool1(hx1)
203
+
204
+ hx2 = self.rebnconv2(hx)
205
+ hx = self.pool2(hx2)
206
+
207
+ hx3 = self.rebnconv3(hx)
208
+ hx = self.pool3(hx3)
209
+
210
+ hx4 = self.rebnconv4(hx)
211
+ hx = self.pool4(hx4)
212
+
213
+ hx5 = self.rebnconv5(hx)
214
+
215
+ hx6 = self.rebnconv6(hx5)
216
+
217
+
218
+ hx5d = self.rebnconv5d(torch.cat((hx6,hx5),1))
219
+ hx5dup = _upsample_like(hx5d,hx4)
220
+
221
+ hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
222
+ hx4dup = _upsample_like(hx4d,hx3)
223
+
224
+ hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
225
+ hx3dup = _upsample_like(hx3d,hx2)
226
+
227
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
228
+ hx2dup = _upsample_like(hx2d,hx1)
229
+
230
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
231
+
232
+ return hx1d + hxin
233
+
234
+ ### RSU-5 ###
235
+ class RSU5(nn.Module):
236
+
237
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
238
+ super(RSU5,self).__init__()
239
+
240
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
241
+
242
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
243
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
244
+
245
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
246
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
247
+
248
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
249
+ self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
250
+
251
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
252
+
253
+ self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=2)
254
+
255
+ self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
256
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
257
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
258
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
259
+
260
+ def forward(self,x):
261
+
262
+ hx = x
263
+
264
+ hxin = self.rebnconvin(hx)
265
+
266
+ hx1 = self.rebnconv1(hxin)
267
+ hx = self.pool1(hx1)
268
+
269
+ hx2 = self.rebnconv2(hx)
270
+ hx = self.pool2(hx2)
271
+
272
+ hx3 = self.rebnconv3(hx)
273
+ hx = self.pool3(hx3)
274
+
275
+ hx4 = self.rebnconv4(hx)
276
+
277
+ hx5 = self.rebnconv5(hx4)
278
+
279
+ hx4d = self.rebnconv4d(torch.cat((hx5,hx4),1))
280
+ hx4dup = _upsample_like(hx4d,hx3)
281
+
282
+ hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
283
+ hx3dup = _upsample_like(hx3d,hx2)
284
+
285
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
286
+ hx2dup = _upsample_like(hx2d,hx1)
287
+
288
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
289
+
290
+ return hx1d + hxin
291
+
292
+ ### RSU-4 ###
293
+ class RSU4(nn.Module):
294
+
295
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
296
+ super(RSU4,self).__init__()
297
+
298
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
299
+
300
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
301
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
302
+
303
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
304
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
305
+
306
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
307
+
308
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=2)
309
+
310
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
311
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
312
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
313
+
314
+ def forward(self,x):
315
+
316
+ hx = x
317
+
318
+ hxin = self.rebnconvin(hx)
319
+
320
+ hx1 = self.rebnconv1(hxin)
321
+ hx = self.pool1(hx1)
322
+
323
+ hx2 = self.rebnconv2(hx)
324
+ hx = self.pool2(hx2)
325
+
326
+ hx3 = self.rebnconv3(hx)
327
+
328
+ hx4 = self.rebnconv4(hx3)
329
+
330
+ hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
331
+ hx3dup = _upsample_like(hx3d,hx2)
332
+
333
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
334
+ hx2dup = _upsample_like(hx2d,hx1)
335
+
336
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
337
+
338
+ return hx1d + hxin
339
+
340
+ ### RSU-4F ###
341
+ class RSU4F(nn.Module):
342
+
343
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
344
+ super(RSU4F,self).__init__()
345
+
346
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
347
+
348
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
349
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=2)
350
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=4)
351
+
352
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=8)
353
+
354
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=4)
355
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=2)
356
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
357
+
358
+ def forward(self,x):
359
+
360
+ hx = x
361
+
362
+ hxin = self.rebnconvin(hx)
363
+
364
+ hx1 = self.rebnconv1(hxin)
365
+ hx2 = self.rebnconv2(hx1)
366
+ hx3 = self.rebnconv3(hx2)
367
+
368
+ hx4 = self.rebnconv4(hx3)
369
+
370
+ hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
371
+ hx2d = self.rebnconv2d(torch.cat((hx3d,hx2),1))
372
+ hx1d = self.rebnconv1d(torch.cat((hx2d,hx1),1))
373
+
374
+ return hx1d + hxin
375
+
376
+
377
+ class myrebnconv(nn.Module):
378
+ def __init__(self, in_ch=3,
379
+ out_ch=1,
380
+ kernel_size=3,
381
+ stride=1,
382
+ padding=1,
383
+ dilation=1,
384
+ groups=1):
385
+ super(myrebnconv,self).__init__()
386
+
387
+ self.conv = nn.Conv2d(in_ch,
388
+ out_ch,
389
+ kernel_size=kernel_size,
390
+ stride=stride,
391
+ padding=padding,
392
+ dilation=dilation,
393
+ groups=groups)
394
+ self.bn = nn.BatchNorm2d(out_ch)
395
+ self.rl = nn.ReLU(inplace=True)
396
+
397
+ def forward(self,x):
398
+ return self.rl(self.bn(self.conv(x)))
399
+
400
+
401
+ class ISNetGTEncoder(nn.Module):
402
+
403
+ def __init__(self,in_ch=1,out_ch=1):
404
+ super(ISNetGTEncoder,self).__init__()
405
+
406
+ self.conv_in = myrebnconv(in_ch,16,3,stride=2,padding=1) # nn.Conv2d(in_ch,64,3,stride=2,padding=1)
407
+
408
+ self.stage1 = RSU7(16,16,64)
409
+ self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
410
+
411
+ self.stage2 = RSU6(64,16,64)
412
+ self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
413
+
414
+ self.stage3 = RSU5(64,32,128)
415
+ self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
416
+
417
+ self.stage4 = RSU4(128,32,256)
418
+ self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
419
+
420
+ self.stage5 = RSU4F(256,64,512)
421
+ self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
422
+
423
+ self.stage6 = RSU4F(512,64,512)
424
+
425
+
426
+ self.side1 = nn.Conv2d(64,out_ch,3,padding=1)
427
+ self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
428
+ self.side3 = nn.Conv2d(128,out_ch,3,padding=1)
429
+ self.side4 = nn.Conv2d(256,out_ch,3,padding=1)
430
+ self.side5 = nn.Conv2d(512,out_ch,3,padding=1)
431
+ self.side6 = nn.Conv2d(512,out_ch,3,padding=1)
432
+
433
+ def compute_loss(self, preds, targets):
434
+
435
+ return muti_loss_fusion(preds,targets)
436
+
437
+ def forward(self,x):
438
+
439
+ hx = x
440
+
441
+ hxin = self.conv_in(hx)
442
+ # hx = self.pool_in(hxin)
443
+
444
+ #stage 1
445
+ hx1 = self.stage1(hxin)
446
+ hx = self.pool12(hx1)
447
+
448
+ #stage 2
449
+ hx2 = self.stage2(hx)
450
+ hx = self.pool23(hx2)
451
+
452
+ #stage 3
453
+ hx3 = self.stage3(hx)
454
+ hx = self.pool34(hx3)
455
+
456
+ #stage 4
457
+ hx4 = self.stage4(hx)
458
+ hx = self.pool45(hx4)
459
+
460
+ #stage 5
461
+ hx5 = self.stage5(hx)
462
+ hx = self.pool56(hx5)
463
+
464
+ #stage 6
465
+ hx6 = self.stage6(hx)
466
+
467
+
468
+ #side output
469
+ d1 = self.side1(hx1)
470
+ d1 = _upsample_like(d1,x)
471
+
472
+ d2 = self.side2(hx2)
473
+ d2 = _upsample_like(d2,x)
474
+
475
+ d3 = self.side3(hx3)
476
+ d3 = _upsample_like(d3,x)
477
+
478
+ d4 = self.side4(hx4)
479
+ d4 = _upsample_like(d4,x)
480
+
481
+ d5 = self.side5(hx5)
482
+ d5 = _upsample_like(d5,x)
483
+
484
+ d6 = self.side6(hx6)
485
+ d6 = _upsample_like(d6,x)
486
+
487
+ # d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
488
+
489
+ return [F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)], [hx1,hx2,hx3,hx4,hx5,hx6]
490
+
491
+ class ISNetDIS(nn.Module):
492
+
493
+ def __init__(self,in_ch=3,out_ch=1):
494
+ super(ISNetDIS,self).__init__()
495
+
496
+ self.conv_in = nn.Conv2d(in_ch,64,3,stride=2,padding=1)
497
+ self.pool_in = nn.MaxPool2d(2,stride=2,ceil_mode=True)
498
+
499
+ self.stage1 = RSU7(64,32,64)
500
+ self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
501
+
502
+ self.stage2 = RSU6(64,32,128)
503
+ self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
504
+
505
+ self.stage3 = RSU5(128,64,256)
506
+ self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
507
+
508
+ self.stage4 = RSU4(256,128,512)
509
+ self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
510
+
511
+ self.stage5 = RSU4F(512,256,512)
512
+ self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
513
+
514
+ self.stage6 = RSU4F(512,256,512)
515
+
516
+ # decoder
517
+ self.stage5d = RSU4F(1024,256,512)
518
+ self.stage4d = RSU4(1024,128,256)
519
+ self.stage3d = RSU5(512,64,128)
520
+ self.stage2d = RSU6(256,32,64)
521
+ self.stage1d = RSU7(128,16,64)
522
+
523
+ self.side1 = nn.Conv2d(64,out_ch,3,padding=1)
524
+ self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
525
+ self.side3 = nn.Conv2d(128,out_ch,3,padding=1)
526
+ self.side4 = nn.Conv2d(256,out_ch,3,padding=1)
527
+ self.side5 = nn.Conv2d(512,out_ch,3,padding=1)
528
+ self.side6 = nn.Conv2d(512,out_ch,3,padding=1)
529
+
530
+ # self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
531
+
532
+ def compute_loss_kl(self, preds, targets, dfs, fs, mode='MSE'):
533
+
534
+ # return muti_loss_fusion(preds,targets)
535
+ return muti_loss_fusion_kl(preds, targets, dfs, fs, mode=mode)
536
+
537
+ def compute_loss(self, preds, targets):
538
+
539
+ # return muti_loss_fusion(preds,targets)
540
+ return muti_loss_fusion(preds, targets)
541
+
542
+ def forward(self,x):
543
+
544
+ hx = x
545
+
546
+ hxin = self.conv_in(hx)
547
+ #hx = self.pool_in(hxin)
548
+
549
+ #stage 1
550
+ hx1 = self.stage1(hxin)
551
+ hx = self.pool12(hx1)
552
+
553
+ #stage 2
554
+ hx2 = self.stage2(hx)
555
+ hx = self.pool23(hx2)
556
+
557
+ #stage 3
558
+ hx3 = self.stage3(hx)
559
+ hx = self.pool34(hx3)
560
+
561
+ #stage 4
562
+ hx4 = self.stage4(hx)
563
+ hx = self.pool45(hx4)
564
+
565
+ #stage 5
566
+ hx5 = self.stage5(hx)
567
+ hx = self.pool56(hx5)
568
+
569
+ #stage 6
570
+ hx6 = self.stage6(hx)
571
+ hx6up = _upsample_like(hx6,hx5)
572
+
573
+ #-------------------- decoder --------------------
574
+ hx5d = self.stage5d(torch.cat((hx6up,hx5),1))
575
+ hx5dup = _upsample_like(hx5d,hx4)
576
+
577
+ hx4d = self.stage4d(torch.cat((hx5dup,hx4),1))
578
+ hx4dup = _upsample_like(hx4d,hx3)
579
+
580
+ hx3d = self.stage3d(torch.cat((hx4dup,hx3),1))
581
+ hx3dup = _upsample_like(hx3d,hx2)
582
+
583
+ hx2d = self.stage2d(torch.cat((hx3dup,hx2),1))
584
+ hx2dup = _upsample_like(hx2d,hx1)
585
+
586
+ hx1d = self.stage1d(torch.cat((hx2dup,hx1),1))
587
+
588
+
589
+ #side output
590
+ d1 = self.side1(hx1d)
591
+ d1 = _upsample_like(d1,x)
592
+
593
+ d2 = self.side2(hx2d)
594
+ d2 = _upsample_like(d2,x)
595
+
596
+ d3 = self.side3(hx3d)
597
+ d3 = _upsample_like(d3,x)
598
+
599
+ d4 = self.side4(hx4d)
600
+ d4 = _upsample_like(d4,x)
601
+
602
+ d5 = self.side5(hx5d)
603
+ d5 = _upsample_like(d5,x)
604
+
605
+ d6 = self.side6(hx6)
606
+ d6 = _upsample_like(d6,x)
607
+
608
+ # d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
609
+
610
+ return [F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)],[hx1d,hx2d,hx3d,hx4d,hx5d,hx6]
saved_models/isnet.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9e1aafea58f0b55d0c35077e0ceade6ba1ba2bce372fd4f8f77215391f3fac13
3
+ size 176579397