Sophie98 commited on
Commit
ea2f7c7
1 Parent(s): 210920a
ViT_helper.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ def drop_path(x, drop_prob: float = 0., training: bool = False):
5
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
6
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
7
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
8
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
9
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
10
+ 'survival rate' as the argument.
11
+ """
12
+ if drop_prob == 0. or not training:
13
+ return x
14
+ keep_prob = 1 - drop_prob
15
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
16
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
17
+ random_tensor.floor_() # binarize
18
+ output = x.div(keep_prob) * random_tensor
19
+ return output
20
+
21
+
22
+ class DropPath(nn.Module):
23
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
24
+ """
25
+ def __init__(self, drop_prob=None):
26
+ super(DropPath, self).__init__()
27
+ self.drop_prob = drop_prob
28
+
29
+ def forward(self, x):
30
+ return drop_path(x, self.drop_prob, self.training)
31
+
32
+ from itertools import repeat
33
+ from torch._six import container_abcs
34
+
35
+
36
+ # From PyTorch internals
37
+ def _ntuple(n):
38
+ def parse(x):
39
+ if isinstance(x, container_abcs.Iterable):
40
+ return x
41
+ return tuple(repeat(x, n))
42
+ return parse
43
+
44
+
45
+ to_1tuple = _ntuple(1)
46
+ to_2tuple = _ntuple(2)
47
+ to_3tuple = _ntuple(3)
48
+ to_4tuple = _ntuple(4)
49
+
50
+
51
+
52
+ import torch
53
+ import math
54
+ import warnings
55
+
56
+
57
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
58
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
59
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
60
+ def norm_cdf(x):
61
+ # Computes standard normal cumulative distribution function
62
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
63
+
64
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
65
+ warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
66
+ "The distribution of values may be incorrect.",
67
+ stacklevel=2)
68
+
69
+ with torch.no_grad():
70
+ # Values are generated by using a truncated uniform distribution and
71
+ # then using the inverse CDF for the normal distribution.
72
+ # Get upper and lower cdf values
73
+ l = norm_cdf((a - mean) / std)
74
+ u = norm_cdf((b - mean) / std)
75
+
76
+ # Uniformly fill tensor with values from [l, u], then translate to
77
+ # [2l-1, 2u-1].
78
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
79
+
80
+ # Use inverse cdf transform for normal distribution to get truncated
81
+ # standard normal
82
+ tensor.erfinv_()
83
+
84
+ # Transform to proper mean, std
85
+ tensor.mul_(std * math.sqrt(2.))
86
+ tensor.add_(mean)
87
+
88
+ # Clamp to ensure it's in the proper range
89
+ tensor.clamp_(min=a, max=b)
90
+ return tensor
91
+
92
+
93
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
94
+ # type: (Tensor, float, float, float, float) -> Tensor
95
+ r"""Fills the input Tensor with values drawn from a truncated
96
+ normal distribution. The values are effectively drawn from the
97
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
98
+ with values outside :math:`[a, b]` redrawn until they are within
99
+ the bounds. The method used for generating the random values works
100
+ best when :math:`a \leq \text{mean} \leq b`.
101
+ Args:
102
+ tensor: an n-dimensional `torch.Tensor`
103
+ mean: the mean of the normal distribution
104
+ std: the standard deviation of the normal distribution
105
+ a: the minimum cutoff value
106
+ b: the maximum cutoff value
107
+ Examples:
108
+ >>> w = torch.empty(3, 5)
109
+ >>> nn.init.trunc_normal_(w)
110
+ """
111
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
function.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def calc_mean_std(feat, eps=1e-5):
5
+ # eps is a small value added to the variance to avoid divide-by-zero.
6
+ size = feat.size()
7
+ assert (len(size) == 4)
8
+ N, C = size[:2]
9
+ feat_var = feat.view(N, C, -1).var(dim=2) + eps
10
+ feat_std = feat_var.sqrt().view(N, C, 1, 1)
11
+ feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
12
+ return feat_mean, feat_std
13
+
14
+ def calc_mean_std1(feat, eps=1e-5):
15
+ # eps is a small value added to the variance to avoid divide-by-zero.
16
+ size = feat.size()
17
+ # assert (len(size) == 4)
18
+ WH,N, C = size
19
+ feat_var = feat.var(dim=0) + eps
20
+ feat_std = feat_var.sqrt()
21
+ feat_mean = feat.mean(dim=0)
22
+ return feat_mean, feat_std
23
+ def normal(feat, eps=1e-5):
24
+ feat_mean, feat_std= calc_mean_std(feat, eps)
25
+ normalized=(feat-feat_mean)/feat_std
26
+ return normalized
27
+ def normal_style(feat, eps=1e-5):
28
+ feat_mean, feat_std= calc_mean_std1(feat, eps)
29
+ normalized=(feat-feat_mean)/feat_std
30
+ return normalized
31
+
32
+ def _calc_feat_flatten_mean_std(feat):
33
+ # takes 3D feat (C, H, W), return mean and std of array within channels
34
+ assert (feat.size()[0] == 3)
35
+ assert (isinstance(feat, torch.FloatTensor))
36
+ feat_flatten = feat.view(3, -1)
37
+ mean = feat_flatten.mean(dim=-1, keepdim=True)
38
+ std = feat_flatten.std(dim=-1, keepdim=True)
39
+ return feat_flatten, mean, std
40
+
41
+
42
+ def _mat_sqrt(x):
43
+ U, D, V = torch.svd(x)
44
+ return torch.mm(torch.mm(U, D.pow(0.5).diag()), V.t())
45
+
46
+
47
+ def coral(source, target):
48
+ # assume both source and target are 3D array (C, H, W)
49
+ # Note: flatten -> f
50
+
51
+ source_f, source_f_mean, source_f_std = _calc_feat_flatten_mean_std(source)
52
+ source_f_norm = (source_f - source_f_mean.expand_as(
53
+ source_f)) / source_f_std.expand_as(source_f)
54
+ source_f_cov_eye = \
55
+ torch.mm(source_f_norm, source_f_norm.t()) + torch.eye(3)
56
+
57
+ target_f, target_f_mean, target_f_std = _calc_feat_flatten_mean_std(target)
58
+ target_f_norm = (target_f - target_f_mean.expand_as(
59
+ target_f)) / target_f_std.expand_as(target_f)
60
+ target_f_cov_eye = \
61
+ torch.mm(target_f_norm, target_f_norm.t()) + torch.eye(3)
62
+
63
+ source_f_norm_transfer = torch.mm(
64
+ _mat_sqrt(target_f_cov_eye),
65
+ torch.mm(torch.inverse(_mat_sqrt(source_f_cov_eye)),
66
+ source_f_norm)
67
+ )
68
+
69
+ source_f_transfer = source_f_norm_transfer * \
70
+ target_f_std.expand_as(source_f_norm) + \
71
+ target_f_mean.expand_as(source_f_norm)
72
+
73
+ return source_f_transfer.view(source.size())
misc.py ADDED
@@ -0,0 +1,469 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ """
3
+ Misc functions, including distributed helpers.
4
+
5
+ Mostly copy-paste from torchvision references.
6
+ """
7
+ import os
8
+ import subprocess
9
+ import time
10
+ from collections import defaultdict, deque
11
+ import datetime
12
+ import pickle
13
+ from typing import Optional, List
14
+
15
+ import torch
16
+ import torch.distributed as dist
17
+ from torch import Tensor
18
+
19
+ # needed due to empty tensor bug in pytorch and torchvision 0.5
20
+ import torchvision
21
+ if float(torchvision.__version__[:3]) < 0.7:
22
+ from torchvision.ops import _new_empty_tensor
23
+ from torchvision.ops.misc import _output_size
24
+
25
+
26
+ class SmoothedValue(object):
27
+ """Track a series of values and provide access to smoothed values over a
28
+ window or the global series average.
29
+ """
30
+
31
+ def __init__(self, window_size=20, fmt=None):
32
+ if fmt is None:
33
+ fmt = "{median:.4f} ({global_avg:.4f})"
34
+ self.deque = deque(maxlen=window_size)
35
+ self.total = 0.0
36
+ self.count = 0
37
+ self.fmt = fmt
38
+
39
+ def update(self, value, n=1):
40
+ self.deque.append(value)
41
+ self.count += n
42
+ self.total += value * n
43
+
44
+ def synchronize_between_processes(self):
45
+ """
46
+ Warning: does not synchronize the deque!
47
+ """
48
+ if not is_dist_avail_and_initialized():
49
+ return
50
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
51
+ dist.barrier()
52
+ dist.all_reduce(t)
53
+ t = t.tolist()
54
+ self.count = int(t[0])
55
+ self.total = t[1]
56
+
57
+ @property
58
+ def median(self):
59
+ d = torch.tensor(list(self.deque))
60
+ return d.median().item()
61
+
62
+ @property
63
+ def avg(self):
64
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
65
+ return d.mean().item()
66
+
67
+ @property
68
+ def global_avg(self):
69
+ return self.total / self.count
70
+
71
+ @property
72
+ def max(self):
73
+ return max(self.deque)
74
+
75
+ @property
76
+ def value(self):
77
+ return self.deque[-1]
78
+
79
+ def __str__(self):
80
+ return self.fmt.format(
81
+ median=self.median,
82
+ avg=self.avg,
83
+ global_avg=self.global_avg,
84
+ max=self.max,
85
+ value=self.value)
86
+
87
+
88
+ def all_gather(data):
89
+ """
90
+ Run all_gather on arbitrary picklable data (not necessarily tensors)
91
+ Args:
92
+ data: any picklable object
93
+ Returns:
94
+ list[data]: list of data gathered from each rank
95
+ """
96
+ world_size = get_world_size()
97
+ if world_size == 1:
98
+ return [data]
99
+
100
+ # serialized to a Tensor
101
+ buffer = pickle.dumps(data)
102
+ storage = torch.ByteStorage.from_buffer(buffer)
103
+ tensor = torch.ByteTensor(storage).to("cuda")
104
+
105
+ # obtain Tensor size of each rank
106
+ local_size = torch.tensor([tensor.numel()], device="cuda")
107
+ size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
108
+ dist.all_gather(size_list, local_size)
109
+ size_list = [int(size.item()) for size in size_list]
110
+ max_size = max(size_list)
111
+
112
+ # receiving Tensor from all ranks
113
+ # we pad the tensor because torch all_gather does not support
114
+ # gathering tensors of different shapes
115
+ tensor_list = []
116
+ for _ in size_list:
117
+ tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
118
+ if local_size != max_size:
119
+ padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
120
+ tensor = torch.cat((tensor, padding), dim=0)
121
+ dist.all_gather(tensor_list, tensor)
122
+
123
+ data_list = []
124
+ for size, tensor in zip(size_list, tensor_list):
125
+ buffer = tensor.cpu().numpy().tobytes()[:size]
126
+ data_list.append(pickle.loads(buffer))
127
+
128
+ return data_list
129
+
130
+
131
+ def reduce_dict(input_dict, average=True):
132
+ """
133
+ Args:
134
+ input_dict (dict): all the values will be reduced
135
+ average (bool): whether to do average or sum
136
+ Reduce the values in the dictionary from all processes so that all processes
137
+ have the averaged results. Returns a dict with the same fields as
138
+ input_dict, after reduction.
139
+ """
140
+ world_size = get_world_size()
141
+ if world_size < 2:
142
+ return input_dict
143
+ with torch.no_grad():
144
+ names = []
145
+ values = []
146
+ # sort the keys so that they are consistent across processes
147
+ for k in sorted(input_dict.keys()):
148
+ names.append(k)
149
+ values.append(input_dict[k])
150
+ values = torch.stack(values, dim=0)
151
+ dist.all_reduce(values)
152
+ if average:
153
+ values /= world_size
154
+ reduced_dict = {k: v for k, v in zip(names, values)}
155
+ return reduced_dict
156
+
157
+
158
+ class MetricLogger(object):
159
+ def __init__(self, delimiter="\t"):
160
+ self.meters = defaultdict(SmoothedValue)
161
+ self.delimiter = delimiter
162
+
163
+ def update(self, **kwargs):
164
+ for k, v in kwargs.items():
165
+ if isinstance(v, torch.Tensor):
166
+ v = v.item()
167
+ assert isinstance(v, (float, int))
168
+ self.meters[k].update(v)
169
+
170
+ def __getattr__(self, attr):
171
+ if attr in self.meters:
172
+ return self.meters[attr]
173
+ if attr in self.__dict__:
174
+ return self.__dict__[attr]
175
+ raise AttributeError("'{}' object has no attribute '{}'".format(
176
+ type(self).__name__, attr))
177
+
178
+ def __str__(self):
179
+ loss_str = []
180
+ for name, meter in self.meters.items():
181
+ loss_str.append(
182
+ "{}: {}".format(name, str(meter))
183
+ )
184
+ return self.delimiter.join(loss_str)
185
+
186
+ def synchronize_between_processes(self):
187
+ for meter in self.meters.values():
188
+ meter.synchronize_between_processes()
189
+
190
+ def add_meter(self, name, meter):
191
+ self.meters[name] = meter
192
+
193
+ def log_every(self, iterable, print_freq, header=None):
194
+ i = 0
195
+ if not header:
196
+ header = ''
197
+ start_time = time.time()
198
+ end = time.time()
199
+ iter_time = SmoothedValue(fmt='{avg:.4f}')
200
+ data_time = SmoothedValue(fmt='{avg:.4f}')
201
+ space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
202
+ if torch.cuda.is_available():
203
+ log_msg = self.delimiter.join([
204
+ header,
205
+ '[{0' + space_fmt + '}/{1}]',
206
+ 'eta: {eta}',
207
+ '{meters}',
208
+ 'time: {time}',
209
+ 'data: {data}',
210
+ 'max mem: {memory:.0f}'
211
+ ])
212
+ else:
213
+ log_msg = self.delimiter.join([
214
+ header,
215
+ '[{0' + space_fmt + '}/{1}]',
216
+ 'eta: {eta}',
217
+ '{meters}',
218
+ 'time: {time}',
219
+ 'data: {data}'
220
+ ])
221
+ MB = 1024.0 * 1024.0
222
+ for obj in iterable:
223
+ data_time.update(time.time() - end)
224
+ yield obj
225
+ iter_time.update(time.time() - end)
226
+ if i % print_freq == 0 or i == len(iterable) - 1:
227
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
228
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
229
+ if torch.cuda.is_available():
230
+ print(log_msg.format(
231
+ i, len(iterable), eta=eta_string,
232
+ meters=str(self),
233
+ time=str(iter_time), data=str(data_time),
234
+ memory=torch.cuda.max_memory_allocated() / MB))
235
+ else:
236
+ print(log_msg.format(
237
+ i, len(iterable), eta=eta_string,
238
+ meters=str(self),
239
+ time=str(iter_time), data=str(data_time)))
240
+ i += 1
241
+ end = time.time()
242
+ total_time = time.time() - start_time
243
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
244
+ print('{} Total time: {} ({:.4f} s / it)'.format(
245
+ header, total_time_str, total_time / len(iterable)))
246
+
247
+
248
+ def get_sha():
249
+ cwd = os.path.dirname(os.path.abspath(__file__))
250
+
251
+ def _run(command):
252
+ return subprocess.check_output(command, cwd=cwd).decode('ascii').strip()
253
+ sha = 'N/A'
254
+ diff = "clean"
255
+ branch = 'N/A'
256
+ try:
257
+ sha = _run(['git', 'rev-parse', 'HEAD'])
258
+ subprocess.check_output(['git', 'diff'], cwd=cwd)
259
+ diff = _run(['git', 'diff-index', 'HEAD'])
260
+ diff = "has uncommited changes" if diff else "clean"
261
+ branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
262
+ except Exception:
263
+ pass
264
+ message = f"sha: {sha}, status: {diff}, branch: {branch}"
265
+ return message
266
+
267
+
268
+ def collate_fn(batch):
269
+ batch = list(zip(*batch))
270
+ batch[0] = nested_tensor_from_tensor_list(batch[0])
271
+ return tuple(batch)
272
+
273
+
274
+ def _max_by_axis(the_list):
275
+ # type: (List[List[int]]) -> List[int]
276
+ maxes = the_list[0]
277
+ for sublist in the_list[1:]:
278
+ for index, item in enumerate(sublist):
279
+ maxes[index] = max(maxes[index], item)
280
+ return maxes
281
+
282
+
283
+ class NestedTensor(object):
284
+ def __init__(self, tensors, mask: Optional[Tensor]):
285
+ self.tensors = tensors
286
+ self.mask = mask
287
+
288
+ def to(self, device):
289
+ # type: (Device) -> NestedTensor # noqa
290
+ cast_tensor = self.tensors.to(device)
291
+ mask = self.mask
292
+ if mask is not None:
293
+ assert mask is not None
294
+ cast_mask = mask.to(device)
295
+ else:
296
+ cast_mask = None
297
+ return NestedTensor(cast_tensor, cast_mask)
298
+
299
+ def decompose(self):
300
+ return self.tensors, self.mask
301
+
302
+ def __repr__(self):
303
+ return str(self.tensors)
304
+
305
+
306
+ def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
307
+ # TODO make this more general
308
+ if tensor_list[0].ndim == 3:
309
+ if torchvision._is_tracing():
310
+ # nested_tensor_from_tensor_list() does not export well to ONNX
311
+ # call _onnx_nested_tensor_from_tensor_list() instead
312
+ return _onnx_nested_tensor_from_tensor_list(tensor_list)
313
+
314
+ # TODO make it support different-sized images
315
+ max_size = _max_by_axis([list(img.shape) for img in tensor_list])
316
+ # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
317
+ # print(len(tensor_list), max_size)
318
+ batch_shape = [len(tensor_list)] + max_size
319
+ # print(batch_shape)
320
+ b, c, h, w = batch_shape
321
+ dtype = tensor_list[0].dtype
322
+ device = tensor_list[0].device
323
+ tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
324
+ mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
325
+ for img, pad_img, m in zip(tensor_list, tensor, mask):
326
+ pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
327
+ m[: img.shape[1], :img.shape[2]] = False
328
+ else:
329
+ raise ValueError('not supported')
330
+ return NestedTensor(tensor, mask)
331
+
332
+
333
+ # _onnx_nested_tensor_from_tensor_list() is an implementation of
334
+ # nested_tensor_from_tensor_list() that is supported by ONNX tracing.
335
+ @torch.jit.unused
336
+ def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor:
337
+ max_size = []
338
+ for i in range(tensor_list[0].dim()):
339
+ max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(torch.int64)
340
+ max_size.append(max_size_i)
341
+ max_size = tuple(max_size)
342
+
343
+ # work around for
344
+ # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
345
+ # m[: img.shape[1], :img.shape[2]] = False
346
+ # which is not yet supported in onnx
347
+ padded_imgs = []
348
+ padded_masks = []
349
+ for img in tensor_list:
350
+ padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
351
+ padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
352
+ padded_imgs.append(padded_img)
353
+
354
+ m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
355
+ padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1)
356
+ padded_masks.append(padded_mask.to(torch.bool))
357
+
358
+ tensor = torch.stack(padded_imgs)
359
+ mask = torch.stack(padded_masks)
360
+
361
+ return NestedTensor(tensor, mask=mask)
362
+
363
+
364
+ def setup_for_distributed(is_master):
365
+ """
366
+ This function disables printing when not in master process
367
+ """
368
+ import builtins as __builtin__
369
+ builtin_print = __builtin__.print
370
+
371
+ def print(*args, **kwargs):
372
+ force = kwargs.pop('force', False)
373
+ if is_master or force:
374
+ builtin_print(*args, **kwargs)
375
+
376
+ __builtin__.print = print
377
+
378
+
379
+ def is_dist_avail_and_initialized():
380
+ if not dist.is_available():
381
+ return False
382
+ if not dist.is_initialized():
383
+ return False
384
+ return True
385
+
386
+
387
+ def get_world_size():
388
+ if not is_dist_avail_and_initialized():
389
+ return 1
390
+ return dist.get_world_size()
391
+
392
+
393
+ def get_rank():
394
+ if not is_dist_avail_and_initialized():
395
+ return 0
396
+ return dist.get_rank()
397
+
398
+
399
+ def is_main_process():
400
+ return get_rank() == 0
401
+
402
+
403
+ def save_on_master(*args, **kwargs):
404
+ if is_main_process():
405
+ torch.save(*args, **kwargs)
406
+
407
+
408
+ def init_distributed_mode(args):
409
+ if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
410
+ args.rank = int(os.environ["RANK"])
411
+ args.world_size = int(os.environ['WORLD_SIZE'])
412
+ args.gpu = int(os.environ['LOCAL_RANK'])
413
+ elif 'SLURM_PROCID' in os.environ:
414
+ args.rank = int(os.environ['SLURM_PROCID'])
415
+ args.gpu = args.rank % torch.cuda.device_count()
416
+ else:
417
+ print('Not using distributed mode')
418
+ args.distributed = False
419
+ return
420
+
421
+ args.distributed = True
422
+
423
+ torch.cuda.set_device(args.gpu)
424
+ args.dist_backend = 'nccl'
425
+ print('| distributed init (rank {}): {}'.format(
426
+ args.rank, args.dist_url), flush=True)
427
+ torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
428
+ world_size=args.world_size, rank=args.rank)
429
+ torch.distributed.barrier()
430
+ setup_for_distributed(args.rank == 0)
431
+
432
+
433
+ @torch.no_grad()
434
+ def accuracy(output, target, topk=(1,)):
435
+ """Computes the precision@k for the specified values of k"""
436
+ if target.numel() == 0:
437
+ return [torch.zeros([], device=output.device)]
438
+ maxk = max(topk)
439
+ batch_size = target.size(0)
440
+
441
+ _, pred = output.topk(maxk, 1, True, True)
442
+ pred = pred.t()
443
+ correct = pred.eq(target.view(1, -1).expand_as(pred))
444
+
445
+ res = []
446
+ for k in topk:
447
+ correct_k = correct[:k].view(-1).float().sum(0)
448
+ res.append(correct_k.mul_(100.0 / batch_size))
449
+ return res
450
+
451
+
452
+ def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None):
453
+ # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
454
+ """
455
+ Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
456
+ This will eventually be supported natively by PyTorch, and this
457
+ class can go away.
458
+ """
459
+ if float(torchvision.__version__[:3]) < 0.7:
460
+ if input.numel() > 0:
461
+ return torch.nn.functional.interpolate(
462
+ input, size, scale_factor, mode, align_corners
463
+ )
464
+
465
+ output_shape = _output_size(2, input, size, scale_factor)
466
+ output_shape = list(input.shape[:-2]) + list(output_shape)
467
+ return _new_empty_tensor(input, output_shape)
468
+ else:
469
+ return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners)
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==1.8.0
2
+ torchvision==0.9.0
3
+ pillow
4
+ scipy
5
+ numpy
6
+ tqdm
7
+ matplotlib
8
+ gradio
9
+
10
+ segmentation_models
11
+ opencv-python-headless
12
+ tensorflow
sofa.jpg ADDED
sofaApp.py CHANGED
@@ -1,7 +1,7 @@
1
  import numpy as np
2
  import gradio as gr
3
- from Segmentation.segmentation import get_mask,replace_sofa
4
- from StyleTransfer.styleTransfer import resize_sofa,resize_style,create_styledSofa
5
  from PIL import Image
6
 
7
  def style_sofa(input_img: np.ndarray, style_img: np.ndarray):
@@ -33,11 +33,11 @@ demo = gr.Interface(
33
  [image,style],
34
  'image',
35
  examples=[
36
- ['input/sofa_example1.jpg','input/style_example1.jpg'],
37
- ['input/sofa_example1.jpg','input/style_example2.jpg'],
38
- ['input/sofa_example1.jpg','input/style_example3.jpg'],
39
- ['input/sofa_example1.jpg','input/style_example4.jpg'],
40
- ['input/sofa_example1.jpg','input/style_example5.jpg'],
41
  ],
42
  title="Style your sofa",
43
  description="🛋 Customize your sofa to your wildest dreams! 🛋",
1
  import numpy as np
2
  import gradio as gr
3
+ from segmentation import get_mask,replace_sofa
4
+ from styleTransfer import resize_sofa,resize_style,create_styledSofa
5
  from PIL import Image
6
 
7
  def style_sofa(input_img: np.ndarray, style_img: np.ndarray):
33
  [image,style],
34
  'image',
35
  examples=[
36
+ ['sofa_example1.jpg','style_example1.jpg'],
37
+ ['sofa_example1.jpg','style_example2.jpg'],
38
+ ['sofa_example1.jpg','style_example3.jpg'],
39
+ ['sofa_example1.jpg','style_example4.jpg'],
40
+ ['sofa_example1.jpg','style_example5.jpg'],
41
  ],
42
  title="Style your sofa",
43
  description="🛋 Customize your sofa to your wildest dreams! 🛋",
sofa_example1.jpg ADDED
sofa_stylized_style.jpg ADDED
style.jpg ADDED
styleTransfer.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ import numpy as np
4
+ import os
5
+ import cv2
6
+
7
+ def resize_sofa(img):
8
+ img = Image.fromarray(img)
9
+ width, height = img.size
10
+ idx = np.argmin([width,height])
11
+
12
+ if idx==0:
13
+ img1 = Image.new(img.mode, (height, height), (255, 255, 255))
14
+ img1.paste(img, ((height-width)//2, 0))
15
+ else:
16
+ img1 = Image.new(img.mode, (width, width), (255, 255, 255))
17
+ img1.paste(img, (0, (width-height)//2))
18
+
19
+ newsize = (640, 640) # parameters from test script
20
+ im1 = img1.resize(newsize)
21
+ return im1
22
+
23
+ def resize_style(img):
24
+ #img = Image.open(path)#"../style5.jpg")
25
+ img = Image.fromarray(img)
26
+ width, height = img.size
27
+ idx = np.argmin([width,height])
28
+ #print(width,height)
29
+
30
+ if idx==0:
31
+ top= (height-width)//2
32
+ bottom= height-(height-width)//2
33
+ left = 0
34
+ right= width
35
+ else:
36
+ left = (width-height)//2
37
+ right = width - (width-height)//2
38
+ top = 0
39
+ bottom = height
40
+
41
+ newsize = (640, 640) # parameters from test script
42
+ im1 = img.crop((left, top, right, bottom))
43
+
44
+ copies = 8
45
+ resize = (newsize[0]//copies,newsize[1]//copies)
46
+ dst = Image.new('RGB', (resize[0]*copies,resize[1]*copies))
47
+ im2 = im1.resize((resize))
48
+ for row in range(copies):
49
+ im2 = im2.transpose(Image.FLIP_LEFT_RIGHT)
50
+ for column in range(copies):
51
+ im2 = im2.transpose(Image.FLIP_TOP_BOTTOM)
52
+ dst.paste(im2, (resize[0]*row, resize[1]*column))
53
+ dst = dst.resize((newsize))
54
+ return dst
55
+
56
+ def create_styledSofa(sofa,style):
57
+ path_sofa,path_style = 'sofa.jpg','style.jpg'
58
+ sofa.save(path_sofa)
59
+ style.save(path_style)
60
+ #newpath_sofa = resize_sofa(path_sofa)
61
+ #newpath_style = resize_style(path_style)
62
+ os.system("time python3 test.py --content "+path_sofa+" \
63
+ --style "+path_style+" \
64
+ --output . \
65
+ --vgg vgg_normalised.pth \
66
+ --decoder_path decoder_iter_160000.pth \
67
+ --Trans_path transformer_iter_160000.pth \
68
+ --embedding_path embedding_iter_160000.pth")
69
+ styled_sofa = cv2.imread('sofa_stylized_style.jpg')
70
+
71
+ return styled_sofa
72
+
73
+ # image = Image.open('input/sofa.jpg')
74
+ # image = np.array(image)[:,:600]
75
+ # image = resize_sofa(image)
style_example2.jpg ADDED
style_example3.jpg ADDED
style_example4.jpg ADDED
style_example5.jpg ADDED