Yuanhao Zhai commited on
Commit
88677a1
1 Parent(s): b35230f

add missing files

Browse files
models/lib/__init__.py ADDED
File without changes
models/lib/nn/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .modules import *
2
+ from .parallel import (UserScatteredDataParallel, async_copy_to,
3
+ user_scattered_collate)
models/lib/nn/modules/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : __init__.py
3
+ # Author : Jiayuan Mao
4
+ # Email : maojiayuan@gmail.com
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ from .batchnorm import (SynchronizedBatchNorm1d, SynchronizedBatchNorm2d,
12
+ SynchronizedBatchNorm3d)
13
+ from .replicate import DataParallelWithCallback, patch_replication_callback
models/lib/nn/modules/batchnorm.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : batchnorm.py
3
+ # Author : Jiayuan Mao
4
+ # Email : maojiayuan@gmail.com
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ import collections
12
+
13
+ import torch
14
+ import torch.nn.functional as F
15
+ from torch.nn.modules.batchnorm import _BatchNorm
16
+ from torch.nn.parallel._functions import Broadcast, ReduceAddCoalesced
17
+
18
+ from .comm import SyncMaster
19
+
20
+ __all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d']
21
+
22
+
23
+ def _sum_ft(tensor):
24
+ """sum over the first and last dimention"""
25
+ return tensor.sum(dim=0).sum(dim=-1)
26
+
27
+
28
+ def _unsqueeze_ft(tensor):
29
+ """add new dementions at the front and the tail"""
30
+ return tensor.unsqueeze(0).unsqueeze(-1)
31
+
32
+
33
+ _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])
34
+ _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
35
+
36
+
37
+ class _SynchronizedBatchNorm(_BatchNorm):
38
+ def __init__(self, num_features, eps=1e-5, momentum=0.001, affine=True):
39
+ super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine)
40
+
41
+ self._sync_master = SyncMaster(self._data_parallel_master)
42
+
43
+ self._is_parallel = False
44
+ self._parallel_id = None
45
+ self._slave_pipe = None
46
+
47
+ # customed batch norm statistics
48
+ self._moving_average_fraction = 1. - momentum
49
+ self.register_buffer('_tmp_running_mean', torch.zeros(self.num_features))
50
+ self.register_buffer('_tmp_running_var', torch.ones(self.num_features))
51
+ self.register_buffer('_running_iter', torch.ones(1))
52
+ self._tmp_running_mean = self.running_mean.clone() * self._running_iter
53
+ self._tmp_running_var = self.running_var.clone() * self._running_iter
54
+
55
+ def forward(self, input):
56
+ # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
57
+ if not (self._is_parallel and self.training):
58
+ return F.batch_norm(
59
+ input, self.running_mean, self.running_var, self.weight, self.bias,
60
+ self.training, self.momentum, self.eps)
61
+
62
+ # Resize the input to (B, C, -1).
63
+ input_shape = input.size()
64
+ input = input.view(input.size(0), self.num_features, -1)
65
+
66
+ # Compute the sum and square-sum.
67
+ sum_size = input.size(0) * input.size(2)
68
+ input_sum = _sum_ft(input)
69
+ input_ssum = _sum_ft(input ** 2)
70
+
71
+ # Reduce-and-broadcast the statistics.
72
+ if self._parallel_id == 0:
73
+ mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
74
+ else:
75
+ mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
76
+
77
+ # Compute the output.
78
+ if self.affine:
79
+ # MJY:: Fuse the multiplication for speed.
80
+ output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)
81
+ else:
82
+ output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)
83
+
84
+ # Reshape it.
85
+ return output.view(input_shape)
86
+
87
+ def __data_parallel_replicate__(self, ctx, copy_id):
88
+ self._is_parallel = True
89
+ self._parallel_id = copy_id
90
+
91
+ # parallel_id == 0 means master device.
92
+ if self._parallel_id == 0:
93
+ ctx.sync_master = self._sync_master
94
+ else:
95
+ self._slave_pipe = ctx.sync_master.register_slave(copy_id)
96
+
97
+ def _data_parallel_master(self, intermediates):
98
+ """Reduce the sum and square-sum, compute the statistics, and broadcast it."""
99
+ intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
100
+
101
+ to_reduce = [i[1][:2] for i in intermediates]
102
+ to_reduce = [j for i in to_reduce for j in i] # flatten
103
+ target_gpus = [i[1].sum.get_device() for i in intermediates]
104
+
105
+ sum_size = sum([i[1].sum_size for i in intermediates])
106
+ sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
107
+
108
+ mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
109
+
110
+ broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
111
+
112
+ outputs = []
113
+ for i, rec in enumerate(intermediates):
114
+ outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2])))
115
+
116
+ return outputs
117
+
118
+ def _add_weighted(self, dest, delta, alpha=1, beta=1, bias=0):
119
+ """return *dest* by `dest := dest*alpha + delta*beta + bias`"""
120
+ return dest * alpha + delta * beta + bias
121
+
122
+ def _compute_mean_std(self, sum_, ssum, size):
123
+ """Compute the mean and standard-deviation with sum and square-sum. This method
124
+ also maintains the moving average on the master device."""
125
+ assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
126
+ mean = sum_ / size
127
+ sumvar = ssum - sum_ * mean
128
+ unbias_var = sumvar / (size - 1)
129
+ bias_var = sumvar / size
130
+
131
+ self._tmp_running_mean = self._add_weighted(self._tmp_running_mean, mean.data, alpha=self._moving_average_fraction)
132
+ self._tmp_running_var = self._add_weighted(self._tmp_running_var, unbias_var.data, alpha=self._moving_average_fraction)
133
+ self._running_iter = self._add_weighted(self._running_iter, 1, alpha=self._moving_average_fraction)
134
+
135
+ self.running_mean = self._tmp_running_mean / self._running_iter
136
+ self.running_var = self._tmp_running_var / self._running_iter
137
+
138
+ return mean, bias_var.clamp(self.eps) ** -0.5
139
+
140
+
141
+ class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
142
+ r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
143
+ mini-batch.
144
+
145
+ .. math::
146
+
147
+ y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
148
+
149
+ This module differs from the built-in PyTorch BatchNorm1d as the mean and
150
+ standard-deviation are reduced across all devices during training.
151
+
152
+ For example, when one uses `nn.DataParallel` to wrap the network during
153
+ training, PyTorch's implementation normalize the tensor on each device using
154
+ the statistics only on that device, which accelerated the computation and
155
+ is also easy to implement, but the statistics might be inaccurate.
156
+ Instead, in this synchronized version, the statistics will be computed
157
+ over all training samples distributed on multiple devices.
158
+
159
+ Note that, for one-GPU or CPU-only case, this module behaves exactly same
160
+ as the built-in PyTorch implementation.
161
+
162
+ The mean and standard-deviation are calculated per-dimension over
163
+ the mini-batches and gamma and beta are learnable parameter vectors
164
+ of size C (where C is the input size).
165
+
166
+ During training, this layer keeps a running estimate of its computed mean
167
+ and variance. The running sum is kept with a default momentum of 0.1.
168
+
169
+ During evaluation, this running mean/variance is used for normalization.
170
+
171
+ Because the BatchNorm is done over the `C` dimension, computing statistics
172
+ on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
173
+
174
+ Args:
175
+ num_features: num_features from an expected input of size
176
+ `batch_size x num_features [x width]`
177
+ eps: a value added to the denominator for numerical stability.
178
+ Default: 1e-5
179
+ momentum: the value used for the running_mean and running_var
180
+ computation. Default: 0.1
181
+ affine: a boolean value that when set to ``True``, gives the layer learnable
182
+ affine parameters. Default: ``True``
183
+
184
+ Shape:
185
+ - Input: :math:`(N, C)` or :math:`(N, C, L)`
186
+ - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
187
+
188
+ Examples:
189
+ >>> # With Learnable Parameters
190
+ >>> m = SynchronizedBatchNorm1d(100)
191
+ >>> # Without Learnable Parameters
192
+ >>> m = SynchronizedBatchNorm1d(100, affine=False)
193
+ >>> input = torch.autograd.Variable(torch.randn(20, 100))
194
+ >>> output = m(input)
195
+ """
196
+
197
+ def _check_input_dim(self, input):
198
+ if input.dim() != 2 and input.dim() != 3:
199
+ raise ValueError('expected 2D or 3D input (got {}D input)'
200
+ .format(input.dim()))
201
+ super(SynchronizedBatchNorm1d, self)._check_input_dim(input)
202
+
203
+
204
+ class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
205
+ r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
206
+ of 3d inputs
207
+
208
+ .. math::
209
+
210
+ y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
211
+
212
+ This module differs from the built-in PyTorch BatchNorm2d as the mean and
213
+ standard-deviation are reduced across all devices during training.
214
+
215
+ For example, when one uses `nn.DataParallel` to wrap the network during
216
+ training, PyTorch's implementation normalize the tensor on each device using
217
+ the statistics only on that device, which accelerated the computation and
218
+ is also easy to implement, but the statistics might be inaccurate.
219
+ Instead, in this synchronized version, the statistics will be computed
220
+ over all training samples distributed on multiple devices.
221
+
222
+ Note that, for one-GPU or CPU-only case, this module behaves exactly same
223
+ as the built-in PyTorch implementation.
224
+
225
+ The mean and standard-deviation are calculated per-dimension over
226
+ the mini-batches and gamma and beta are learnable parameter vectors
227
+ of size C (where C is the input size).
228
+
229
+ During training, this layer keeps a running estimate of its computed mean
230
+ and variance. The running sum is kept with a default momentum of 0.1.
231
+
232
+ During evaluation, this running mean/variance is used for normalization.
233
+
234
+ Because the BatchNorm is done over the `C` dimension, computing statistics
235
+ on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
236
+
237
+ Args:
238
+ num_features: num_features from an expected input of
239
+ size batch_size x num_features x height x width
240
+ eps: a value added to the denominator for numerical stability.
241
+ Default: 1e-5
242
+ momentum: the value used for the running_mean and running_var
243
+ computation. Default: 0.1
244
+ affine: a boolean value that when set to ``True``, gives the layer learnable
245
+ affine parameters. Default: ``True``
246
+
247
+ Shape:
248
+ - Input: :math:`(N, C, H, W)`
249
+ - Output: :math:`(N, C, H, W)` (same shape as input)
250
+
251
+ Examples:
252
+ >>> # With Learnable Parameters
253
+ >>> m = SynchronizedBatchNorm2d(100)
254
+ >>> # Without Learnable Parameters
255
+ >>> m = SynchronizedBatchNorm2d(100, affine=False)
256
+ >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
257
+ >>> output = m(input)
258
+ """
259
+
260
+ def _check_input_dim(self, input):
261
+ if input.dim() != 4:
262
+ raise ValueError('expected 4D input (got {}D input)'
263
+ .format(input.dim()))
264
+ super(SynchronizedBatchNorm2d, self)._check_input_dim(input)
265
+
266
+
267
+ class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
268
+ r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
269
+ of 4d inputs
270
+
271
+ .. math::
272
+
273
+ y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
274
+
275
+ This module differs from the built-in PyTorch BatchNorm3d as the mean and
276
+ standard-deviation are reduced across all devices during training.
277
+
278
+ For example, when one uses `nn.DataParallel` to wrap the network during
279
+ training, PyTorch's implementation normalize the tensor on each device using
280
+ the statistics only on that device, which accelerated the computation and
281
+ is also easy to implement, but the statistics might be inaccurate.
282
+ Instead, in this synchronized version, the statistics will be computed
283
+ over all training samples distributed on multiple devices.
284
+
285
+ Note that, for one-GPU or CPU-only case, this module behaves exactly same
286
+ as the built-in PyTorch implementation.
287
+
288
+ The mean and standard-deviation are calculated per-dimension over
289
+ the mini-batches and gamma and beta are learnable parameter vectors
290
+ of size C (where C is the input size).
291
+
292
+ During training, this layer keeps a running estimate of its computed mean
293
+ and variance. The running sum is kept with a default momentum of 0.1.
294
+
295
+ During evaluation, this running mean/variance is used for normalization.
296
+
297
+ Because the BatchNorm is done over the `C` dimension, computing statistics
298
+ on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
299
+ or Spatio-temporal BatchNorm
300
+
301
+ Args:
302
+ num_features: num_features from an expected input of
303
+ size batch_size x num_features x depth x height x width
304
+ eps: a value added to the denominator for numerical stability.
305
+ Default: 1e-5
306
+ momentum: the value used for the running_mean and running_var
307
+ computation. Default: 0.1
308
+ affine: a boolean value that when set to ``True``, gives the layer learnable
309
+ affine parameters. Default: ``True``
310
+
311
+ Shape:
312
+ - Input: :math:`(N, C, D, H, W)`
313
+ - Output: :math:`(N, C, D, H, W)` (same shape as input)
314
+
315
+ Examples:
316
+ >>> # With Learnable Parameters
317
+ >>> m = SynchronizedBatchNorm3d(100)
318
+ >>> # Without Learnable Parameters
319
+ >>> m = SynchronizedBatchNorm3d(100, affine=False)
320
+ >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
321
+ >>> output = m(input)
322
+ """
323
+
324
+ def _check_input_dim(self, input):
325
+ if input.dim() != 5:
326
+ raise ValueError('expected 5D input (got {}D input)'
327
+ .format(input.dim()))
328
+ super(SynchronizedBatchNorm3d, self)._check_input_dim(input)
models/lib/nn/modules/comm.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : comm.py
3
+ # Author : Jiayuan Mao
4
+ # Email : maojiayuan@gmail.com
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ import collections
12
+ import queue
13
+ import threading
14
+
15
+ __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
16
+
17
+
18
+ class FutureResult(object):
19
+ """A thread-safe future implementation. Used only as one-to-one pipe."""
20
+
21
+ def __init__(self):
22
+ self._result = None
23
+ self._lock = threading.Lock()
24
+ self._cond = threading.Condition(self._lock)
25
+
26
+ def put(self, result):
27
+ with self._lock:
28
+ assert self._result is None, 'Previous result has\'t been fetched.'
29
+ self._result = result
30
+ self._cond.notify()
31
+
32
+ def get(self):
33
+ with self._lock:
34
+ if self._result is None:
35
+ self._cond.wait()
36
+
37
+ res = self._result
38
+ self._result = None
39
+ return res
40
+
41
+
42
+ _MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
43
+ _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
44
+
45
+
46
+ class SlavePipe(_SlavePipeBase):
47
+ """Pipe for master-slave communication."""
48
+
49
+ def run_slave(self, msg):
50
+ self.queue.put((self.identifier, msg))
51
+ ret = self.result.get()
52
+ self.queue.put(True)
53
+ return ret
54
+
55
+
56
+ class SyncMaster(object):
57
+ """An abstract `SyncMaster` object.
58
+
59
+ - During the replication, as the data parallel will trigger an callback of each module, all slave devices should
60
+ call `register(id)` and obtain an `SlavePipe` to communicate with the master.
61
+ - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
62
+ and passed to a registered callback.
63
+ - After receiving the messages, the master device should gather the information and determine to message passed
64
+ back to each slave devices.
65
+ """
66
+
67
+ def __init__(self, master_callback):
68
+ """
69
+
70
+ Args:
71
+ master_callback: a callback to be invoked after having collected messages from slave devices.
72
+ """
73
+ self._master_callback = master_callback
74
+ self._queue = queue.Queue()
75
+ self._registry = collections.OrderedDict()
76
+ self._activated = False
77
+
78
+ def register_slave(self, identifier):
79
+ """
80
+ Register an slave device.
81
+
82
+ Args:
83
+ identifier: an identifier, usually is the device id.
84
+
85
+ Returns: a `SlavePipe` object which can be used to communicate with the master device.
86
+
87
+ """
88
+ if self._activated:
89
+ assert self._queue.empty(), 'Queue is not clean before next initialization.'
90
+ self._activated = False
91
+ self._registry.clear()
92
+ future = FutureResult()
93
+ self._registry[identifier] = _MasterRegistry(future)
94
+ return SlavePipe(identifier, self._queue, future)
95
+
96
+ def run_master(self, master_msg):
97
+ """
98
+ Main entry for the master device in each forward pass.
99
+ The messages were first collected from each devices (including the master device), and then
100
+ an callback will be invoked to compute the message to be sent back to each devices
101
+ (including the master device).
102
+
103
+ Args:
104
+ master_msg: the message that the master want to send to itself. This will be placed as the first
105
+ message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
106
+
107
+ Returns: the message to be sent back to the master device.
108
+
109
+ """
110
+ self._activated = True
111
+
112
+ intermediates = [(0, master_msg)]
113
+ for i in range(self.nr_slaves):
114
+ intermediates.append(self._queue.get())
115
+
116
+ results = self._master_callback(intermediates)
117
+ assert results[0][0] == 0, 'The first result should belongs to the master.'
118
+
119
+ for i, res in results:
120
+ if i == 0:
121
+ continue
122
+ self._registry[i].result.put(res)
123
+
124
+ for i in range(self.nr_slaves):
125
+ assert self._queue.get() is True
126
+
127
+ return results[0][1]
128
+
129
+ @property
130
+ def nr_slaves(self):
131
+ return len(self._registry)
models/lib/nn/modules/replicate.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : replicate.py
3
+ # Author : Jiayuan Mao
4
+ # Email : maojiayuan@gmail.com
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ import functools
12
+
13
+ from torch.nn.parallel.data_parallel import DataParallel
14
+
15
+ __all__ = [
16
+ 'CallbackContext',
17
+ 'execute_replication_callbacks',
18
+ 'DataParallelWithCallback',
19
+ 'patch_replication_callback'
20
+ ]
21
+
22
+
23
+ class CallbackContext(object):
24
+ pass
25
+
26
+
27
+ def execute_replication_callbacks(modules):
28
+ """
29
+ Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
30
+
31
+ The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
32
+
33
+ Note that, as all modules are isomorphism, we assign each sub-module with a context
34
+ (shared among multiple copies of this module on different devices).
35
+ Through this context, different copies can share some information.
36
+
37
+ We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
38
+ of any slave copies.
39
+ """
40
+ master_copy = modules[0]
41
+ nr_modules = len(list(master_copy.modules()))
42
+ ctxs = [CallbackContext() for _ in range(nr_modules)]
43
+
44
+ for i, module in enumerate(modules):
45
+ for j, m in enumerate(module.modules()):
46
+ if hasattr(m, '__data_parallel_replicate__'):
47
+ m.__data_parallel_replicate__(ctxs[j], i)
48
+
49
+
50
+ class DataParallelWithCallback(DataParallel):
51
+ """
52
+ Data Parallel with a replication callback.
53
+
54
+ An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
55
+ original `replicate` function.
56
+ The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
57
+
58
+ Examples:
59
+ > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
60
+ > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
61
+ # sync_bn.__data_parallel_replicate__ will be invoked.
62
+ """
63
+
64
+ def replicate(self, module, device_ids):
65
+ modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
66
+ execute_replication_callbacks(modules)
67
+ return modules
68
+
69
+
70
+ def patch_replication_callback(data_parallel):
71
+ """
72
+ Monkey-patch an existing `DataParallel` object. Add the replication callback.
73
+ Useful when you have customized `DataParallel` implementation.
74
+
75
+ Examples:
76
+ > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
77
+ > sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
78
+ > patch_replication_callback(sync_bn)
79
+ # this is equivalent to
80
+ > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
81
+ > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
82
+ """
83
+
84
+ assert isinstance(data_parallel, DataParallel)
85
+
86
+ old_replicate = data_parallel.replicate
87
+
88
+ @functools.wraps(old_replicate)
89
+ def new_replicate(module, device_ids):
90
+ modules = old_replicate(module, device_ids)
91
+ execute_replication_callbacks(modules)
92
+ return modules
93
+
94
+ data_parallel.replicate = new_replicate
models/lib/nn/modules/tests/__init__.py ADDED
File without changes
models/lib/nn/modules/tests/test_numeric_batchnorm.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : test_numeric_batchnorm.py
3
+ # Author : Jiayuan Mao
4
+ # Email : maojiayuan@gmail.com
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+
9
+ import unittest
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ from sync_batchnorm.unittest import TorchTestCase
14
+ from torch.autograd import Variable
15
+
16
+
17
+ def handy_var(a, unbias=True):
18
+ n = a.size(0)
19
+ asum = a.sum(dim=0)
20
+ as_sum = (a ** 2).sum(dim=0) # a square sum
21
+ sumvar = as_sum - asum * asum / n
22
+ if unbias:
23
+ return sumvar / (n - 1)
24
+ else:
25
+ return sumvar / n
26
+
27
+
28
+ class NumericTestCase(TorchTestCase):
29
+ def testNumericBatchNorm(self):
30
+ a = torch.rand(16, 10)
31
+ bn = nn.BatchNorm2d(10, momentum=1, eps=1e-5, affine=False)
32
+ bn.train()
33
+
34
+ a_var1 = Variable(a, requires_grad=True)
35
+ b_var1 = bn(a_var1)
36
+ loss1 = b_var1.sum()
37
+ loss1.backward()
38
+
39
+ a_var2 = Variable(a, requires_grad=True)
40
+ a_mean2 = a_var2.mean(dim=0, keepdim=True)
41
+ a_std2 = torch.sqrt(handy_var(a_var2, unbias=False).clamp(min=1e-5))
42
+ # a_std2 = torch.sqrt(a_var2.var(dim=0, keepdim=True, unbiased=False) + 1e-5)
43
+ b_var2 = (a_var2 - a_mean2) / a_std2
44
+ loss2 = b_var2.sum()
45
+ loss2.backward()
46
+
47
+ self.assertTensorClose(bn.running_mean, a.mean(dim=0))
48
+ self.assertTensorClose(bn.running_var, handy_var(a))
49
+ self.assertTensorClose(a_var1.data, a_var2.data)
50
+ self.assertTensorClose(b_var1.data, b_var2.data)
51
+ self.assertTensorClose(a_var1.grad, a_var2.grad)
52
+
53
+
54
+ if __name__ == '__main__':
55
+ unittest.main()
models/lib/nn/modules/tests/test_sync_batchnorm.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : test_sync_batchnorm.py
3
+ # Author : Jiayuan Mao
4
+ # Email : maojiayuan@gmail.com
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+
9
+ import unittest
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ from sync_batchnorm import (DataParallelWithCallback, SynchronizedBatchNorm1d,
14
+ SynchronizedBatchNorm2d)
15
+ from sync_batchnorm.unittest import TorchTestCase
16
+ from torch.autograd import Variable
17
+
18
+
19
+ def handy_var(a, unbias=True):
20
+ n = a.size(0)
21
+ asum = a.sum(dim=0)
22
+ as_sum = (a ** 2).sum(dim=0) # a square sum
23
+ sumvar = as_sum - asum * asum / n
24
+ if unbias:
25
+ return sumvar / (n - 1)
26
+ else:
27
+ return sumvar / n
28
+
29
+
30
+ def _find_bn(module):
31
+ for m in module.modules():
32
+ if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, SynchronizedBatchNorm1d, SynchronizedBatchNorm2d)):
33
+ return m
34
+
35
+
36
+ class SyncTestCase(TorchTestCase):
37
+ def _syncParameters(self, bn1, bn2):
38
+ bn1.reset_parameters()
39
+ bn2.reset_parameters()
40
+ if bn1.affine and bn2.affine:
41
+ bn2.weight.data.copy_(bn1.weight.data)
42
+ bn2.bias.data.copy_(bn1.bias.data)
43
+
44
+ def _checkBatchNormResult(self, bn1, bn2, input, is_train, cuda=False):
45
+ """Check the forward and backward for the customized batch normalization."""
46
+ bn1.train(mode=is_train)
47
+ bn2.train(mode=is_train)
48
+
49
+ if cuda:
50
+ input = input.cuda()
51
+
52
+ self._syncParameters(_find_bn(bn1), _find_bn(bn2))
53
+
54
+ input1 = Variable(input, requires_grad=True)
55
+ output1 = bn1(input1)
56
+ output1.sum().backward()
57
+ input2 = Variable(input, requires_grad=True)
58
+ output2 = bn2(input2)
59
+ output2.sum().backward()
60
+
61
+ self.assertTensorClose(input1.data, input2.data)
62
+ self.assertTensorClose(output1.data, output2.data)
63
+ self.assertTensorClose(input1.grad, input2.grad)
64
+ self.assertTensorClose(_find_bn(bn1).running_mean, _find_bn(bn2).running_mean)
65
+ self.assertTensorClose(_find_bn(bn1).running_var, _find_bn(bn2).running_var)
66
+
67
+ def testSyncBatchNormNormalTrain(self):
68
+ bn = nn.BatchNorm1d(10)
69
+ sync_bn = SynchronizedBatchNorm1d(10)
70
+
71
+ self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True)
72
+
73
+ def testSyncBatchNormNormalEval(self):
74
+ bn = nn.BatchNorm1d(10)
75
+ sync_bn = SynchronizedBatchNorm1d(10)
76
+
77
+ self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False)
78
+
79
+ def testSyncBatchNormSyncTrain(self):
80
+ bn = nn.BatchNorm1d(10, eps=1e-5, affine=False)
81
+ sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
82
+ sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
83
+
84
+ bn.cuda()
85
+ sync_bn.cuda()
86
+
87
+ self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True, cuda=True)
88
+
89
+ def testSyncBatchNormSyncEval(self):
90
+ bn = nn.BatchNorm1d(10, eps=1e-5, affine=False)
91
+ sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
92
+ sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
93
+
94
+ bn.cuda()
95
+ sync_bn.cuda()
96
+
97
+ self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False, cuda=True)
98
+
99
+ def testSyncBatchNorm2DSyncTrain(self):
100
+ bn = nn.BatchNorm2d(10)
101
+ sync_bn = SynchronizedBatchNorm2d(10)
102
+ sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
103
+
104
+ bn.cuda()
105
+ sync_bn.cuda()
106
+
107
+ self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10, 16, 16), True, cuda=True)
108
+
109
+
110
+ if __name__ == '__main__':
111
+ unittest.main()
models/lib/nn/modules/unittest.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : unittest.py
3
+ # Author : Jiayuan Mao
4
+ # Email : maojiayuan@gmail.com
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ import unittest
12
+
13
+ import numpy as np
14
+ from torch.autograd import Variable
15
+
16
+
17
+ def as_numpy(v):
18
+ if isinstance(v, Variable):
19
+ v = v.data
20
+ return v.cpu().numpy()
21
+
22
+
23
+ class TorchTestCase(unittest.TestCase):
24
+ def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3):
25
+ npa, npb = as_numpy(a), as_numpy(b)
26
+ self.assertTrue(
27
+ np.allclose(npa, npb, atol=atol),
28
+ 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max())
29
+ )
models/lib/nn/parallel/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .data_parallel import (UserScatteredDataParallel, async_copy_to,
2
+ user_scattered_collate)
models/lib/nn/parallel/data_parallel.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf8 -*-
2
+
3
+ import collections
4
+
5
+ import torch
6
+ import torch.cuda as cuda
7
+ import torch.nn as nn
8
+ from torch.nn.parallel._functions import Gather
9
+
10
+ __all__ = ['UserScatteredDataParallel', 'user_scattered_collate', 'async_copy_to']
11
+
12
+
13
+ def async_copy_to(obj, dev, main_stream=None):
14
+ if torch.is_tensor(obj):
15
+ v = obj.cuda(dev, non_blocking=True)
16
+ if main_stream is not None:
17
+ v.data.record_stream(main_stream)
18
+ return v
19
+ elif isinstance(obj, collections.Mapping):
20
+ return {k: async_copy_to(o, dev, main_stream) for k, o in obj.items()}
21
+ elif isinstance(obj, collections.Sequence):
22
+ return [async_copy_to(o, dev, main_stream) for o in obj]
23
+ else:
24
+ return obj
25
+
26
+
27
+ def dict_gather(outputs, target_device, dim=0):
28
+ """
29
+ Gathers variables from different GPUs on a specified device
30
+ (-1 means the CPU), with dictionary support.
31
+ """
32
+ def gather_map(outputs):
33
+ out = outputs[0]
34
+ if torch.is_tensor(out):
35
+ # MJY(20180330) HACK:: force nr_dims > 0
36
+ if out.dim() == 0:
37
+ outputs = [o.unsqueeze(0) for o in outputs]
38
+ return Gather.apply(target_device, dim, *outputs)
39
+ elif out is None:
40
+ return None
41
+ elif isinstance(out, collections.Mapping):
42
+ return {k: gather_map([o[k] for o in outputs]) for k in out}
43
+ elif isinstance(out, collections.Sequence):
44
+ return type(out)(map(gather_map, zip(*outputs)))
45
+ return gather_map(outputs)
46
+
47
+
48
+ class DictGatherDataParallel(nn.DataParallel):
49
+ def gather(self, outputs, output_device):
50
+ return dict_gather(outputs, output_device, dim=self.dim)
51
+
52
+
53
+ class UserScatteredDataParallel(DictGatherDataParallel):
54
+ def scatter(self, inputs, kwargs, device_ids):
55
+ assert len(inputs) == 1
56
+ inputs = inputs[0]
57
+ inputs = _async_copy_stream(inputs, device_ids)
58
+ inputs = [[i] for i in inputs]
59
+ assert len(kwargs) == 0
60
+ kwargs = [{} for _ in range(len(inputs))]
61
+
62
+ return inputs, kwargs
63
+
64
+
65
+ def user_scattered_collate(batch):
66
+ return batch
67
+
68
+
69
+ def _async_copy(inputs, device_ids):
70
+ nr_devs = len(device_ids)
71
+ assert type(inputs) in (tuple, list)
72
+ assert len(inputs) == nr_devs
73
+
74
+ outputs = []
75
+ for i, dev in zip(inputs, device_ids):
76
+ with cuda.device(dev):
77
+ outputs.append(async_copy_to(i, dev))
78
+
79
+ return tuple(outputs)
80
+
81
+
82
+ def _async_copy_stream(inputs, device_ids):
83
+ nr_devs = len(device_ids)
84
+ assert type(inputs) in (tuple, list)
85
+ assert len(inputs) == nr_devs
86
+
87
+ outputs = []
88
+ streams = [_get_stream(d) for d in device_ids]
89
+ for i, dev, stream in zip(inputs, device_ids, streams):
90
+ with cuda.device(dev):
91
+ main_stream = cuda.current_stream()
92
+ with cuda.stream(stream):
93
+ outputs.append(async_copy_to(i, dev, main_stream=main_stream))
94
+ main_stream.wait_stream(stream)
95
+
96
+ return outputs
97
+
98
+
99
+ """Adapted from: torch/nn/parallel/_functions.py"""
100
+ # background streams used for copying
101
+ _streams = None
102
+
103
+
104
+ def _get_stream(device):
105
+ """Gets a background stream for copying between CPU and GPU"""
106
+ global _streams
107
+ if device == -1:
108
+ return None
109
+ if _streams is None:
110
+ _streams = [None] * cuda.device_count()
111
+ if _streams[device] is None: _streams[device] = cuda.Stream(device)
112
+ return _streams[device]
models/lib/utils/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .th import *
models/lib/utils/data/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+
2
+ from .dataloader import DataLoader
3
+ from .dataset import ConcatDataset, Dataset, TensorDataset
models/lib/utils/data/dataloader.py ADDED
@@ -0,0 +1,429 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.multiprocessing as multiprocessing
3
+ from torch._C import (_error_if_any_worker_fails, _remove_worker_pids,
4
+ _set_worker_signal_handlers)
5
+
6
+ try:
7
+ from torch._C import _set_worker_pids
8
+ except:
9
+ from torch._C import _update_worker_pids as _set_worker_pids
10
+
11
+ import collections
12
+ import re
13
+ import signal
14
+ import sys
15
+ import threading
16
+ import traceback
17
+
18
+ import numpy as np
19
+ from torch._six import int_classes, string_classes
20
+
21
+ from .sampler import BatchSampler, RandomSampler, SequentialSampler
22
+
23
+ if sys.version_info[0] == 2:
24
+ import Queue as queue
25
+ else:
26
+ import queue
27
+
28
+
29
+ class ExceptionWrapper(object):
30
+ r"Wraps an exception plus traceback to communicate across threads"
31
+
32
+ def __init__(self, exc_info):
33
+ self.exc_type = exc_info[0]
34
+ self.exc_msg = "".join(traceback.format_exception(*exc_info))
35
+
36
+
37
+ _use_shared_memory = False
38
+ """Whether to use shared memory in default_collate"""
39
+
40
+
41
+ def _worker_loop(dataset, index_queue, data_queue, collate_fn, seed, init_fn, worker_id):
42
+ global _use_shared_memory
43
+ _use_shared_memory = True
44
+
45
+ # Intialize C side signal handlers for SIGBUS and SIGSEGV. Python signal
46
+ # module's handlers are executed after Python returns from C low-level
47
+ # handlers, likely when the same fatal signal happened again already.
48
+ # https://docs.python.org/3/library/signal.html Sec. 18.8.1.1
49
+ _set_worker_signal_handlers()
50
+
51
+ torch.set_num_threads(1)
52
+ torch.manual_seed(seed)
53
+ np.random.seed(seed)
54
+
55
+ if init_fn is not None:
56
+ init_fn(worker_id)
57
+
58
+ while True:
59
+ r = index_queue.get()
60
+ if r is None:
61
+ break
62
+ idx, batch_indices = r
63
+ try:
64
+ samples = collate_fn([dataset[i] for i in batch_indices])
65
+ except Exception:
66
+ data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
67
+ else:
68
+ data_queue.put((idx, samples))
69
+
70
+
71
+ def _worker_manager_loop(in_queue, out_queue, done_event, pin_memory, device_id):
72
+ if pin_memory:
73
+ torch.cuda.set_device(device_id)
74
+
75
+ while True:
76
+ try:
77
+ r = in_queue.get()
78
+ except Exception:
79
+ if done_event.is_set():
80
+ return
81
+ raise
82
+ if r is None:
83
+ break
84
+ if isinstance(r[1], ExceptionWrapper):
85
+ out_queue.put(r)
86
+ continue
87
+ idx, batch = r
88
+ try:
89
+ if pin_memory:
90
+ batch = pin_memory_batch(batch)
91
+ except Exception:
92
+ out_queue.put((idx, ExceptionWrapper(sys.exc_info())))
93
+ else:
94
+ out_queue.put((idx, batch))
95
+
96
+ numpy_type_map = {
97
+ 'float64': torch.DoubleTensor,
98
+ 'float32': torch.FloatTensor,
99
+ 'float16': torch.HalfTensor,
100
+ 'int64': torch.LongTensor,
101
+ 'int32': torch.IntTensor,
102
+ 'int16': torch.ShortTensor,
103
+ 'int8': torch.CharTensor,
104
+ 'uint8': torch.ByteTensor,
105
+ }
106
+
107
+
108
+ def default_collate(batch):
109
+ "Puts each data field into a tensor with outer dimension batch size"
110
+
111
+ error_msg = "batch must contain tensors, numbers, dicts or lists; found {}"
112
+ elem_type = type(batch[0])
113
+ if torch.is_tensor(batch[0]):
114
+ out = None
115
+ if _use_shared_memory:
116
+ # If we're in a background process, concatenate directly into a
117
+ # shared memory tensor to avoid an extra copy
118
+ numel = sum([x.numel() for x in batch])
119
+ storage = batch[0].storage()._new_shared(numel)
120
+ out = batch[0].new(storage)
121
+ return torch.stack(batch, 0, out=out)
122
+ elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
123
+ and elem_type.__name__ != 'string_':
124
+ elem = batch[0]
125
+ if elem_type.__name__ == 'ndarray':
126
+ # array of string classes and object
127
+ if re.search('[SaUO]', elem.dtype.str) is not None:
128
+ raise TypeError(error_msg.format(elem.dtype))
129
+
130
+ return torch.stack([torch.from_numpy(b) for b in batch], 0)
131
+ if elem.shape == (): # scalars
132
+ py_type = float if elem.dtype.name.startswith('float') else int
133
+ return numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
134
+ elif isinstance(batch[0], int_classes):
135
+ return torch.LongTensor(batch)
136
+ elif isinstance(batch[0], float):
137
+ return torch.DoubleTensor(batch)
138
+ elif isinstance(batch[0], string_classes):
139
+ return batch
140
+ elif isinstance(batch[0], collections.Mapping):
141
+ return {key: default_collate([d[key] for d in batch]) for key in batch[0]}
142
+ elif isinstance(batch[0], collections.Sequence):
143
+ transposed = zip(*batch)
144
+ return [default_collate(samples) for samples in transposed]
145
+
146
+ raise TypeError((error_msg.format(type(batch[0]))))
147
+
148
+
149
+ def pin_memory_batch(batch):
150
+ if torch.is_tensor(batch):
151
+ return batch.pin_memory()
152
+ elif isinstance(batch, string_classes):
153
+ return batch
154
+ elif isinstance(batch, collections.Mapping):
155
+ return {k: pin_memory_batch(sample) for k, sample in batch.items()}
156
+ elif isinstance(batch, collections.Sequence):
157
+ return [pin_memory_batch(sample) for sample in batch]
158
+ else:
159
+ return batch
160
+
161
+
162
+ _SIGCHLD_handler_set = False
163
+ """Whether SIGCHLD handler is set for DataLoader worker failures. Only one
164
+ handler needs to be set for all DataLoaders in a process."""
165
+
166
+
167
+ def _set_SIGCHLD_handler():
168
+ # Windows doesn't support SIGCHLD handler
169
+ if sys.platform == 'win32':
170
+ return
171
+ # can't set signal in child threads
172
+ if not isinstance(threading.current_thread(), threading._MainThread):
173
+ return
174
+ global _SIGCHLD_handler_set
175
+ if _SIGCHLD_handler_set:
176
+ return
177
+ previous_handler = signal.getsignal(signal.SIGCHLD)
178
+ if not callable(previous_handler):
179
+ previous_handler = None
180
+
181
+ def handler(signum, frame):
182
+ # This following call uses `waitid` with WNOHANG from C side. Therefore,
183
+ # Python can still get and update the process status successfully.
184
+ _error_if_any_worker_fails()
185
+ if previous_handler is not None:
186
+ previous_handler(signum, frame)
187
+
188
+ signal.signal(signal.SIGCHLD, handler)
189
+ _SIGCHLD_handler_set = True
190
+
191
+
192
+ class DataLoaderIter(object):
193
+ "Iterates once over the DataLoader's dataset, as specified by the sampler"
194
+
195
+ def __init__(self, loader):
196
+ self.dataset = loader.dataset
197
+ self.collate_fn = loader.collate_fn
198
+ self.batch_sampler = loader.batch_sampler
199
+ self.num_workers = loader.num_workers
200
+ self.pin_memory = loader.pin_memory and torch.cuda.is_available()
201
+ self.timeout = loader.timeout
202
+ self.done_event = threading.Event()
203
+
204
+ self.sample_iter = iter(self.batch_sampler)
205
+
206
+ if self.num_workers > 0:
207
+ self.worker_init_fn = loader.worker_init_fn
208
+ self.index_queue = multiprocessing.SimpleQueue()
209
+ self.worker_result_queue = multiprocessing.SimpleQueue()
210
+ self.batches_outstanding = 0
211
+ self.worker_pids_set = False
212
+ self.shutdown = False
213
+ self.send_idx = 0
214
+ self.rcvd_idx = 0
215
+ self.reorder_dict = {}
216
+
217
+ base_seed = torch.LongTensor(1).random_(0, 2**31-1)[0]
218
+ self.workers = [
219
+ multiprocessing.Process(
220
+ target=_worker_loop,
221
+ args=(self.dataset, self.index_queue, self.worker_result_queue, self.collate_fn,
222
+ base_seed + i, self.worker_init_fn, i))
223
+ for i in range(self.num_workers)]
224
+
225
+ if self.pin_memory or self.timeout > 0:
226
+ self.data_queue = queue.Queue()
227
+ if self.pin_memory:
228
+ maybe_device_id = torch.cuda.current_device()
229
+ else:
230
+ # do not initialize cuda context if not necessary
231
+ maybe_device_id = None
232
+ self.worker_manager_thread = threading.Thread(
233
+ target=_worker_manager_loop,
234
+ args=(self.worker_result_queue, self.data_queue, self.done_event, self.pin_memory,
235
+ maybe_device_id))
236
+ self.worker_manager_thread.daemon = True
237
+ self.worker_manager_thread.start()
238
+ else:
239
+ self.data_queue = self.worker_result_queue
240
+
241
+ for w in self.workers:
242
+ w.daemon = True # ensure that the worker exits on process exit
243
+ w.start()
244
+
245
+ _set_worker_pids(id(self), tuple(w.pid for w in self.workers))
246
+ _set_SIGCHLD_handler()
247
+ self.worker_pids_set = True
248
+
249
+ # prime the prefetch loop
250
+ for _ in range(2 * self.num_workers):
251
+ self._put_indices()
252
+
253
+ def __len__(self):
254
+ return len(self.batch_sampler)
255
+
256
+ def _get_batch(self):
257
+ if self.timeout > 0:
258
+ try:
259
+ return self.data_queue.get(timeout=self.timeout)
260
+ except queue.Empty:
261
+ raise RuntimeError('DataLoader timed out after {} seconds'.format(self.timeout))
262
+ else:
263
+ return self.data_queue.get()
264
+
265
+ def __next__(self):
266
+ if self.num_workers == 0: # same-process loading
267
+ indices = next(self.sample_iter) # may raise StopIteration
268
+ batch = self.collate_fn([self.dataset[i] for i in indices])
269
+ if self.pin_memory:
270
+ batch = pin_memory_batch(batch)
271
+ return batch
272
+
273
+ # check if the next sample has already been generated
274
+ if self.rcvd_idx in self.reorder_dict:
275
+ batch = self.reorder_dict.pop(self.rcvd_idx)
276
+ return self._process_next_batch(batch)
277
+
278
+ if self.batches_outstanding == 0:
279
+ self._shutdown_workers()
280
+ raise StopIteration
281
+
282
+ while True:
283
+ assert (not self.shutdown and self.batches_outstanding > 0)
284
+ idx, batch = self._get_batch()
285
+ self.batches_outstanding -= 1
286
+ if idx != self.rcvd_idx:
287
+ # store out-of-order samples
288
+ self.reorder_dict[idx] = batch
289
+ continue
290
+ return self._process_next_batch(batch)
291
+
292
+ next = __next__ # Python 2 compatibility
293
+
294
+ def __iter__(self):
295
+ return self
296
+
297
+ def _put_indices(self):
298
+ assert self.batches_outstanding < 2 * self.num_workers
299
+ indices = next(self.sample_iter, None)
300
+ if indices is None:
301
+ return
302
+ self.index_queue.put((self.send_idx, indices))
303
+ self.batches_outstanding += 1
304
+ self.send_idx += 1
305
+
306
+ def _process_next_batch(self, batch):
307
+ self.rcvd_idx += 1
308
+ self._put_indices()
309
+ if isinstance(batch, ExceptionWrapper):
310
+ raise batch.exc_type(batch.exc_msg)
311
+ return batch
312
+
313
+ def __getstate__(self):
314
+ # TODO: add limited pickling support for sharing an iterator
315
+ # across multiple threads for HOGWILD.
316
+ # Probably the best way to do this is by moving the sample pushing
317
+ # to a separate thread and then just sharing the data queue
318
+ # but signalling the end is tricky without a non-blocking API
319
+ raise NotImplementedError("DataLoaderIterator cannot be pickled")
320
+
321
+ def _shutdown_workers(self):
322
+ try:
323
+ if not self.shutdown:
324
+ self.shutdown = True
325
+ self.done_event.set()
326
+ # if worker_manager_thread is waiting to put
327
+ while not self.data_queue.empty():
328
+ self.data_queue.get()
329
+ for _ in self.workers:
330
+ self.index_queue.put(None)
331
+ # done_event should be sufficient to exit worker_manager_thread,
332
+ # but be safe here and put another None
333
+ self.worker_result_queue.put(None)
334
+ finally:
335
+ # removes pids no matter what
336
+ if self.worker_pids_set:
337
+ _remove_worker_pids(id(self))
338
+ self.worker_pids_set = False
339
+
340
+ def __del__(self):
341
+ if self.num_workers > 0:
342
+ self._shutdown_workers()
343
+
344
+
345
+ class DataLoader(object):
346
+ """
347
+ Data loader. Combines a dataset and a sampler, and provides
348
+ single- or multi-process iterators over the dataset.
349
+
350
+ Arguments:
351
+ dataset (Dataset): dataset from which to load the data.
352
+ batch_size (int, optional): how many samples per batch to load
353
+ (default: 1).
354
+ shuffle (bool, optional): set to ``True`` to have the data reshuffled
355
+ at every epoch (default: False).
356
+ sampler (Sampler, optional): defines the strategy to draw samples from
357
+ the dataset. If specified, ``shuffle`` must be False.
358
+ batch_sampler (Sampler, optional): like sampler, but returns a batch of
359
+ indices at a time. Mutually exclusive with batch_size, shuffle,
360
+ sampler, and drop_last.
361
+ num_workers (int, optional): how many subprocesses to use for data
362
+ loading. 0 means that the data will be loaded in the main process.
363
+ (default: 0)
364
+ collate_fn (callable, optional): merges a list of samples to form a mini-batch.
365
+ pin_memory (bool, optional): If ``True``, the data loader will copy tensors
366
+ into CUDA pinned memory before returning them.
367
+ drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
368
+ if the dataset size is not divisible by the batch size. If ``False`` and
369
+ the size of dataset is not divisible by the batch size, then the last batch
370
+ will be smaller. (default: False)
371
+ timeout (numeric, optional): if positive, the timeout value for collecting a batch
372
+ from workers. Should always be non-negative. (default: 0)
373
+ worker_init_fn (callable, optional): If not None, this will be called on each
374
+ worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
375
+ input, after seeding and before data loading. (default: None)
376
+
377
+ .. note:: By default, each worker will have its PyTorch seed set to
378
+ ``base_seed + worker_id``, where ``base_seed`` is a long generated
379
+ by main process using its RNG. You may use ``torch.initial_seed()`` to access
380
+ this value in :attr:`worker_init_fn`, which can be used to set other seeds
381
+ (e.g. NumPy) before data loading.
382
+
383
+ .. warning:: If ``spawn'' start method is used, :attr:`worker_init_fn` cannot be an
384
+ unpicklable object, e.g., a lambda function.
385
+ """
386
+
387
+ def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,
388
+ num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False,
389
+ timeout=0, worker_init_fn=None):
390
+ self.dataset = dataset
391
+ self.batch_size = batch_size
392
+ self.num_workers = num_workers
393
+ self.collate_fn = collate_fn
394
+ self.pin_memory = pin_memory
395
+ self.drop_last = drop_last
396
+ self.timeout = timeout
397
+ self.worker_init_fn = worker_init_fn
398
+
399
+ if timeout < 0:
400
+ raise ValueError('timeout option should be non-negative')
401
+
402
+ if batch_sampler is not None:
403
+ if batch_size > 1 or shuffle or sampler is not None or drop_last:
404
+ raise ValueError('batch_sampler is mutually exclusive with '
405
+ 'batch_size, shuffle, sampler, and drop_last')
406
+
407
+ if sampler is not None and shuffle:
408
+ raise ValueError('sampler is mutually exclusive with shuffle')
409
+
410
+ if self.num_workers < 0:
411
+ raise ValueError('num_workers cannot be negative; '
412
+ 'use num_workers=0 to disable multiprocessing.')
413
+
414
+ if batch_sampler is None:
415
+ if sampler is None:
416
+ if shuffle:
417
+ sampler = RandomSampler(dataset)
418
+ else:
419
+ sampler = SequentialSampler(dataset)
420
+ batch_sampler = BatchSampler(sampler, batch_size, drop_last)
421
+
422
+ self.sampler = sampler
423
+ self.batch_sampler = batch_sampler
424
+
425
+ def __iter__(self):
426
+ return DataLoaderIter(self)
427
+
428
+ def __len__(self):
429
+ return len(self.batch_sampler)
models/lib/utils/data/dataset.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import bisect
2
+ import warnings
3
+
4
+ from torch import randperm
5
+ from torch._utils import _accumulate
6
+
7
+
8
+ class Dataset(object):
9
+ """An abstract class representing a Dataset.
10
+
11
+ All other datasets should subclass it. All subclasses should override
12
+ ``__len__``, that provides the size of the dataset, and ``__getitem__``,
13
+ supporting integer indexing in range from 0 to len(self) exclusive.
14
+ """
15
+
16
+ def __getitem__(self, index):
17
+ raise NotImplementedError
18
+
19
+ def __len__(self):
20
+ raise NotImplementedError
21
+
22
+ def __add__(self, other):
23
+ return ConcatDataset([self, other])
24
+
25
+
26
+ class TensorDataset(Dataset):
27
+ """Dataset wrapping data and target tensors.
28
+
29
+ Each sample will be retrieved by indexing both tensors along the first
30
+ dimension.
31
+
32
+ Arguments:
33
+ data_tensor (Tensor): contains sample data.
34
+ target_tensor (Tensor): contains sample targets (labels).
35
+ """
36
+
37
+ def __init__(self, data_tensor, target_tensor):
38
+ assert data_tensor.size(0) == target_tensor.size(0)
39
+ self.data_tensor = data_tensor
40
+ self.target_tensor = target_tensor
41
+
42
+ def __getitem__(self, index):
43
+ return self.data_tensor[index], self.target_tensor[index]
44
+
45
+ def __len__(self):
46
+ return self.data_tensor.size(0)
47
+
48
+
49
+ class ConcatDataset(Dataset):
50
+ """
51
+ Dataset to concatenate multiple datasets.
52
+ Purpose: useful to assemble different existing datasets, possibly
53
+ large-scale datasets as the concatenation operation is done in an
54
+ on-the-fly manner.
55
+
56
+ Arguments:
57
+ datasets (iterable): List of datasets to be concatenated
58
+ """
59
+
60
+ @staticmethod
61
+ def cumsum(sequence):
62
+ r, s = [], 0
63
+ for e in sequence:
64
+ l = len(e)
65
+ r.append(l + s)
66
+ s += l
67
+ return r
68
+
69
+ def __init__(self, datasets):
70
+ super(ConcatDataset, self).__init__()
71
+ assert len(datasets) > 0, 'datasets should not be an empty iterable'
72
+ self.datasets = list(datasets)
73
+ self.cumulative_sizes = self.cumsum(self.datasets)
74
+
75
+ def __len__(self):
76
+ return self.cumulative_sizes[-1]
77
+
78
+ def __getitem__(self, idx):
79
+ dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
80
+ if dataset_idx == 0:
81
+ sample_idx = idx
82
+ else:
83
+ sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
84
+ return self.datasets[dataset_idx][sample_idx]
85
+
86
+ @property
87
+ def cummulative_sizes(self):
88
+ warnings.warn("cummulative_sizes attribute is renamed to "
89
+ "cumulative_sizes", DeprecationWarning, stacklevel=2)
90
+ return self.cumulative_sizes
91
+
92
+
93
+ class Subset(Dataset):
94
+ def __init__(self, dataset, indices):
95
+ self.dataset = dataset
96
+ self.indices = indices
97
+
98
+ def __getitem__(self, idx):
99
+ return self.dataset[self.indices[idx]]
100
+
101
+ def __len__(self):
102
+ return len(self.indices)
103
+
104
+
105
+ def random_split(dataset, lengths):
106
+ """
107
+ Randomly split a dataset into non-overlapping new datasets of given lengths
108
+ ds
109
+
110
+ Arguments:
111
+ dataset (Dataset): Dataset to be split
112
+ lengths (iterable): lengths of splits to be produced
113
+ """
114
+ if sum(lengths) != len(dataset):
115
+ raise ValueError("Sum of input lengths does not equal the length of the input dataset!")
116
+
117
+ indices = randperm(sum(lengths))
118
+ return [Subset(dataset, indices[offset - length:offset]) for offset, length in zip(_accumulate(lengths), lengths)]
models/lib/utils/data/distributed.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ from torch.distributed import get_rank, get_world_size
5
+
6
+ from .sampler import Sampler
7
+
8
+
9
+ class DistributedSampler(Sampler):
10
+ """Sampler that restricts data loading to a subset of the dataset.
11
+
12
+ It is especially useful in conjunction with
13
+ :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
14
+ process can pass a DistributedSampler instance as a DataLoader sampler,
15
+ and load a subset of the original dataset that is exclusive to it.
16
+
17
+ .. note::
18
+ Dataset is assumed to be of constant size.
19
+
20
+ Arguments:
21
+ dataset: Dataset used for sampling.
22
+ num_replicas (optional): Number of processes participating in
23
+ distributed training.
24
+ rank (optional): Rank of the current process within num_replicas.
25
+ """
26
+
27
+ def __init__(self, dataset, num_replicas=None, rank=None):
28
+ if num_replicas is None:
29
+ num_replicas = get_world_size()
30
+ if rank is None:
31
+ rank = get_rank()
32
+ self.dataset = dataset
33
+ self.num_replicas = num_replicas
34
+ self.rank = rank
35
+ self.epoch = 0
36
+ self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
37
+ self.total_size = self.num_samples * self.num_replicas
38
+
39
+ def __iter__(self):
40
+ # deterministically shuffle based on epoch
41
+ g = torch.Generator()
42
+ g.manual_seed(self.epoch)
43
+ indices = list(torch.randperm(len(self.dataset), generator=g))
44
+
45
+ # add extra samples to make it evenly divisible
46
+ indices += indices[:(self.total_size - len(indices))]
47
+ assert len(indices) == self.total_size
48
+
49
+ # subsample
50
+ offset = self.num_samples * self.rank
51
+ indices = indices[offset:offset + self.num_samples]
52
+ assert len(indices) == self.num_samples
53
+
54
+ return iter(indices)
55
+
56
+ def __len__(self):
57
+ return self.num_samples
58
+
59
+ def set_epoch(self, epoch):
60
+ self.epoch = epoch
models/lib/utils/data/sampler.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class Sampler(object):
5
+ """Base class for all Samplers.
6
+
7
+ Every Sampler subclass has to provide an __iter__ method, providing a way
8
+ to iterate over indices of dataset elements, and a __len__ method that
9
+ returns the length of the returned iterators.
10
+ """
11
+
12
+ def __init__(self, data_source):
13
+ pass
14
+
15
+ def __iter__(self):
16
+ raise NotImplementedError
17
+
18
+ def __len__(self):
19
+ raise NotImplementedError
20
+
21
+
22
+ class SequentialSampler(Sampler):
23
+ """Samples elements sequentially, always in the same order.
24
+
25
+ Arguments:
26
+ data_source (Dataset): dataset to sample from
27
+ """
28
+
29
+ def __init__(self, data_source):
30
+ self.data_source = data_source
31
+
32
+ def __iter__(self):
33
+ return iter(range(len(self.data_source)))
34
+
35
+ def __len__(self):
36
+ return len(self.data_source)
37
+
38
+
39
+ class RandomSampler(Sampler):
40
+ """Samples elements randomly, without replacement.
41
+
42
+ Arguments:
43
+ data_source (Dataset): dataset to sample from
44
+ """
45
+
46
+ def __init__(self, data_source):
47
+ self.data_source = data_source
48
+
49
+ def __iter__(self):
50
+ return iter(torch.randperm(len(self.data_source)).long())
51
+
52
+ def __len__(self):
53
+ return len(self.data_source)
54
+
55
+
56
+ class SubsetRandomSampler(Sampler):
57
+ """Samples elements randomly from a given list of indices, without replacement.
58
+
59
+ Arguments:
60
+ indices (list): a list of indices
61
+ """
62
+
63
+ def __init__(self, indices):
64
+ self.indices = indices
65
+
66
+ def __iter__(self):
67
+ return (self.indices[i] for i in torch.randperm(len(self.indices)))
68
+
69
+ def __len__(self):
70
+ return len(self.indices)
71
+
72
+
73
+ class WeightedRandomSampler(Sampler):
74
+ """Samples elements from [0,..,len(weights)-1] with given probabilities (weights).
75
+
76
+ Arguments:
77
+ weights (list) : a list of weights, not necessary summing up to one
78
+ num_samples (int): number of samples to draw
79
+ replacement (bool): if ``True``, samples are drawn with replacement.
80
+ If not, they are drawn without replacement, which means that when a
81
+ sample index is drawn for a row, it cannot be drawn again for that row.
82
+ """
83
+
84
+ def __init__(self, weights, num_samples, replacement=True):
85
+ self.weights = torch.DoubleTensor(weights)
86
+ self.num_samples = num_samples
87
+ self.replacement = replacement
88
+
89
+ def __iter__(self):
90
+ return iter(torch.multinomial(self.weights, self.num_samples, self.replacement))
91
+
92
+ def __len__(self):
93
+ return self.num_samples
94
+
95
+
96
+ class BatchSampler(object):
97
+ """Wraps another sampler to yield a mini-batch of indices.
98
+
99
+ Args:
100
+ sampler (Sampler): Base sampler.
101
+ batch_size (int): Size of mini-batch.
102
+ drop_last (bool): If ``True``, the sampler will drop the last batch if
103
+ its size would be less than ``batch_size``
104
+
105
+ Example:
106
+ >>> list(BatchSampler(range(10), batch_size=3, drop_last=False))
107
+ [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
108
+ >>> list(BatchSampler(range(10), batch_size=3, drop_last=True))
109
+ [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
110
+ """
111
+
112
+ def __init__(self, sampler, batch_size, drop_last):
113
+ self.sampler = sampler
114
+ self.batch_size = batch_size
115
+ self.drop_last = drop_last
116
+
117
+ def __iter__(self):
118
+ batch = []
119
+ for idx in self.sampler:
120
+ batch.append(idx)
121
+ if len(batch) == self.batch_size:
122
+ yield batch
123
+ batch = []
124
+ if len(batch) > 0 and not self.drop_last:
125
+ yield batch
126
+
127
+ def __len__(self):
128
+ if self.drop_last:
129
+ return len(self.sampler) // self.batch_size
130
+ else:
131
+ return (len(self.sampler) + self.batch_size - 1) // self.batch_size
models/lib/utils/th.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+
3
+ import numpy as np
4
+ import torch
5
+ from torch.autograd import Variable
6
+
7
+ __all__ = ['as_variable', 'as_numpy', 'mark_volatile']
8
+
9
+ def as_variable(obj):
10
+ if isinstance(obj, Variable):
11
+ return obj
12
+ if isinstance(obj, collections.Sequence):
13
+ return [as_variable(v) for v in obj]
14
+ elif isinstance(obj, collections.Mapping):
15
+ return {k: as_variable(v) for k, v in obj.items()}
16
+ else:
17
+ return Variable(obj)
18
+
19
+ def as_numpy(obj):
20
+ if isinstance(obj, collections.Sequence):
21
+ return [as_numpy(v) for v in obj]
22
+ elif isinstance(obj, collections.Mapping):
23
+ return {k: as_numpy(v) for k, v in obj.items()}
24
+ elif isinstance(obj, Variable):
25
+ return obj.data.cpu().numpy()
26
+ elif torch.is_tensor(obj):
27
+ return obj.cpu().numpy()
28
+ else:
29
+ return np.array(obj)
30
+
31
+ def mark_volatile(obj):
32
+ if torch.is_tensor(obj):
33
+ obj = Variable(obj)
34
+ if isinstance(obj, Variable):
35
+ obj.no_grad = True
36
+ return obj
37
+ elif isinstance(obj, collections.Mapping):
38
+ return {k: mark_volatile(o) for k, o in obj.items()}
39
+ elif isinstance(obj, collections.Sequence):
40
+ return [mark_volatile(o) for o in obj]
41
+ else:
42
+ return obj