fmsfm commited on
Commit
1ff2d47
1 Parent(s): 555da6f

Upload 13 files

Browse files
.gitattributes CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ RCFPyTorch0/vgg16convs.mat filter=lfs diff=lfs merge=lfs -text
RCFPyTorch0 DELETED
@@ -1 +0,0 @@
1
- Subproject commit 0f1f2486e5cca2f0c564fc87bdd87b182bfb03c1
 
 
RCFPyTorch0/LICENSE.md ADDED
@@ -0,0 +1 @@
 
 
1
+ The code is released under the [Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Public License](https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode) for NonCommercial use only. Any commercial use should get formal permission first (Email: yun.liu@vision.ee.ethz.ch).
RCFPyTorch0/README.md ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## [Richer Convolutional Features for Edge Detection](http://mmcheng.net/rcfedge/)
2
+
3
+ This is the PyTorch implementation of our edge detection method, RCF.
4
+
5
+ ### Citations
6
+
7
+ If you are using the code/model/data provided here in a publication, please consider citing:
8
+
9
+ @article{liu2019richer,
10
+ title={Richer Convolutional Features for Edge Detection},
11
+ author={Liu, Yun and Cheng, Ming-Ming and Hu, Xiaowei and Bian, Jia-Wang and Zhang, Le and Bai, Xiang and Tang, Jinhui},
12
+ journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
13
+ volume={41},
14
+ number={8},
15
+ pages={1939--1946},
16
+ year={2019},
17
+ publisher={IEEE}
18
+ }
19
+
20
+ ### Training
21
+
22
+ 1. Clone the RCF repository:
23
+ ```
24
+ git clone https://github.com/yun-liu/RCF-PyTorch.git
25
+ ```
26
+
27
+ 2. Download the ImageNet-pretrained model ([Google Drive](https://drive.google.com/file/d/1szqDNG3dUO6BM3l6YBuC9vWp16n48-cK/view?usp=sharing) or [Baidu Yun](https://pan.baidu.com/s/1vfntX-cTKnk58atNW5T1lA?pwd=g5af)), and put it into the `$ROOT_DIR` folder.
28
+
29
+ 3. Download the datasets as below, and extract these datasets to the `$ROOT_DIR/data/` folder.
30
+
31
+ ```
32
+ wget http://mftp.mmcheng.net/liuyun/rcf/data/bsds_pascal_train_pair.lst
33
+ wget http://mftp.mmcheng.net/liuyun/rcf/data/HED-BSDS.tar.gz
34
+ wget http://mftp.mmcheng.net/liuyun/rcf/data/PASCAL.tar.gz
35
+ ```
36
+
37
+ 4. Run the following command to start the training:
38
+ ```
39
+ python train.py --save-dir /path/to/output/directory/
40
+ ```
41
+
42
+ ### Testing
43
+
44
+ 1. Download the pretrained model (BSDS500+PASCAL: [Google Drive](https://drive.google.com/file/d/1oxlHQCM4mm5zhHzmE7yho_oToU5Ucckk/view?usp=sharing) or [Baidu Yun](https://pan.baidu.com/s/1Tpf_-dIxHmKwH5IeClt0Ng?pwd=03ad)), and put it into the `$ROOT_DIR` folder.
45
+
46
+ 2. Run the following command to start the testing:
47
+ ```
48
+ python test.py --checkpoint bsds500_pascal_model.pth --save-dir /path/to/output/directory/
49
+ ```
50
+ This pretrained model should achieve an ODS F-measure of 0.812.
51
+
52
+ For more information about RCF and edge quality evaluation, please refer to this page: [yun-liu/RCF](https://github.com/yun-liu/RCF)
53
+
54
+ ### Edge PR Curves
55
+
56
+ We have released the code and data for plotting the edge PR curves of many existing edge detectors [here](https://github.com/yun-liu/plot-edge-pr-curves).
57
+
58
+ ### RCF based on other frameworks
59
+
60
+ Caffe based RCF: [yun-liu/RCF](https://github.com/yun-liu/RCF)
61
+
62
+ Jittor based RCF: [yun-liu/RCF-Jittor](https://github.com/yun-liu/RCF-Jittor)
63
+
64
+ ### Acknowledgements
65
+
66
+ [1] [balajiselvaraj1601/RCF_Pytorch_Updated](https://github.com/balajiselvaraj1601/RCF_Pytorch_Updated)
67
+
68
+ [2] [meteorshowers/RCF-pytorch](https://github.com/meteorshowers/RCF-pytorch)
RCFPyTorch0/__pycache__/dataset.cpython-37.pyc ADDED
Binary file (1.7 kB). View file
 
RCFPyTorch0/__pycache__/models.cpython-37.pyc ADDED
Binary file (5.05 kB). View file
 
RCFPyTorch0/__pycache__/utils.cpython-37.pyc ADDED
Binary file (2.23 kB). View file
 
RCFPyTorch0/__pycache__/web.cpython-37.pyc ADDED
Binary file (1.57 kB). View file
 
RCFPyTorch0/bsds500_pascal_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9913d9ae1eaa4a71022e89e8c8f6e3eeab5f9bd1cb6a2cc91b1bba7bf36e898c
3
+ size 59235375
RCFPyTorch0/dataset.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import cv2
3
+ import numpy as np
4
+ import os.path as osp
5
+
6
+
7
+ class BSDS_Dataset(torch.utils.data.Dataset):
8
+ def __init__(self, root='data/HED-BSDS', split='test', transform=False):
9
+ super(BSDS_Dataset, self).__init__()
10
+ self.root = root
11
+ self.split = split
12
+ self.transform = transform
13
+ if self.split == 'train':
14
+ self.file_list = osp.join(self.root, 'bsds_pascal_train_pair.lst')
15
+ elif self.split == 'test':
16
+ self.file_list = osp.join(self.root, 'test.lst')
17
+ else:
18
+ raise ValueError('Invalid split type!')
19
+ with open(self.file_list, 'r') as f:
20
+ self.file_list = f.readlines()
21
+ self.mean = np.array([104.00698793, 116.66876762, 122.67891434], dtype=np.float32)
22
+
23
+ def __len__(self):
24
+ return len(self.file_list)
25
+
26
+ def __getitem__(self, index):
27
+ if self.split == 'train':
28
+ img_file, label_file = self.file_list[index].split()
29
+ label = cv2.imread(osp.join(self.root, label_file), 0)
30
+ label = np.array(label, dtype=np.float32)
31
+ label = label[np.newaxis, :, :]
32
+ label[label == 0] = 0
33
+ label[np.logical_and(label > 0, label < 127.5)] = 2
34
+ label[label >= 127.5] = 1
35
+ else:
36
+ img_file = self.file_list[index].rstrip()
37
+
38
+ img = cv2.imread(osp.join(self.root, img_file))
39
+ img = np.array(img, dtype=np.float32)
40
+ img = (img - self.mean).transpose((2, 0, 1))
41
+
42
+ if self.split == 'train':
43
+ return img, label
44
+ else:
45
+ return img
RCFPyTorch0/models.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ import scipy.io as sio
5
+ import torch.nn.functional as F
6
+
7
+
8
+ class RCF(nn.Module):
9
+ def __init__(self, pretrained=None):
10
+ super(RCF, self).__init__()
11
+ self.conv1_1 = nn.Conv2d( 3, 64, 3, padding=1, dilation=1)
12
+ self.conv1_2 = nn.Conv2d( 64, 64, 3, padding=1, dilation=1)
13
+ self.conv2_1 = nn.Conv2d( 64, 128, 3, padding=1, dilation=1)
14
+ self.conv2_2 = nn.Conv2d(128, 128, 3, padding=1, dilation=1)
15
+ self.conv3_1 = nn.Conv2d(128, 256, 3, padding=1, dilation=1)
16
+ self.conv3_2 = nn.Conv2d(256, 256, 3, padding=1, dilation=1)
17
+ self.conv3_3 = nn.Conv2d(256, 256, 3, padding=1, dilation=1)
18
+ self.conv4_1 = nn.Conv2d(256, 512, 3, padding=1, dilation=1)
19
+ self.conv4_2 = nn.Conv2d(512, 512, 3, padding=1, dilation=1)
20
+ self.conv4_3 = nn.Conv2d(512, 512, 3, padding=1, dilation=1)
21
+ self.conv5_1 = nn.Conv2d(512, 512, 3, padding=2, dilation=2)
22
+ self.conv5_2 = nn.Conv2d(512, 512, 3, padding=2, dilation=2)
23
+ self.conv5_3 = nn.Conv2d(512, 512, 3, padding=2, dilation=2)
24
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
25
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
26
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
27
+ self.pool4 = nn.MaxPool2d(2, stride=1, ceil_mode=True)
28
+ self.act = nn.ReLU(inplace=True)
29
+
30
+ self.conv1_1_down = nn.Conv2d( 64, 21, 1)
31
+ self.conv1_2_down = nn.Conv2d( 64, 21, 1)
32
+ self.conv2_1_down = nn.Conv2d(128, 21, 1)
33
+ self.conv2_2_down = nn.Conv2d(128, 21, 1)
34
+ self.conv3_1_down = nn.Conv2d(256, 21, 1)
35
+ self.conv3_2_down = nn.Conv2d(256, 21, 1)
36
+ self.conv3_3_down = nn.Conv2d(256, 21, 1)
37
+ self.conv4_1_down = nn.Conv2d(512, 21, 1)
38
+ self.conv4_2_down = nn.Conv2d(512, 21, 1)
39
+ self.conv4_3_down = nn.Conv2d(512, 21, 1)
40
+ self.conv5_1_down = nn.Conv2d(512, 21, 1)
41
+ self.conv5_2_down = nn.Conv2d(512, 21, 1)
42
+ self.conv5_3_down = nn.Conv2d(512, 21, 1)
43
+
44
+ self.score_dsn1 = nn.Conv2d(21, 1, 1)
45
+ self.score_dsn2 = nn.Conv2d(21, 1, 1)
46
+ self.score_dsn3 = nn.Conv2d(21, 1, 1)
47
+ self.score_dsn4 = nn.Conv2d(21, 1, 1)
48
+ self.score_dsn5 = nn.Conv2d(21, 1, 1)
49
+ self.score_fuse = nn.Conv2d(5, 1, 1)
50
+
51
+ self.weight_deconv2 = self._make_bilinear_weights( 4, 1).cuda()
52
+ self.weight_deconv3 = self._make_bilinear_weights( 8, 1).cuda()
53
+ self.weight_deconv4 = self._make_bilinear_weights(16, 1).cuda()
54
+ self.weight_deconv5 = self._make_bilinear_weights(16, 1).cuda()
55
+
56
+ # init weights
57
+ self.apply(self._init_weights)
58
+ if pretrained is not None:
59
+ vgg16 = sio.loadmat(pretrained)
60
+ torch_params = self.state_dict()
61
+
62
+ for k in vgg16.keys():
63
+ name_par = k.split('-')
64
+ size = len(name_par)
65
+ if size == 2:
66
+ name_space = name_par[0] + '.' + name_par[1]
67
+ data = np.squeeze(vgg16[k])
68
+ torch_params[name_space] = torch.from_numpy(data)
69
+ self.load_state_dict(torch_params)
70
+
71
+ def _init_weights(self, m):
72
+ if isinstance(m, nn.Conv2d):
73
+ m.weight.data.normal_(0, 0.01)
74
+ if m.weight.data.shape == torch.Size([1, 5, 1, 1]):
75
+ nn.init.constant_(m.weight, 0.2)
76
+ if m.bias is not None:
77
+ nn.init.constant_(m.bias, 0)
78
+
79
+ # Based on HED implementation @ https://github.com/xwjabc/hed
80
+ def _make_bilinear_weights(self, size, num_channels):
81
+ factor = (size + 1) // 2
82
+ if size % 2 == 1:
83
+ center = factor - 1
84
+ else:
85
+ center = factor - 0.5
86
+ og = np.ogrid[:size, :size]
87
+ filt = (1 - abs(og[0] - center) / factor) * (1 - abs(og[1] - center) / factor)
88
+ filt = torch.from_numpy(filt)
89
+ w = torch.zeros(num_channels, num_channels, size, size)
90
+ w.requires_grad = False
91
+ for i in range(num_channels):
92
+ for j in range(num_channels):
93
+ if i == j:
94
+ w[i, j] = filt
95
+ return w
96
+
97
+ # Based on BDCN implementation @ https://github.com/pkuCactus/BDCN
98
+ def _crop(self, data, img_h, img_w, crop_h, crop_w):
99
+ _, _, h, w = data.size()
100
+ assert(img_h <= h and img_w <= w)
101
+ data = data[:, :, crop_h:crop_h + img_h, crop_w:crop_w + img_w]
102
+ return data
103
+
104
+ def forward(self, x):
105
+ img_h, img_w = x.shape[2], x.shape[3]
106
+ conv1_1 = self.act(self.conv1_1(x))
107
+ conv1_2 = self.act(self.conv1_2(conv1_1))
108
+ pool1 = self.pool1(conv1_2)
109
+ conv2_1 = self.act(self.conv2_1(pool1))
110
+ conv2_2 = self.act(self.conv2_2(conv2_1))
111
+ pool2 = self.pool2(conv2_2)
112
+ conv3_1 = self.act(self.conv3_1(pool2))
113
+ conv3_2 = self.act(self.conv3_2(conv3_1))
114
+ conv3_3 = self.act(self.conv3_3(conv3_2))
115
+ pool3 = self.pool3(conv3_3)
116
+ conv4_1 = self.act(self.conv4_1(pool3))
117
+ conv4_2 = self.act(self.conv4_2(conv4_1))
118
+ conv4_3 = self.act(self.conv4_3(conv4_2))
119
+ pool4 = self.pool4(conv4_3)
120
+ conv5_1 = self.act(self.conv5_1(pool4))
121
+ conv5_2 = self.act(self.conv5_2(conv5_1))
122
+ conv5_3 = self.act(self.conv5_3(conv5_2))
123
+
124
+ conv1_1_down = self.conv1_1_down(conv1_1)
125
+ conv1_2_down = self.conv1_2_down(conv1_2)
126
+ conv2_1_down = self.conv2_1_down(conv2_1)
127
+ conv2_2_down = self.conv2_2_down(conv2_2)
128
+ conv3_1_down = self.conv3_1_down(conv3_1)
129
+ conv3_2_down = self.conv3_2_down(conv3_2)
130
+ conv3_3_down = self.conv3_3_down(conv3_3)
131
+ conv4_1_down = self.conv4_1_down(conv4_1)
132
+ conv4_2_down = self.conv4_2_down(conv4_2)
133
+ conv4_3_down = self.conv4_3_down(conv4_3)
134
+ conv5_1_down = self.conv5_1_down(conv5_1)
135
+ conv5_2_down = self.conv5_2_down(conv5_2)
136
+ conv5_3_down = self.conv5_3_down(conv5_3)
137
+
138
+ out1 = self.score_dsn1(conv1_1_down + conv1_2_down)
139
+ out2 = self.score_dsn2(conv2_1_down + conv2_2_down)
140
+ out3 = self.score_dsn3(conv3_1_down + conv3_2_down + conv3_3_down)
141
+ out4 = self.score_dsn4(conv4_1_down + conv4_2_down + conv4_3_down)
142
+ out5 = self.score_dsn5(conv5_1_down + conv5_2_down + conv5_3_down)
143
+
144
+ out2 = F.conv_transpose2d(out2, self.weight_deconv2, stride=2)
145
+ out3 = F.conv_transpose2d(out3, self.weight_deconv3, stride=4)
146
+ out4 = F.conv_transpose2d(out4, self.weight_deconv4, stride=8)
147
+ out5 = F.conv_transpose2d(out5, self.weight_deconv5, stride=8)
148
+
149
+ out2 = self._crop(out2, img_h, img_w, 1, 1)
150
+ out3 = self._crop(out3, img_h, img_w, 2, 2)
151
+ out4 = self._crop(out4, img_h, img_w, 4, 4)
152
+ out5 = self._crop(out5, img_h, img_w, 0, 0)
153
+
154
+ fuse = torch.cat((out1, out2, out3, out4, out5), dim=1)
155
+ fuse = self.score_fuse(fuse)
156
+ results = [out1, out2, out3, out4, out5, fuse]
157
+ results = [torch.sigmoid(r) for r in results]
158
+ return results
RCFPyTorch0/test.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import os.path as osp
4
+ import cv2
5
+ import argparse
6
+ import torch
7
+ from torch.utils.data import DataLoader
8
+ import torchvision
9
+ from dataset import BSDS_Dataset
10
+ from models import RCF
11
+
12
+
13
+ def single_scale_test(model, test_loader, test_list, save_dir):
14
+ model.eval()
15
+ if not osp.isdir(save_dir):
16
+ os.makedirs(save_dir)
17
+ for idx, image in enumerate(test_loader):
18
+ image = image.cuda()
19
+ _, _, H, W = image.shape
20
+ results = model(image)
21
+ all_res = torch.zeros((len(results), 1, H, W))
22
+ for i in range(len(results)):
23
+ all_res[i, 0, :, :] = results[i]
24
+ filename = osp.splitext(test_list[idx])[0]
25
+ torchvision.utils.save_image(1 - all_res, osp.join(save_dir, '%s.jpg' % filename))
26
+ fuse_res = torch.squeeze(results[-1].detach()).cpu().numpy()
27
+ fuse_res = ((1 - fuse_res) * 255).astype(np.uint8)
28
+ cv2.imwrite(osp.join(save_dir, '%s_ss.png' % filename), fuse_res)
29
+ #print('\rRunning single-scale test [%d/%d]' % (idx + 1, len(test_loader)), end='')
30
+ print('Running single-scale test done')
31
+
32
+
33
+ def multi_scale_test(model, test_loader, test_list, save_dir):
34
+ model.eval()
35
+ if not osp.isdir(save_dir):
36
+ os.makedirs(save_dir)
37
+ scale = [0.5, 1, 1.5]
38
+ for idx, image in enumerate(test_loader):
39
+ in_ = image[0].numpy().transpose((1, 2, 0))
40
+ _, _, H, W = image.shape
41
+ ms_fuse = np.zeros((H, W), np.float32)
42
+ for k in range(len(scale)):
43
+ im_ = cv2.resize(in_, None, fx=scale[k], fy=scale[k], interpolation=cv2.INTER_LINEAR)
44
+ im_ = im_.transpose((2, 0, 1))
45
+ results = model(torch.unsqueeze(torch.from_numpy(im_).cuda(), 0))
46
+ fuse_res = torch.squeeze(results[-1].detach()).cpu().numpy()
47
+ fuse_res = cv2.resize(fuse_res, (W, H), interpolation=cv2.INTER_LINEAR)
48
+ ms_fuse += fuse_res
49
+ ms_fuse = ms_fuse / len(scale)
50
+ ### rescale trick
51
+ # ms_fuse = (ms_fuse - ms_fuse.min()) / (ms_fuse.max() - ms_fuse.min())
52
+ filename = osp.splitext(test_list[idx])[0]
53
+ ms_fuse = ((1 - ms_fuse) * 255).astype(np.uint8)
54
+ cv2.imwrite(osp.join(save_dir, '%s_ms.png' % filename), ms_fuse)
55
+ #print('\rRunning multi-scale test [%d/%d]' % (idx + 1, len(test_loader)), end='')
56
+ print('Running multi-scale test done')
57
+
58
+
59
+ if __name__ == '__main__':
60
+ parser = argparse.ArgumentParser(description='PyTorch Testing')
61
+ parser.add_argument('--gpu', default='0', type=str, help='GPU ID')
62
+ parser.add_argument('--checkpoint', default=None, type=str, help='path to latest checkpoint')
63
+ parser.add_argument('--save-dir', help='output folder', default='results/RCF')
64
+ parser.add_argument('--dataset', help='root folder of dataset', default='data/HED-BSDS')
65
+ args = parser.parse_args()
66
+
67
+ os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
68
+ os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
69
+
70
+ if not osp.isdir(args.save_dir):
71
+ os.makedirs(args.save_dir)
72
+
73
+ test_dataset = BSDS_Dataset(root=args.dataset, split='test')
74
+ test_loader = DataLoader(test_dataset, batch_size=1, num_workers=1, drop_last=False, shuffle=False)
75
+ test_list = [osp.split(i.rstrip())[1] for i in test_dataset.file_list]
76
+ assert len(test_list) == len(test_loader)
77
+
78
+ model = RCF().cuda()
79
+
80
+ if osp.isfile(args.checkpoint):
81
+ print("=> loading checkpoint from '{}'".format(args.checkpoint))
82
+ checkpoint = torch.load(args.checkpoint)
83
+ model.load_state_dict(checkpoint)
84
+ print("=> checkpoint loaded")
85
+ else:
86
+ print("=> no checkpoint found at '{}'".format(args.checkpoint))
87
+
88
+ print('Performing the testing...')
89
+ single_scale_test(model, test_loader, test_list, args.save_dir)
90
+ multi_scale_test(model, test_loader, test_list, args.save_dir)
RCFPyTorch0/train.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import os.path as osp
4
+ import cv2
5
+ import argparse
6
+ import time
7
+ import torch
8
+ from torch.utils.data import DataLoader
9
+ import torchvision
10
+ from dataset import BSDS_Dataset
11
+ from models import RCF
12
+ from utils import Logger, Averagvalue, Cross_entropy_loss
13
+
14
+
15
+ def train(args, model, train_loader, optimizer, epoch, logger):
16
+ batch_time = Averagvalue()
17
+ losses = Averagvalue()
18
+ model.train()
19
+ end = time.time()
20
+ counter = 0
21
+ for i, (image, label) in enumerate(train_loader):
22
+ image, label = image.cuda(), label.cuda()
23
+ outputs = model(image)
24
+ loss = torch.zeros(1).cuda()
25
+ for o in outputs:
26
+ loss = loss + Cross_entropy_loss(o, label)
27
+ counter += 1
28
+ loss = loss / args.iter_size
29
+ loss.backward()
30
+ if counter == args.iter_size:
31
+ optimizer.step()
32
+ optimizer.zero_grad()
33
+ counter = 0
34
+ # measure accuracy and record loss
35
+ losses.update(loss.item(), image.size(0))
36
+ batch_time.update(time.time() - end)
37
+ if i % args.print_freq == 0:
38
+ logger.info('Epoch: [{0}/{1}][{2}/{3}] '.format(epoch + 1, args.max_epoch, i, len(train_loader)) + \
39
+ 'Time {batch_time.val:.3f} (avg: {batch_time.avg:.3f}) '.format(batch_time=batch_time) + \
40
+ 'Loss {loss.val:f} (avg: {loss.avg:f}) '.format(loss=losses))
41
+ end = time.time()
42
+
43
+
44
+ def single_scale_test(model, test_loader, test_list, save_dir):
45
+ model.eval()
46
+ if not osp.isdir(save_dir):
47
+ os.makedirs(save_dir)
48
+ for idx, image in enumerate(test_loader):
49
+ image = image.cuda()
50
+ _, _, H, W = image.shape
51
+ results = model(image)
52
+ all_res = torch.zeros((len(results), 1, H, W))
53
+ for i in range(len(results)):
54
+ all_res[i, 0, :, :] = results[i]
55
+ filename = osp.splitext(test_list[idx])[0]
56
+ torchvision.utils.save_image(1 - all_res, osp.join(save_dir, '%s.jpg' % filename))
57
+ fuse_res = torch.squeeze(results[-1].detach()).cpu().numpy()
58
+ fuse_res = ((1 - fuse_res) * 255).astype(np.uint8)
59
+ cv2.imwrite(osp.join(save_dir, '%s_ss.png' % filename), fuse_res)
60
+ #print('\rRunning single-scale test [%d/%d]' % (idx + 1, len(test_loader)), end='')
61
+ logger.info('Running single-scale test done')
62
+
63
+
64
+ def multi_scale_test(model, test_loader, test_list, save_dir):
65
+ model.eval()
66
+ if not osp.isdir(save_dir):
67
+ os.makedirs(save_dir)
68
+ scale = [0.5, 1, 1.5]
69
+ for idx, image in enumerate(test_loader):
70
+ in_ = image[0].numpy().transpose((1, 2, 0))
71
+ _, _, H, W = image.shape
72
+ ms_fuse = np.zeros((H, W), np.float32)
73
+ for k in range(len(scale)):
74
+ im_ = cv2.resize(in_, None, fx=scale[k], fy=scale[k], interpolation=cv2.INTER_LINEAR)
75
+ im_ = im_.transpose((2, 0, 1))
76
+ results = model(torch.unsqueeze(torch.from_numpy(im_).cuda(), 0))
77
+ fuse_res = torch.squeeze(results[-1].detach()).cpu().numpy()
78
+ fuse_res = cv2.resize(fuse_res, (W, H), interpolation=cv2.INTER_LINEAR)
79
+ ms_fuse += fuse_res
80
+ ms_fuse = ms_fuse / len(scale)
81
+ ### rescale trick
82
+ # ms_fuse = (ms_fuse - ms_fuse.min()) / (ms_fuse.max() - ms_fuse.min())
83
+ filename = osp.splitext(test_list[idx])[0]
84
+ ms_fuse = ((1 - ms_fuse) * 255).astype(np.uint8)
85
+ cv2.imwrite(osp.join(save_dir, '%s_ms.png' % filename), ms_fuse)
86
+ #print('\rRunning multi-scale test [%d/%d]' % (idx + 1, len(test_loader)), end='')
87
+ logger.info('Running multi-scale test done')
88
+
89
+
90
+ if __name__ == '__main__':
91
+ parser = argparse.ArgumentParser(description='PyTorch Training')
92
+ parser.add_argument('--batch-size', default=1, type=int, help='batch size')
93
+ parser.add_argument('--lr', default=1e-6, type=float, help='initial learning rate')
94
+ parser.add_argument('--momentum', default=0.9, type=float, help='momentum')
95
+ parser.add_argument('--weight-decay', default=2e-4, type=float, help='weight decay')
96
+ parser.add_argument('--stepsize', default=3, type=int, help='learning rate step size')
97
+ parser.add_argument('--gamma', default=0.1, type=float, help='learning rate decay rate')
98
+ parser.add_argument('--max-epoch', default=10, type=int, help='the number of training epochs')
99
+ parser.add_argument('--iter-size', default=10, type=int, help='iter size')
100
+ parser.add_argument('--start-epoch', default=0, type=int, help='manual epoch number')
101
+ parser.add_argument('--print-freq', default=200, type=int, help='print frequency')
102
+ parser.add_argument('--gpu', default='0', type=str, help='GPU ID')
103
+ parser.add_argument('--resume', default=None, type=str, help='path to latest checkpoint')
104
+ parser.add_argument('--save-dir', help='output folder', default='results/RCF')
105
+ parser.add_argument('--dataset', help='root folder of dataset', default='data')
106
+ args = parser.parse_args()
107
+
108
+ os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
109
+ os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
110
+
111
+ if not osp.isdir(args.save_dir):
112
+ os.makedirs(args.save_dir)
113
+
114
+ logger = Logger(osp.join(args.save_dir, 'log.txt'))
115
+ logger.info('Called with args:')
116
+ for (key, value) in vars(args).items():
117
+ logger.info('{0:15} | {1}'.format(key, value))
118
+
119
+ train_dataset = BSDS_Dataset(root=args.dataset, split='train')
120
+ test_dataset = BSDS_Dataset(root=osp.join(args.dataset, 'HED-BSDS'), split='test')
121
+ train_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=4, drop_last=True, shuffle=True)
122
+ test_loader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=4, drop_last=False, shuffle=False)
123
+ test_list = [osp.split(i.rstrip())[1] for i in test_dataset.file_list]
124
+ assert len(test_list) == len(test_loader)
125
+
126
+ model = RCF(pretrained='vgg16convs.mat').cuda()
127
+ parameters = {'conv1-4.weight': [], 'conv1-4.bias': [], 'conv5.weight': [], 'conv5.bias': [],
128
+ 'conv_down_1-5.weight': [], 'conv_down_1-5.bias': [], 'score_dsn_1-5.weight': [],
129
+ 'score_dsn_1-5.bias': [], 'score_fuse.weight': [], 'score_fuse.bias': []}
130
+ for pname, p in model.named_parameters():
131
+ if pname in ['conv1_1.weight','conv1_2.weight',
132
+ 'conv2_1.weight','conv2_2.weight',
133
+ 'conv3_1.weight','conv3_2.weight','conv3_3.weight',
134
+ 'conv4_1.weight','conv4_2.weight','conv4_3.weight']:
135
+ parameters['conv1-4.weight'].append(p)
136
+ elif pname in ['conv1_1.bias','conv1_2.bias',
137
+ 'conv2_1.bias','conv2_2.bias',
138
+ 'conv3_1.bias','conv3_2.bias','conv3_3.bias',
139
+ 'conv4_1.bias','conv4_2.bias','conv4_3.bias']:
140
+ parameters['conv1-4.bias'].append(p)
141
+ elif pname in ['conv5_1.weight','conv5_2.weight','conv5_3.weight']:
142
+ parameters['conv5.weight'].append(p)
143
+ elif pname in ['conv5_1.bias','conv5_2.bias','conv5_3.bias']:
144
+ parameters['conv5.bias'].append(p)
145
+ elif pname in ['conv1_1_down.weight','conv1_2_down.weight',
146
+ 'conv2_1_down.weight','conv2_2_down.weight',
147
+ 'conv3_1_down.weight','conv3_2_down.weight','conv3_3_down.weight',
148
+ 'conv4_1_down.weight','conv4_2_down.weight','conv4_3_down.weight',
149
+ 'conv5_1_down.weight','conv5_2_down.weight','conv5_3_down.weight']:
150
+ parameters['conv_down_1-5.weight'].append(p)
151
+ elif pname in ['conv1_1_down.bias','conv1_2_down.bias',
152
+ 'conv2_1_down.bias','conv2_2_down.bias',
153
+ 'conv3_1_down.bias','conv3_2_down.bias','conv3_3_down.bias',
154
+ 'conv4_1_down.bias','conv4_2_down.bias','conv4_3_down.bias',
155
+ 'conv5_1_down.bias','conv5_2_down.bias','conv5_3_down.bias']:
156
+ parameters['conv_down_1-5.bias'].append(p)
157
+ elif pname in ['score_dsn1.weight','score_dsn2.weight','score_dsn3.weight', 'score_dsn4.weight','score_dsn5.weight']:
158
+ parameters['score_dsn_1-5.weight'].append(p)
159
+ elif pname in ['score_dsn1.bias','score_dsn2.bias','score_dsn3.bias', 'score_dsn4.bias','score_dsn5.bias']:
160
+ parameters['score_dsn_1-5.bias'].append(p)
161
+ elif pname in ['score_fuse.weight']:
162
+ parameters['score_fuse.weight'].append(p)
163
+ elif pname in ['score_fuse.bias']:
164
+ parameters['score_fuse.bias'].append(p)
165
+
166
+ optimizer = torch.optim.SGD([
167
+ {'params': parameters['conv1-4.weight'], 'lr': args.lr*1, 'weight_decay': args.weight_decay},
168
+ {'params': parameters['conv1-4.bias'], 'lr': args.lr*2, 'weight_decay': 0.},
169
+ {'params': parameters['conv5.weight'], 'lr': args.lr*100, 'weight_decay': args.weight_decay},
170
+ {'params': parameters['conv5.bias'], 'lr': args.lr*200, 'weight_decay': 0.},
171
+ {'params': parameters['conv_down_1-5.weight'], 'lr': args.lr*0.1, 'weight_decay': args.weight_decay},
172
+ {'params': parameters['conv_down_1-5.bias'], 'lr': args.lr*0.2, 'weight_decay': 0.},
173
+ {'params': parameters['score_dsn_1-5.weight'], 'lr': args.lr*0.01, 'weight_decay': args.weight_decay},
174
+ {'params': parameters['score_dsn_1-5.bias'], 'lr': args.lr*0.02, 'weight_decay': 0.},
175
+ {'params': parameters['score_fuse.weight'], 'lr': args.lr*0.001, 'weight_decay': args.weight_decay},
176
+ {'params': parameters['score_fuse.bias'], 'lr': args.lr*0.002, 'weight_decay': 0.},
177
+ ], lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
178
+ lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.stepsize, gamma=args.gamma)
179
+
180
+ if args.resume is not None:
181
+ if osp.isfile(args.resume):
182
+ logger.info("=> loading checkpoint from '{}'".format(args.resume))
183
+ checkpoint = torch.load(args.resume)
184
+ model.load_state_dict(checkpoint['state_dict'])
185
+ optimizer.load_state_dict(checkpoint['optimizer'])
186
+ lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
187
+ args.start_epoch = checkpoint['epoch'] + 1
188
+ logger.info("=> checkpoint loaded")
189
+ else:
190
+ logger.info("=> no checkpoint found at '{}'".format(args.resume))
191
+
192
+ for epoch in range(args.start_epoch, args.max_epoch):
193
+ logger.info('Performing initial testing...')
194
+ train(args, model, train_loader, optimizer, epoch, logger)
195
+ save_dir = osp.join(args.save_dir, 'epoch%d-test' % (epoch + 1))
196
+ single_scale_test(model, test_loader, test_list, save_dir)
197
+ multi_scale_test(model, test_loader, test_list, save_dir)
198
+ # Save checkpoint
199
+ save_file = osp.join(args.save_dir, 'checkpoint_epoch{}.pth'.format(epoch + 1))
200
+ torch.save({
201
+ 'epoch': epoch,
202
+ 'args': args,
203
+ 'state_dict': model.state_dict(),
204
+ 'optimizer': optimizer.state_dict(),
205
+ 'lr_scheduler': lr_scheduler.state_dict(),
206
+ }, save_file)
207
+ lr_scheduler.step() # will adjust learning rate
208
+
209
+ logger.close()
RCFPyTorch0/utils.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn.functional as F
6
+
7
+
8
+ class Logger(object):
9
+ def __init__(self, path='log.txt'):
10
+ self.logger = logging.getLogger('Logger')
11
+ self.file_handler = logging.FileHandler(path, 'w')
12
+ self.stdout_handler = logging.StreamHandler()
13
+ self.logger.addHandler(self.file_handler)
14
+ self.logger.addHandler(self.stdout_handler)
15
+ self.stdout_handler.setFormatter(logging.Formatter('%(asctime)s %(levelname)s %(message)s'))
16
+ self.file_handler.setFormatter(logging.Formatter('%(asctime)s %(levelname)s %(message)s'))
17
+ self.logger.setLevel(logging.INFO)
18
+
19
+ def info(self, txt):
20
+ self.logger.info(txt)
21
+
22
+ def close(self):
23
+ self.file_handler.close()
24
+ self.stdout_handler.close()
25
+
26
+
27
+ class Averagvalue(object):
28
+ """Computes and stores the average and current value"""
29
+ def __init__(self):
30
+ self.reset()
31
+
32
+ def reset(self):
33
+ self.val = 0
34
+ self.avg = 0
35
+ self.sum = 0
36
+ self.count = 0
37
+
38
+ def update(self, val, n=1):
39
+ self.val = val
40
+ self.sum += val * n
41
+ self.count += n
42
+ self.avg = self.sum / self.count
43
+
44
+
45
+ def Cross_entropy_loss(prediction, label):
46
+ mask = label.clone()
47
+ num_positive = torch.sum((mask == 1).float()).float()
48
+ num_negative = torch.sum((mask == 0).float()).float()
49
+
50
+ mask[mask == 1] = 1.0 * num_negative / (num_positive + num_negative)
51
+ mask[mask == 0] = 1.1 * num_positive / (num_positive + num_negative)
52
+ mask[mask == 2] = 0
53
+ cost = F.binary_cross_entropy(prediction, label, weight=mask, reduce=False)
54
+ return torch.sum(cost)
RCFPyTorch0/vgg16convs.mat ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bce56b30c32d4c72954355fe970c87dceba15bc180aa89524960fda1e0e32cd9
3
+ size 58860856