Spaces:
Runtime error
Runtime error
Upload 13 files
Browse files- .gitattributes +1 -0
- RCFPyTorch0 +0 -1
- RCFPyTorch0/LICENSE.md +1 -0
- RCFPyTorch0/README.md +68 -0
- RCFPyTorch0/__pycache__/dataset.cpython-37.pyc +0 -0
- RCFPyTorch0/__pycache__/models.cpython-37.pyc +0 -0
- RCFPyTorch0/__pycache__/utils.cpython-37.pyc +0 -0
- RCFPyTorch0/__pycache__/web.cpython-37.pyc +0 -0
- RCFPyTorch0/bsds500_pascal_model.pth +3 -0
- RCFPyTorch0/dataset.py +45 -0
- RCFPyTorch0/models.py +158 -0
- RCFPyTorch0/test.py +90 -0
- RCFPyTorch0/train.py +209 -0
- RCFPyTorch0/utils.py +54 -0
- RCFPyTorch0/vgg16convs.mat +3 -0
.gitattributes
CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
+
RCFPyTorch0/vgg16convs.mat filter=lfs diff=lfs merge=lfs -text
|
RCFPyTorch0
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
Subproject commit 0f1f2486e5cca2f0c564fc87bdd87b182bfb03c1
|
|
|
|
RCFPyTorch0/LICENSE.md
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
The code is released under the [Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Public License](https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode) for NonCommercial use only. Any commercial use should get formal permission first (Email: yun.liu@vision.ee.ethz.ch).
|
RCFPyTorch0/README.md
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## [Richer Convolutional Features for Edge Detection](http://mmcheng.net/rcfedge/)
|
2 |
+
|
3 |
+
This is the PyTorch implementation of our edge detection method, RCF.
|
4 |
+
|
5 |
+
### Citations
|
6 |
+
|
7 |
+
If you are using the code/model/data provided here in a publication, please consider citing:
|
8 |
+
|
9 |
+
@article{liu2019richer,
|
10 |
+
title={Richer Convolutional Features for Edge Detection},
|
11 |
+
author={Liu, Yun and Cheng, Ming-Ming and Hu, Xiaowei and Bian, Jia-Wang and Zhang, Le and Bai, Xiang and Tang, Jinhui},
|
12 |
+
journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
|
13 |
+
volume={41},
|
14 |
+
number={8},
|
15 |
+
pages={1939--1946},
|
16 |
+
year={2019},
|
17 |
+
publisher={IEEE}
|
18 |
+
}
|
19 |
+
|
20 |
+
### Training
|
21 |
+
|
22 |
+
1. Clone the RCF repository:
|
23 |
+
```
|
24 |
+
git clone https://github.com/yun-liu/RCF-PyTorch.git
|
25 |
+
```
|
26 |
+
|
27 |
+
2. Download the ImageNet-pretrained model ([Google Drive](https://drive.google.com/file/d/1szqDNG3dUO6BM3l6YBuC9vWp16n48-cK/view?usp=sharing) or [Baidu Yun](https://pan.baidu.com/s/1vfntX-cTKnk58atNW5T1lA?pwd=g5af)), and put it into the `$ROOT_DIR` folder.
|
28 |
+
|
29 |
+
3. Download the datasets as below, and extract these datasets to the `$ROOT_DIR/data/` folder.
|
30 |
+
|
31 |
+
```
|
32 |
+
wget http://mftp.mmcheng.net/liuyun/rcf/data/bsds_pascal_train_pair.lst
|
33 |
+
wget http://mftp.mmcheng.net/liuyun/rcf/data/HED-BSDS.tar.gz
|
34 |
+
wget http://mftp.mmcheng.net/liuyun/rcf/data/PASCAL.tar.gz
|
35 |
+
```
|
36 |
+
|
37 |
+
4. Run the following command to start the training:
|
38 |
+
```
|
39 |
+
python train.py --save-dir /path/to/output/directory/
|
40 |
+
```
|
41 |
+
|
42 |
+
### Testing
|
43 |
+
|
44 |
+
1. Download the pretrained model (BSDS500+PASCAL: [Google Drive](https://drive.google.com/file/d/1oxlHQCM4mm5zhHzmE7yho_oToU5Ucckk/view?usp=sharing) or [Baidu Yun](https://pan.baidu.com/s/1Tpf_-dIxHmKwH5IeClt0Ng?pwd=03ad)), and put it into the `$ROOT_DIR` folder.
|
45 |
+
|
46 |
+
2. Run the following command to start the testing:
|
47 |
+
```
|
48 |
+
python test.py --checkpoint bsds500_pascal_model.pth --save-dir /path/to/output/directory/
|
49 |
+
```
|
50 |
+
This pretrained model should achieve an ODS F-measure of 0.812.
|
51 |
+
|
52 |
+
For more information about RCF and edge quality evaluation, please refer to this page: [yun-liu/RCF](https://github.com/yun-liu/RCF)
|
53 |
+
|
54 |
+
### Edge PR Curves
|
55 |
+
|
56 |
+
We have released the code and data for plotting the edge PR curves of many existing edge detectors [here](https://github.com/yun-liu/plot-edge-pr-curves).
|
57 |
+
|
58 |
+
### RCF based on other frameworks
|
59 |
+
|
60 |
+
Caffe based RCF: [yun-liu/RCF](https://github.com/yun-liu/RCF)
|
61 |
+
|
62 |
+
Jittor based RCF: [yun-liu/RCF-Jittor](https://github.com/yun-liu/RCF-Jittor)
|
63 |
+
|
64 |
+
### Acknowledgements
|
65 |
+
|
66 |
+
[1] [balajiselvaraj1601/RCF_Pytorch_Updated](https://github.com/balajiselvaraj1601/RCF_Pytorch_Updated)
|
67 |
+
|
68 |
+
[2] [meteorshowers/RCF-pytorch](https://github.com/meteorshowers/RCF-pytorch)
|
RCFPyTorch0/__pycache__/dataset.cpython-37.pyc
ADDED
Binary file (1.7 kB). View file
|
|
RCFPyTorch0/__pycache__/models.cpython-37.pyc
ADDED
Binary file (5.05 kB). View file
|
|
RCFPyTorch0/__pycache__/utils.cpython-37.pyc
ADDED
Binary file (2.23 kB). View file
|
|
RCFPyTorch0/__pycache__/web.cpython-37.pyc
ADDED
Binary file (1.57 kB). View file
|
|
RCFPyTorch0/bsds500_pascal_model.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9913d9ae1eaa4a71022e89e8c8f6e3eeab5f9bd1cb6a2cc91b1bba7bf36e898c
|
3 |
+
size 59235375
|
RCFPyTorch0/dataset.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
import os.path as osp
|
5 |
+
|
6 |
+
|
7 |
+
class BSDS_Dataset(torch.utils.data.Dataset):
|
8 |
+
def __init__(self, root='data/HED-BSDS', split='test', transform=False):
|
9 |
+
super(BSDS_Dataset, self).__init__()
|
10 |
+
self.root = root
|
11 |
+
self.split = split
|
12 |
+
self.transform = transform
|
13 |
+
if self.split == 'train':
|
14 |
+
self.file_list = osp.join(self.root, 'bsds_pascal_train_pair.lst')
|
15 |
+
elif self.split == 'test':
|
16 |
+
self.file_list = osp.join(self.root, 'test.lst')
|
17 |
+
else:
|
18 |
+
raise ValueError('Invalid split type!')
|
19 |
+
with open(self.file_list, 'r') as f:
|
20 |
+
self.file_list = f.readlines()
|
21 |
+
self.mean = np.array([104.00698793, 116.66876762, 122.67891434], dtype=np.float32)
|
22 |
+
|
23 |
+
def __len__(self):
|
24 |
+
return len(self.file_list)
|
25 |
+
|
26 |
+
def __getitem__(self, index):
|
27 |
+
if self.split == 'train':
|
28 |
+
img_file, label_file = self.file_list[index].split()
|
29 |
+
label = cv2.imread(osp.join(self.root, label_file), 0)
|
30 |
+
label = np.array(label, dtype=np.float32)
|
31 |
+
label = label[np.newaxis, :, :]
|
32 |
+
label[label == 0] = 0
|
33 |
+
label[np.logical_and(label > 0, label < 127.5)] = 2
|
34 |
+
label[label >= 127.5] = 1
|
35 |
+
else:
|
36 |
+
img_file = self.file_list[index].rstrip()
|
37 |
+
|
38 |
+
img = cv2.imread(osp.join(self.root, img_file))
|
39 |
+
img = np.array(img, dtype=np.float32)
|
40 |
+
img = (img - self.mean).transpose((2, 0, 1))
|
41 |
+
|
42 |
+
if self.split == 'train':
|
43 |
+
return img, label
|
44 |
+
else:
|
45 |
+
return img
|
RCFPyTorch0/models.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import numpy as np
|
4 |
+
import scipy.io as sio
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
|
8 |
+
class RCF(nn.Module):
|
9 |
+
def __init__(self, pretrained=None):
|
10 |
+
super(RCF, self).__init__()
|
11 |
+
self.conv1_1 = nn.Conv2d( 3, 64, 3, padding=1, dilation=1)
|
12 |
+
self.conv1_2 = nn.Conv2d( 64, 64, 3, padding=1, dilation=1)
|
13 |
+
self.conv2_1 = nn.Conv2d( 64, 128, 3, padding=1, dilation=1)
|
14 |
+
self.conv2_2 = nn.Conv2d(128, 128, 3, padding=1, dilation=1)
|
15 |
+
self.conv3_1 = nn.Conv2d(128, 256, 3, padding=1, dilation=1)
|
16 |
+
self.conv3_2 = nn.Conv2d(256, 256, 3, padding=1, dilation=1)
|
17 |
+
self.conv3_3 = nn.Conv2d(256, 256, 3, padding=1, dilation=1)
|
18 |
+
self.conv4_1 = nn.Conv2d(256, 512, 3, padding=1, dilation=1)
|
19 |
+
self.conv4_2 = nn.Conv2d(512, 512, 3, padding=1, dilation=1)
|
20 |
+
self.conv4_3 = nn.Conv2d(512, 512, 3, padding=1, dilation=1)
|
21 |
+
self.conv5_1 = nn.Conv2d(512, 512, 3, padding=2, dilation=2)
|
22 |
+
self.conv5_2 = nn.Conv2d(512, 512, 3, padding=2, dilation=2)
|
23 |
+
self.conv5_3 = nn.Conv2d(512, 512, 3, padding=2, dilation=2)
|
24 |
+
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
25 |
+
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
26 |
+
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
27 |
+
self.pool4 = nn.MaxPool2d(2, stride=1, ceil_mode=True)
|
28 |
+
self.act = nn.ReLU(inplace=True)
|
29 |
+
|
30 |
+
self.conv1_1_down = nn.Conv2d( 64, 21, 1)
|
31 |
+
self.conv1_2_down = nn.Conv2d( 64, 21, 1)
|
32 |
+
self.conv2_1_down = nn.Conv2d(128, 21, 1)
|
33 |
+
self.conv2_2_down = nn.Conv2d(128, 21, 1)
|
34 |
+
self.conv3_1_down = nn.Conv2d(256, 21, 1)
|
35 |
+
self.conv3_2_down = nn.Conv2d(256, 21, 1)
|
36 |
+
self.conv3_3_down = nn.Conv2d(256, 21, 1)
|
37 |
+
self.conv4_1_down = nn.Conv2d(512, 21, 1)
|
38 |
+
self.conv4_2_down = nn.Conv2d(512, 21, 1)
|
39 |
+
self.conv4_3_down = nn.Conv2d(512, 21, 1)
|
40 |
+
self.conv5_1_down = nn.Conv2d(512, 21, 1)
|
41 |
+
self.conv5_2_down = nn.Conv2d(512, 21, 1)
|
42 |
+
self.conv5_3_down = nn.Conv2d(512, 21, 1)
|
43 |
+
|
44 |
+
self.score_dsn1 = nn.Conv2d(21, 1, 1)
|
45 |
+
self.score_dsn2 = nn.Conv2d(21, 1, 1)
|
46 |
+
self.score_dsn3 = nn.Conv2d(21, 1, 1)
|
47 |
+
self.score_dsn4 = nn.Conv2d(21, 1, 1)
|
48 |
+
self.score_dsn5 = nn.Conv2d(21, 1, 1)
|
49 |
+
self.score_fuse = nn.Conv2d(5, 1, 1)
|
50 |
+
|
51 |
+
self.weight_deconv2 = self._make_bilinear_weights( 4, 1).cuda()
|
52 |
+
self.weight_deconv3 = self._make_bilinear_weights( 8, 1).cuda()
|
53 |
+
self.weight_deconv4 = self._make_bilinear_weights(16, 1).cuda()
|
54 |
+
self.weight_deconv5 = self._make_bilinear_weights(16, 1).cuda()
|
55 |
+
|
56 |
+
# init weights
|
57 |
+
self.apply(self._init_weights)
|
58 |
+
if pretrained is not None:
|
59 |
+
vgg16 = sio.loadmat(pretrained)
|
60 |
+
torch_params = self.state_dict()
|
61 |
+
|
62 |
+
for k in vgg16.keys():
|
63 |
+
name_par = k.split('-')
|
64 |
+
size = len(name_par)
|
65 |
+
if size == 2:
|
66 |
+
name_space = name_par[0] + '.' + name_par[1]
|
67 |
+
data = np.squeeze(vgg16[k])
|
68 |
+
torch_params[name_space] = torch.from_numpy(data)
|
69 |
+
self.load_state_dict(torch_params)
|
70 |
+
|
71 |
+
def _init_weights(self, m):
|
72 |
+
if isinstance(m, nn.Conv2d):
|
73 |
+
m.weight.data.normal_(0, 0.01)
|
74 |
+
if m.weight.data.shape == torch.Size([1, 5, 1, 1]):
|
75 |
+
nn.init.constant_(m.weight, 0.2)
|
76 |
+
if m.bias is not None:
|
77 |
+
nn.init.constant_(m.bias, 0)
|
78 |
+
|
79 |
+
# Based on HED implementation @ https://github.com/xwjabc/hed
|
80 |
+
def _make_bilinear_weights(self, size, num_channels):
|
81 |
+
factor = (size + 1) // 2
|
82 |
+
if size % 2 == 1:
|
83 |
+
center = factor - 1
|
84 |
+
else:
|
85 |
+
center = factor - 0.5
|
86 |
+
og = np.ogrid[:size, :size]
|
87 |
+
filt = (1 - abs(og[0] - center) / factor) * (1 - abs(og[1] - center) / factor)
|
88 |
+
filt = torch.from_numpy(filt)
|
89 |
+
w = torch.zeros(num_channels, num_channels, size, size)
|
90 |
+
w.requires_grad = False
|
91 |
+
for i in range(num_channels):
|
92 |
+
for j in range(num_channels):
|
93 |
+
if i == j:
|
94 |
+
w[i, j] = filt
|
95 |
+
return w
|
96 |
+
|
97 |
+
# Based on BDCN implementation @ https://github.com/pkuCactus/BDCN
|
98 |
+
def _crop(self, data, img_h, img_w, crop_h, crop_w):
|
99 |
+
_, _, h, w = data.size()
|
100 |
+
assert(img_h <= h and img_w <= w)
|
101 |
+
data = data[:, :, crop_h:crop_h + img_h, crop_w:crop_w + img_w]
|
102 |
+
return data
|
103 |
+
|
104 |
+
def forward(self, x):
|
105 |
+
img_h, img_w = x.shape[2], x.shape[3]
|
106 |
+
conv1_1 = self.act(self.conv1_1(x))
|
107 |
+
conv1_2 = self.act(self.conv1_2(conv1_1))
|
108 |
+
pool1 = self.pool1(conv1_2)
|
109 |
+
conv2_1 = self.act(self.conv2_1(pool1))
|
110 |
+
conv2_2 = self.act(self.conv2_2(conv2_1))
|
111 |
+
pool2 = self.pool2(conv2_2)
|
112 |
+
conv3_1 = self.act(self.conv3_1(pool2))
|
113 |
+
conv3_2 = self.act(self.conv3_2(conv3_1))
|
114 |
+
conv3_3 = self.act(self.conv3_3(conv3_2))
|
115 |
+
pool3 = self.pool3(conv3_3)
|
116 |
+
conv4_1 = self.act(self.conv4_1(pool3))
|
117 |
+
conv4_2 = self.act(self.conv4_2(conv4_1))
|
118 |
+
conv4_3 = self.act(self.conv4_3(conv4_2))
|
119 |
+
pool4 = self.pool4(conv4_3)
|
120 |
+
conv5_1 = self.act(self.conv5_1(pool4))
|
121 |
+
conv5_2 = self.act(self.conv5_2(conv5_1))
|
122 |
+
conv5_3 = self.act(self.conv5_3(conv5_2))
|
123 |
+
|
124 |
+
conv1_1_down = self.conv1_1_down(conv1_1)
|
125 |
+
conv1_2_down = self.conv1_2_down(conv1_2)
|
126 |
+
conv2_1_down = self.conv2_1_down(conv2_1)
|
127 |
+
conv2_2_down = self.conv2_2_down(conv2_2)
|
128 |
+
conv3_1_down = self.conv3_1_down(conv3_1)
|
129 |
+
conv3_2_down = self.conv3_2_down(conv3_2)
|
130 |
+
conv3_3_down = self.conv3_3_down(conv3_3)
|
131 |
+
conv4_1_down = self.conv4_1_down(conv4_1)
|
132 |
+
conv4_2_down = self.conv4_2_down(conv4_2)
|
133 |
+
conv4_3_down = self.conv4_3_down(conv4_3)
|
134 |
+
conv5_1_down = self.conv5_1_down(conv5_1)
|
135 |
+
conv5_2_down = self.conv5_2_down(conv5_2)
|
136 |
+
conv5_3_down = self.conv5_3_down(conv5_3)
|
137 |
+
|
138 |
+
out1 = self.score_dsn1(conv1_1_down + conv1_2_down)
|
139 |
+
out2 = self.score_dsn2(conv2_1_down + conv2_2_down)
|
140 |
+
out3 = self.score_dsn3(conv3_1_down + conv3_2_down + conv3_3_down)
|
141 |
+
out4 = self.score_dsn4(conv4_1_down + conv4_2_down + conv4_3_down)
|
142 |
+
out5 = self.score_dsn5(conv5_1_down + conv5_2_down + conv5_3_down)
|
143 |
+
|
144 |
+
out2 = F.conv_transpose2d(out2, self.weight_deconv2, stride=2)
|
145 |
+
out3 = F.conv_transpose2d(out3, self.weight_deconv3, stride=4)
|
146 |
+
out4 = F.conv_transpose2d(out4, self.weight_deconv4, stride=8)
|
147 |
+
out5 = F.conv_transpose2d(out5, self.weight_deconv5, stride=8)
|
148 |
+
|
149 |
+
out2 = self._crop(out2, img_h, img_w, 1, 1)
|
150 |
+
out3 = self._crop(out3, img_h, img_w, 2, 2)
|
151 |
+
out4 = self._crop(out4, img_h, img_w, 4, 4)
|
152 |
+
out5 = self._crop(out5, img_h, img_w, 0, 0)
|
153 |
+
|
154 |
+
fuse = torch.cat((out1, out2, out3, out4, out5), dim=1)
|
155 |
+
fuse = self.score_fuse(fuse)
|
156 |
+
results = [out1, out2, out3, out4, out5, fuse]
|
157 |
+
results = [torch.sigmoid(r) for r in results]
|
158 |
+
return results
|
RCFPyTorch0/test.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import os.path as osp
|
4 |
+
import cv2
|
5 |
+
import argparse
|
6 |
+
import torch
|
7 |
+
from torch.utils.data import DataLoader
|
8 |
+
import torchvision
|
9 |
+
from dataset import BSDS_Dataset
|
10 |
+
from models import RCF
|
11 |
+
|
12 |
+
|
13 |
+
def single_scale_test(model, test_loader, test_list, save_dir):
|
14 |
+
model.eval()
|
15 |
+
if not osp.isdir(save_dir):
|
16 |
+
os.makedirs(save_dir)
|
17 |
+
for idx, image in enumerate(test_loader):
|
18 |
+
image = image.cuda()
|
19 |
+
_, _, H, W = image.shape
|
20 |
+
results = model(image)
|
21 |
+
all_res = torch.zeros((len(results), 1, H, W))
|
22 |
+
for i in range(len(results)):
|
23 |
+
all_res[i, 0, :, :] = results[i]
|
24 |
+
filename = osp.splitext(test_list[idx])[0]
|
25 |
+
torchvision.utils.save_image(1 - all_res, osp.join(save_dir, '%s.jpg' % filename))
|
26 |
+
fuse_res = torch.squeeze(results[-1].detach()).cpu().numpy()
|
27 |
+
fuse_res = ((1 - fuse_res) * 255).astype(np.uint8)
|
28 |
+
cv2.imwrite(osp.join(save_dir, '%s_ss.png' % filename), fuse_res)
|
29 |
+
#print('\rRunning single-scale test [%d/%d]' % (idx + 1, len(test_loader)), end='')
|
30 |
+
print('Running single-scale test done')
|
31 |
+
|
32 |
+
|
33 |
+
def multi_scale_test(model, test_loader, test_list, save_dir):
|
34 |
+
model.eval()
|
35 |
+
if not osp.isdir(save_dir):
|
36 |
+
os.makedirs(save_dir)
|
37 |
+
scale = [0.5, 1, 1.5]
|
38 |
+
for idx, image in enumerate(test_loader):
|
39 |
+
in_ = image[0].numpy().transpose((1, 2, 0))
|
40 |
+
_, _, H, W = image.shape
|
41 |
+
ms_fuse = np.zeros((H, W), np.float32)
|
42 |
+
for k in range(len(scale)):
|
43 |
+
im_ = cv2.resize(in_, None, fx=scale[k], fy=scale[k], interpolation=cv2.INTER_LINEAR)
|
44 |
+
im_ = im_.transpose((2, 0, 1))
|
45 |
+
results = model(torch.unsqueeze(torch.from_numpy(im_).cuda(), 0))
|
46 |
+
fuse_res = torch.squeeze(results[-1].detach()).cpu().numpy()
|
47 |
+
fuse_res = cv2.resize(fuse_res, (W, H), interpolation=cv2.INTER_LINEAR)
|
48 |
+
ms_fuse += fuse_res
|
49 |
+
ms_fuse = ms_fuse / len(scale)
|
50 |
+
### rescale trick
|
51 |
+
# ms_fuse = (ms_fuse - ms_fuse.min()) / (ms_fuse.max() - ms_fuse.min())
|
52 |
+
filename = osp.splitext(test_list[idx])[0]
|
53 |
+
ms_fuse = ((1 - ms_fuse) * 255).astype(np.uint8)
|
54 |
+
cv2.imwrite(osp.join(save_dir, '%s_ms.png' % filename), ms_fuse)
|
55 |
+
#print('\rRunning multi-scale test [%d/%d]' % (idx + 1, len(test_loader)), end='')
|
56 |
+
print('Running multi-scale test done')
|
57 |
+
|
58 |
+
|
59 |
+
if __name__ == '__main__':
|
60 |
+
parser = argparse.ArgumentParser(description='PyTorch Testing')
|
61 |
+
parser.add_argument('--gpu', default='0', type=str, help='GPU ID')
|
62 |
+
parser.add_argument('--checkpoint', default=None, type=str, help='path to latest checkpoint')
|
63 |
+
parser.add_argument('--save-dir', help='output folder', default='results/RCF')
|
64 |
+
parser.add_argument('--dataset', help='root folder of dataset', default='data/HED-BSDS')
|
65 |
+
args = parser.parse_args()
|
66 |
+
|
67 |
+
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
|
68 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
|
69 |
+
|
70 |
+
if not osp.isdir(args.save_dir):
|
71 |
+
os.makedirs(args.save_dir)
|
72 |
+
|
73 |
+
test_dataset = BSDS_Dataset(root=args.dataset, split='test')
|
74 |
+
test_loader = DataLoader(test_dataset, batch_size=1, num_workers=1, drop_last=False, shuffle=False)
|
75 |
+
test_list = [osp.split(i.rstrip())[1] for i in test_dataset.file_list]
|
76 |
+
assert len(test_list) == len(test_loader)
|
77 |
+
|
78 |
+
model = RCF().cuda()
|
79 |
+
|
80 |
+
if osp.isfile(args.checkpoint):
|
81 |
+
print("=> loading checkpoint from '{}'".format(args.checkpoint))
|
82 |
+
checkpoint = torch.load(args.checkpoint)
|
83 |
+
model.load_state_dict(checkpoint)
|
84 |
+
print("=> checkpoint loaded")
|
85 |
+
else:
|
86 |
+
print("=> no checkpoint found at '{}'".format(args.checkpoint))
|
87 |
+
|
88 |
+
print('Performing the testing...')
|
89 |
+
single_scale_test(model, test_loader, test_list, args.save_dir)
|
90 |
+
multi_scale_test(model, test_loader, test_list, args.save_dir)
|
RCFPyTorch0/train.py
ADDED
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import os.path as osp
|
4 |
+
import cv2
|
5 |
+
import argparse
|
6 |
+
import time
|
7 |
+
import torch
|
8 |
+
from torch.utils.data import DataLoader
|
9 |
+
import torchvision
|
10 |
+
from dataset import BSDS_Dataset
|
11 |
+
from models import RCF
|
12 |
+
from utils import Logger, Averagvalue, Cross_entropy_loss
|
13 |
+
|
14 |
+
|
15 |
+
def train(args, model, train_loader, optimizer, epoch, logger):
|
16 |
+
batch_time = Averagvalue()
|
17 |
+
losses = Averagvalue()
|
18 |
+
model.train()
|
19 |
+
end = time.time()
|
20 |
+
counter = 0
|
21 |
+
for i, (image, label) in enumerate(train_loader):
|
22 |
+
image, label = image.cuda(), label.cuda()
|
23 |
+
outputs = model(image)
|
24 |
+
loss = torch.zeros(1).cuda()
|
25 |
+
for o in outputs:
|
26 |
+
loss = loss + Cross_entropy_loss(o, label)
|
27 |
+
counter += 1
|
28 |
+
loss = loss / args.iter_size
|
29 |
+
loss.backward()
|
30 |
+
if counter == args.iter_size:
|
31 |
+
optimizer.step()
|
32 |
+
optimizer.zero_grad()
|
33 |
+
counter = 0
|
34 |
+
# measure accuracy and record loss
|
35 |
+
losses.update(loss.item(), image.size(0))
|
36 |
+
batch_time.update(time.time() - end)
|
37 |
+
if i % args.print_freq == 0:
|
38 |
+
logger.info('Epoch: [{0}/{1}][{2}/{3}] '.format(epoch + 1, args.max_epoch, i, len(train_loader)) + \
|
39 |
+
'Time {batch_time.val:.3f} (avg: {batch_time.avg:.3f}) '.format(batch_time=batch_time) + \
|
40 |
+
'Loss {loss.val:f} (avg: {loss.avg:f}) '.format(loss=losses))
|
41 |
+
end = time.time()
|
42 |
+
|
43 |
+
|
44 |
+
def single_scale_test(model, test_loader, test_list, save_dir):
|
45 |
+
model.eval()
|
46 |
+
if not osp.isdir(save_dir):
|
47 |
+
os.makedirs(save_dir)
|
48 |
+
for idx, image in enumerate(test_loader):
|
49 |
+
image = image.cuda()
|
50 |
+
_, _, H, W = image.shape
|
51 |
+
results = model(image)
|
52 |
+
all_res = torch.zeros((len(results), 1, H, W))
|
53 |
+
for i in range(len(results)):
|
54 |
+
all_res[i, 0, :, :] = results[i]
|
55 |
+
filename = osp.splitext(test_list[idx])[0]
|
56 |
+
torchvision.utils.save_image(1 - all_res, osp.join(save_dir, '%s.jpg' % filename))
|
57 |
+
fuse_res = torch.squeeze(results[-1].detach()).cpu().numpy()
|
58 |
+
fuse_res = ((1 - fuse_res) * 255).astype(np.uint8)
|
59 |
+
cv2.imwrite(osp.join(save_dir, '%s_ss.png' % filename), fuse_res)
|
60 |
+
#print('\rRunning single-scale test [%d/%d]' % (idx + 1, len(test_loader)), end='')
|
61 |
+
logger.info('Running single-scale test done')
|
62 |
+
|
63 |
+
|
64 |
+
def multi_scale_test(model, test_loader, test_list, save_dir):
|
65 |
+
model.eval()
|
66 |
+
if not osp.isdir(save_dir):
|
67 |
+
os.makedirs(save_dir)
|
68 |
+
scale = [0.5, 1, 1.5]
|
69 |
+
for idx, image in enumerate(test_loader):
|
70 |
+
in_ = image[0].numpy().transpose((1, 2, 0))
|
71 |
+
_, _, H, W = image.shape
|
72 |
+
ms_fuse = np.zeros((H, W), np.float32)
|
73 |
+
for k in range(len(scale)):
|
74 |
+
im_ = cv2.resize(in_, None, fx=scale[k], fy=scale[k], interpolation=cv2.INTER_LINEAR)
|
75 |
+
im_ = im_.transpose((2, 0, 1))
|
76 |
+
results = model(torch.unsqueeze(torch.from_numpy(im_).cuda(), 0))
|
77 |
+
fuse_res = torch.squeeze(results[-1].detach()).cpu().numpy()
|
78 |
+
fuse_res = cv2.resize(fuse_res, (W, H), interpolation=cv2.INTER_LINEAR)
|
79 |
+
ms_fuse += fuse_res
|
80 |
+
ms_fuse = ms_fuse / len(scale)
|
81 |
+
### rescale trick
|
82 |
+
# ms_fuse = (ms_fuse - ms_fuse.min()) / (ms_fuse.max() - ms_fuse.min())
|
83 |
+
filename = osp.splitext(test_list[idx])[0]
|
84 |
+
ms_fuse = ((1 - ms_fuse) * 255).astype(np.uint8)
|
85 |
+
cv2.imwrite(osp.join(save_dir, '%s_ms.png' % filename), ms_fuse)
|
86 |
+
#print('\rRunning multi-scale test [%d/%d]' % (idx + 1, len(test_loader)), end='')
|
87 |
+
logger.info('Running multi-scale test done')
|
88 |
+
|
89 |
+
|
90 |
+
if __name__ == '__main__':
|
91 |
+
parser = argparse.ArgumentParser(description='PyTorch Training')
|
92 |
+
parser.add_argument('--batch-size', default=1, type=int, help='batch size')
|
93 |
+
parser.add_argument('--lr', default=1e-6, type=float, help='initial learning rate')
|
94 |
+
parser.add_argument('--momentum', default=0.9, type=float, help='momentum')
|
95 |
+
parser.add_argument('--weight-decay', default=2e-4, type=float, help='weight decay')
|
96 |
+
parser.add_argument('--stepsize', default=3, type=int, help='learning rate step size')
|
97 |
+
parser.add_argument('--gamma', default=0.1, type=float, help='learning rate decay rate')
|
98 |
+
parser.add_argument('--max-epoch', default=10, type=int, help='the number of training epochs')
|
99 |
+
parser.add_argument('--iter-size', default=10, type=int, help='iter size')
|
100 |
+
parser.add_argument('--start-epoch', default=0, type=int, help='manual epoch number')
|
101 |
+
parser.add_argument('--print-freq', default=200, type=int, help='print frequency')
|
102 |
+
parser.add_argument('--gpu', default='0', type=str, help='GPU ID')
|
103 |
+
parser.add_argument('--resume', default=None, type=str, help='path to latest checkpoint')
|
104 |
+
parser.add_argument('--save-dir', help='output folder', default='results/RCF')
|
105 |
+
parser.add_argument('--dataset', help='root folder of dataset', default='data')
|
106 |
+
args = parser.parse_args()
|
107 |
+
|
108 |
+
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
|
109 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
|
110 |
+
|
111 |
+
if not osp.isdir(args.save_dir):
|
112 |
+
os.makedirs(args.save_dir)
|
113 |
+
|
114 |
+
logger = Logger(osp.join(args.save_dir, 'log.txt'))
|
115 |
+
logger.info('Called with args:')
|
116 |
+
for (key, value) in vars(args).items():
|
117 |
+
logger.info('{0:15} | {1}'.format(key, value))
|
118 |
+
|
119 |
+
train_dataset = BSDS_Dataset(root=args.dataset, split='train')
|
120 |
+
test_dataset = BSDS_Dataset(root=osp.join(args.dataset, 'HED-BSDS'), split='test')
|
121 |
+
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=4, drop_last=True, shuffle=True)
|
122 |
+
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=4, drop_last=False, shuffle=False)
|
123 |
+
test_list = [osp.split(i.rstrip())[1] for i in test_dataset.file_list]
|
124 |
+
assert len(test_list) == len(test_loader)
|
125 |
+
|
126 |
+
model = RCF(pretrained='vgg16convs.mat').cuda()
|
127 |
+
parameters = {'conv1-4.weight': [], 'conv1-4.bias': [], 'conv5.weight': [], 'conv5.bias': [],
|
128 |
+
'conv_down_1-5.weight': [], 'conv_down_1-5.bias': [], 'score_dsn_1-5.weight': [],
|
129 |
+
'score_dsn_1-5.bias': [], 'score_fuse.weight': [], 'score_fuse.bias': []}
|
130 |
+
for pname, p in model.named_parameters():
|
131 |
+
if pname in ['conv1_1.weight','conv1_2.weight',
|
132 |
+
'conv2_1.weight','conv2_2.weight',
|
133 |
+
'conv3_1.weight','conv3_2.weight','conv3_3.weight',
|
134 |
+
'conv4_1.weight','conv4_2.weight','conv4_3.weight']:
|
135 |
+
parameters['conv1-4.weight'].append(p)
|
136 |
+
elif pname in ['conv1_1.bias','conv1_2.bias',
|
137 |
+
'conv2_1.bias','conv2_2.bias',
|
138 |
+
'conv3_1.bias','conv3_2.bias','conv3_3.bias',
|
139 |
+
'conv4_1.bias','conv4_2.bias','conv4_3.bias']:
|
140 |
+
parameters['conv1-4.bias'].append(p)
|
141 |
+
elif pname in ['conv5_1.weight','conv5_2.weight','conv5_3.weight']:
|
142 |
+
parameters['conv5.weight'].append(p)
|
143 |
+
elif pname in ['conv5_1.bias','conv5_2.bias','conv5_3.bias']:
|
144 |
+
parameters['conv5.bias'].append(p)
|
145 |
+
elif pname in ['conv1_1_down.weight','conv1_2_down.weight',
|
146 |
+
'conv2_1_down.weight','conv2_2_down.weight',
|
147 |
+
'conv3_1_down.weight','conv3_2_down.weight','conv3_3_down.weight',
|
148 |
+
'conv4_1_down.weight','conv4_2_down.weight','conv4_3_down.weight',
|
149 |
+
'conv5_1_down.weight','conv5_2_down.weight','conv5_3_down.weight']:
|
150 |
+
parameters['conv_down_1-5.weight'].append(p)
|
151 |
+
elif pname in ['conv1_1_down.bias','conv1_2_down.bias',
|
152 |
+
'conv2_1_down.bias','conv2_2_down.bias',
|
153 |
+
'conv3_1_down.bias','conv3_2_down.bias','conv3_3_down.bias',
|
154 |
+
'conv4_1_down.bias','conv4_2_down.bias','conv4_3_down.bias',
|
155 |
+
'conv5_1_down.bias','conv5_2_down.bias','conv5_3_down.bias']:
|
156 |
+
parameters['conv_down_1-5.bias'].append(p)
|
157 |
+
elif pname in ['score_dsn1.weight','score_dsn2.weight','score_dsn3.weight', 'score_dsn4.weight','score_dsn5.weight']:
|
158 |
+
parameters['score_dsn_1-5.weight'].append(p)
|
159 |
+
elif pname in ['score_dsn1.bias','score_dsn2.bias','score_dsn3.bias', 'score_dsn4.bias','score_dsn5.bias']:
|
160 |
+
parameters['score_dsn_1-5.bias'].append(p)
|
161 |
+
elif pname in ['score_fuse.weight']:
|
162 |
+
parameters['score_fuse.weight'].append(p)
|
163 |
+
elif pname in ['score_fuse.bias']:
|
164 |
+
parameters['score_fuse.bias'].append(p)
|
165 |
+
|
166 |
+
optimizer = torch.optim.SGD([
|
167 |
+
{'params': parameters['conv1-4.weight'], 'lr': args.lr*1, 'weight_decay': args.weight_decay},
|
168 |
+
{'params': parameters['conv1-4.bias'], 'lr': args.lr*2, 'weight_decay': 0.},
|
169 |
+
{'params': parameters['conv5.weight'], 'lr': args.lr*100, 'weight_decay': args.weight_decay},
|
170 |
+
{'params': parameters['conv5.bias'], 'lr': args.lr*200, 'weight_decay': 0.},
|
171 |
+
{'params': parameters['conv_down_1-5.weight'], 'lr': args.lr*0.1, 'weight_decay': args.weight_decay},
|
172 |
+
{'params': parameters['conv_down_1-5.bias'], 'lr': args.lr*0.2, 'weight_decay': 0.},
|
173 |
+
{'params': parameters['score_dsn_1-5.weight'], 'lr': args.lr*0.01, 'weight_decay': args.weight_decay},
|
174 |
+
{'params': parameters['score_dsn_1-5.bias'], 'lr': args.lr*0.02, 'weight_decay': 0.},
|
175 |
+
{'params': parameters['score_fuse.weight'], 'lr': args.lr*0.001, 'weight_decay': args.weight_decay},
|
176 |
+
{'params': parameters['score_fuse.bias'], 'lr': args.lr*0.002, 'weight_decay': 0.},
|
177 |
+
], lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
|
178 |
+
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.stepsize, gamma=args.gamma)
|
179 |
+
|
180 |
+
if args.resume is not None:
|
181 |
+
if osp.isfile(args.resume):
|
182 |
+
logger.info("=> loading checkpoint from '{}'".format(args.resume))
|
183 |
+
checkpoint = torch.load(args.resume)
|
184 |
+
model.load_state_dict(checkpoint['state_dict'])
|
185 |
+
optimizer.load_state_dict(checkpoint['optimizer'])
|
186 |
+
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
|
187 |
+
args.start_epoch = checkpoint['epoch'] + 1
|
188 |
+
logger.info("=> checkpoint loaded")
|
189 |
+
else:
|
190 |
+
logger.info("=> no checkpoint found at '{}'".format(args.resume))
|
191 |
+
|
192 |
+
for epoch in range(args.start_epoch, args.max_epoch):
|
193 |
+
logger.info('Performing initial testing...')
|
194 |
+
train(args, model, train_loader, optimizer, epoch, logger)
|
195 |
+
save_dir = osp.join(args.save_dir, 'epoch%d-test' % (epoch + 1))
|
196 |
+
single_scale_test(model, test_loader, test_list, save_dir)
|
197 |
+
multi_scale_test(model, test_loader, test_list, save_dir)
|
198 |
+
# Save checkpoint
|
199 |
+
save_file = osp.join(args.save_dir, 'checkpoint_epoch{}.pth'.format(epoch + 1))
|
200 |
+
torch.save({
|
201 |
+
'epoch': epoch,
|
202 |
+
'args': args,
|
203 |
+
'state_dict': model.state_dict(),
|
204 |
+
'optimizer': optimizer.state_dict(),
|
205 |
+
'lr_scheduler': lr_scheduler.state_dict(),
|
206 |
+
}, save_file)
|
207 |
+
lr_scheduler.step() # will adjust learning rate
|
208 |
+
|
209 |
+
logger.close()
|
RCFPyTorch0/utils.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import logging
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
|
8 |
+
class Logger(object):
|
9 |
+
def __init__(self, path='log.txt'):
|
10 |
+
self.logger = logging.getLogger('Logger')
|
11 |
+
self.file_handler = logging.FileHandler(path, 'w')
|
12 |
+
self.stdout_handler = logging.StreamHandler()
|
13 |
+
self.logger.addHandler(self.file_handler)
|
14 |
+
self.logger.addHandler(self.stdout_handler)
|
15 |
+
self.stdout_handler.setFormatter(logging.Formatter('%(asctime)s %(levelname)s %(message)s'))
|
16 |
+
self.file_handler.setFormatter(logging.Formatter('%(asctime)s %(levelname)s %(message)s'))
|
17 |
+
self.logger.setLevel(logging.INFO)
|
18 |
+
|
19 |
+
def info(self, txt):
|
20 |
+
self.logger.info(txt)
|
21 |
+
|
22 |
+
def close(self):
|
23 |
+
self.file_handler.close()
|
24 |
+
self.stdout_handler.close()
|
25 |
+
|
26 |
+
|
27 |
+
class Averagvalue(object):
|
28 |
+
"""Computes and stores the average and current value"""
|
29 |
+
def __init__(self):
|
30 |
+
self.reset()
|
31 |
+
|
32 |
+
def reset(self):
|
33 |
+
self.val = 0
|
34 |
+
self.avg = 0
|
35 |
+
self.sum = 0
|
36 |
+
self.count = 0
|
37 |
+
|
38 |
+
def update(self, val, n=1):
|
39 |
+
self.val = val
|
40 |
+
self.sum += val * n
|
41 |
+
self.count += n
|
42 |
+
self.avg = self.sum / self.count
|
43 |
+
|
44 |
+
|
45 |
+
def Cross_entropy_loss(prediction, label):
|
46 |
+
mask = label.clone()
|
47 |
+
num_positive = torch.sum((mask == 1).float()).float()
|
48 |
+
num_negative = torch.sum((mask == 0).float()).float()
|
49 |
+
|
50 |
+
mask[mask == 1] = 1.0 * num_negative / (num_positive + num_negative)
|
51 |
+
mask[mask == 0] = 1.1 * num_positive / (num_positive + num_negative)
|
52 |
+
mask[mask == 2] = 0
|
53 |
+
cost = F.binary_cross_entropy(prediction, label, weight=mask, reduce=False)
|
54 |
+
return torch.sum(cost)
|
RCFPyTorch0/vgg16convs.mat
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:bce56b30c32d4c72954355fe970c87dceba15bc180aa89524960fda1e0e32cd9
|
3 |
+
size 58860856
|