Image Classification
English
cloudwalker commited on
Commit
f892555
1 Parent(s): 9aa1a34

Upload 4 files

Browse files
wavemix/SemSegment.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from wavemix import Level4Waveblock, Level3Waveblock, Level2Waveblock, Level1Waveblock
2
+ import torch.nn as nn
3
+
4
+
5
+
6
+ class WaveMix(nn.Module):
7
+ def __init__(
8
+ self,
9
+ *,
10
+ num_classes=20,
11
+ depth = 16,
12
+ mult = 2,
13
+ ff_channel = 256,
14
+ final_dim = 256,
15
+ dropout = 0.,
16
+ level = 4,
17
+ stride = 2
18
+ ):
19
+
20
+ super().__init__()
21
+
22
+ self.layers = nn.ModuleList([])
23
+ for _ in range(depth):
24
+ if level == 4:
25
+ self.layers.append(Level4Waveblock(mult = mult, ff_channel = ff_channel, final_dim = final_dim, dropout = dropout))
26
+ elif level == 3:
27
+ self.layers.append(Level3Waveblock(mult = mult, ff_channel = ff_channel, final_dim = final_dim, dropout = dropout))
28
+ elif level == 2:
29
+ self.layers.append(Level2Waveblock(mult = mult, ff_channel = ff_channel, final_dim = final_dim, dropout = dropout))
30
+ else:
31
+ self.layers.append(Level1Waveblock(mult = mult, ff_channel = ff_channel, final_dim = final_dim, dropout = dropout))
32
+
33
+
34
+ self.expand = nn.Sequential(
35
+ nn.ConvTranspose2d(final_dim , int(final_dim/2), 4, stride=2, padding=1),
36
+ nn.ConvTranspose2d(int(final_dim/2), int(final_dim/4), 4, stride=2, padding=1),
37
+ nn.Conv2d(int(final_dim/4), num_classes, 1)
38
+ )
39
+
40
+
41
+ self.conv = nn.Sequential(
42
+ nn.Conv2d(3, int(final_dim/2), 3, stride, 1),
43
+ nn.Conv2d(int(final_dim/2),final_dim, 3, stride, 1)
44
+ )
45
+
46
+
47
+
48
+ def forward(self, img):
49
+ x = self.conv(img)
50
+
51
+ for attn in self.layers:
52
+ x = attn(x) + x
53
+
54
+ out = self.expand(x)
55
+
56
+ return out
wavemix/__init__.py ADDED
@@ -0,0 +1,621 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import numpy as np
4
+ from torch.autograd import Function
5
+ import torch.nn as nn
6
+ import pywt
7
+ from einops import rearrange, repeat
8
+ from einops.layers.torch import Rearrange
9
+
10
+
11
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
12
+
13
+ def sfb1d(lo, hi, g0, g1, mode='zero', dim=-1):
14
+ """ 1D synthesis filter bank of an image tensor
15
+ """
16
+ C = lo.shape[1]
17
+ d = dim % 4
18
+ # If g0, g1 are not tensors, make them. If they are, then assume that they
19
+ # are in the right order
20
+ if not isinstance(g0, torch.Tensor):
21
+ g0 = torch.tensor(np.copy(np.array(g0).ravel()),
22
+ dtype=torch.float, device=lo.device)
23
+ if not isinstance(g1, torch.Tensor):
24
+ g1 = torch.tensor(np.copy(np.array(g1).ravel()),
25
+ dtype=torch.float, device=lo.device)
26
+ L = g0.numel()
27
+ shape = [1,1,1,1]
28
+ shape[d] = L
29
+ N = 2*lo.shape[d]
30
+ # If g aren't in the right shape, make them so
31
+ if g0.shape != tuple(shape):
32
+ g0 = g0.reshape(*shape)
33
+ if g1.shape != tuple(shape):
34
+ g1 = g1.reshape(*shape)
35
+
36
+ s = (2, 1) if d == 2 else (1,2)
37
+ g0 = torch.cat([g0]*C,dim=0)
38
+ g1 = torch.cat([g1]*C,dim=0)
39
+ if mode == 'per' or mode == 'periodization':
40
+ y = F.conv_transpose2d(lo, g0, stride=s, groups=C) + \
41
+ F.conv_transpose2d(hi, g1, stride=s, groups=C)
42
+ if d == 2:
43
+ y[:,:,:L-2] = y[:,:,:L-2] + y[:,:,N:N+L-2]
44
+ y = y[:,:,:N]
45
+ else:
46
+ y[:,:,:,:L-2] = y[:,:,:,:L-2] + y[:,:,:,N:N+L-2]
47
+ y = y[:,:,:,:N]
48
+ y = roll(y, 1-L//2, dim=dim)
49
+ else:
50
+ if mode == 'zero' or mode == 'symmetric' or mode == 'reflect' or \
51
+ mode == 'periodic':
52
+ pad = (L-2, 0) if d == 2 else (0, L-2)
53
+ y = F.conv_transpose2d(lo, g0, stride=s, padding=pad, groups=C) + \
54
+ F.conv_transpose2d(hi, g1, stride=s, padding=pad, groups=C)
55
+ else:
56
+ raise ValueError("Unkown pad type: {}".format(mode))
57
+
58
+ return y
59
+
60
+ def reflect(x, minx, maxx):
61
+ """Reflect the values in matrix *x* about the scalar values *minx* and
62
+ *maxx*. Hence a vector *x* containing a long linearly increasing series is
63
+ converted into a waveform which ramps linearly up and down between *minx*
64
+ and *maxx*. If *x* contains integers and *minx* and *maxx* are (integers +
65
+ 0.5), the ramps will have repeated max and min samples.
66
+ .. codeauthor:: Rich Wareham <rjw57@cantab.net>, Aug 2013
67
+ .. codeauthor:: Nick Kingsbury, Cambridge University, January 1999.
68
+ """
69
+ x = np.asanyarray(x)
70
+ rng = maxx - minx
71
+ rng_by_2 = 2 * rng
72
+ mod = np.fmod(x - minx, rng_by_2)
73
+ normed_mod = np.where(mod < 0, mod + rng_by_2, mod)
74
+ out = np.where(normed_mod >= rng, rng_by_2 - normed_mod, normed_mod) + minx
75
+ return np.array(out, dtype=x.dtype)
76
+
77
+ def mode_to_int(mode):
78
+ if mode == 'zero':
79
+ return 0
80
+ elif mode == 'symmetric':
81
+ return 1
82
+ elif mode == 'per' or mode == 'periodization':
83
+ return 2
84
+ elif mode == 'constant':
85
+ return 3
86
+ elif mode == 'reflect':
87
+ return 4
88
+ elif mode == 'replicate':
89
+ return 5
90
+ elif mode == 'periodic':
91
+ return 6
92
+ else:
93
+ raise ValueError("Unkown pad type: {}".format(mode))
94
+
95
+ def int_to_mode(mode):
96
+ if mode == 0:
97
+ return 'zero'
98
+ elif mode == 1:
99
+ return 'symmetric'
100
+ elif mode == 2:
101
+ return 'periodization'
102
+ elif mode == 3:
103
+ return 'constant'
104
+ elif mode == 4:
105
+ return 'reflect'
106
+ elif mode == 5:
107
+ return 'replicate'
108
+ elif mode == 6:
109
+ return 'periodic'
110
+ else:
111
+ raise ValueError("Unkown pad type: {}".format(mode))
112
+
113
+ def afb1d(x, h0, h1, mode='zero', dim=-1):
114
+ """ 1D analysis filter bank (along one dimension only) of an image
115
+ Inputs:
116
+ x (tensor): 4D input with the last two dimensions the spatial input
117
+ h0 (tensor): 4D input for the lowpass filter. Should have shape (1, 1,
118
+ h, 1) or (1, 1, 1, w)
119
+ h1 (tensor): 4D input for the highpass filter. Should have shape (1, 1,
120
+ h, 1) or (1, 1, 1, w)
121
+ mode (str): padding method
122
+ dim (int) - dimension of filtering. d=2 is for a vertical filter (called
123
+ column filtering but filters across the rows). d=3 is for a
124
+ horizontal filter, (called row filtering but filters across the
125
+ columns).
126
+ Returns:
127
+ lohi: lowpass and highpass subbands concatenated along the channel
128
+ dimension
129
+ """
130
+ C = x.shape[1]
131
+ # Convert the dim to positive
132
+ d = dim % 4
133
+ s = (2, 1) if d == 2 else (1, 2)
134
+ N = x.shape[d]
135
+ # If h0, h1 are not tensors, make them. If they are, then assume that they
136
+ # are in the right order
137
+ if not isinstance(h0, torch.Tensor):
138
+ h0 = torch.tensor(np.copy(np.array(h0).ravel()[::-1]),
139
+ dtype=torch.float, device=x.device)
140
+ if not isinstance(h1, torch.Tensor):
141
+ h1 = torch.tensor(np.copy(np.array(h1).ravel()[::-1]),
142
+ dtype=torch.float, device=x.device)
143
+ L = h0.numel()
144
+ L2 = L // 2
145
+ shape = [1,1,1,1]
146
+ shape[d] = L
147
+ # If h aren't in the right shape, make them so
148
+ if h0.shape != tuple(shape):
149
+ h0 = h0.reshape(*shape)
150
+ if h1.shape != tuple(shape):
151
+ h1 = h1.reshape(*shape)
152
+ h = torch.cat([h0, h1] * C, dim=0)
153
+
154
+ if mode == 'per' or mode == 'periodization':
155
+ if x.shape[dim] % 2 == 1:
156
+ if d == 2:
157
+ x = torch.cat((x, x[:,:,-1:]), dim=2)
158
+ else:
159
+ x = torch.cat((x, x[:,:,:,-1:]), dim=3)
160
+ N += 1
161
+ x = roll(x, -L2, dim=d)
162
+ pad = (L-1, 0) if d == 2 else (0, L-1)
163
+ lohi = F.conv2d(x, h, padding=pad, stride=s, groups=C)
164
+ N2 = N//2
165
+ if d == 2:
166
+ lohi[:,:,:L2] = lohi[:,:,:L2] + lohi[:,:,N2:N2+L2]
167
+ lohi = lohi[:,:,:N2]
168
+ else:
169
+ lohi[:,:,:,:L2] = lohi[:,:,:,:L2] + lohi[:,:,:,N2:N2+L2]
170
+ lohi = lohi[:,:,:,:N2]
171
+ else:
172
+ # Calculate the pad size
173
+ outsize = pywt.dwt_coeff_len(N, L, mode=mode)
174
+ p = 2 * (outsize - 1) - N + L
175
+ if mode == 'zero':
176
+ # Sadly, pytorch only allows for same padding before and after, if
177
+ # we need to do more padding after for odd length signals, have to
178
+ # prepad
179
+ if p % 2 == 1:
180
+ pad = (0, 0, 0, 1) if d == 2 else (0, 1, 0, 0)
181
+ x = F.pad(x, pad)
182
+ pad = (p//2, 0) if d == 2 else (0, p//2)
183
+ # Calculate the high and lowpass
184
+ lohi = F.conv2d(x, h, padding=pad, stride=s, groups=C)
185
+ elif mode == 'symmetric' or mode == 'reflect' or mode == 'periodic':
186
+ pad = (0, 0, p//2, (p+1)//2) if d == 2 else (p//2, (p+1)//2, 0, 0)
187
+ x = mypad(x, pad=pad, mode=mode)
188
+ lohi = F.conv2d(x, h, stride=s, groups=C)
189
+ else:
190
+ raise ValueError("Unkown pad type: {}".format(mode))
191
+
192
+ return lohi
193
+
194
+
195
+
196
+ class AFB2D(Function):
197
+ """ Does a single level 2d wavelet decomposition of an input. Does separate
198
+ row and column filtering by two calls to
199
+ :py:func:`pytorch_wavelets.dwt.lowlevel.afb1d`
200
+ Needs to have the tensors in the right form. Because this function defines
201
+ its own backward pass, saves on memory by not having to save the input
202
+ tensors.
203
+ Inputs:
204
+ x (torch.Tensor): Input to decompose
205
+ h0_row: row lowpass
206
+ h1_row: row highpass
207
+ h0_col: col lowpass
208
+ h1_col: col highpass
209
+ mode (int): use mode_to_int to get the int code here
210
+ We encode the mode as an integer rather than a string as gradcheck causes an
211
+ error when a string is provided.
212
+ Returns:
213
+ y: Tensor of shape (N, C*4, H, W)
214
+ """
215
+ @staticmethod
216
+ def forward(ctx, x, h0_row, h1_row, h0_col, h1_col, mode):
217
+ ctx.save_for_backward(h0_row, h1_row, h0_col, h1_col)
218
+ ctx.shape = x.shape[-2:]
219
+ mode = int_to_mode(mode)
220
+ ctx.mode = mode
221
+ lohi = afb1d(x, h0_row, h1_row, mode=mode, dim=3)
222
+ y = afb1d(lohi, h0_col, h1_col, mode=mode, dim=2)
223
+ s = y.shape
224
+ y = y.reshape(s[0], -1, 4, s[-2], s[-1])
225
+ low = y[:,:,0].contiguous()
226
+ highs = y[:,:,1:].contiguous()
227
+ return low, highs
228
+
229
+ @staticmethod
230
+ def backward(ctx, low, highs):
231
+ dx = None
232
+ if ctx.needs_input_grad[0]:
233
+ mode = ctx.mode
234
+ h0_row, h1_row, h0_col, h1_col = ctx.saved_tensors
235
+ lh, hl, hh = torch.unbind(highs, dim=2)
236
+ lo = sfb1d(low, lh, h0_col, h1_col, mode=mode, dim=2)
237
+ hi = sfb1d(hl, hh, h0_col, h1_col, mode=mode, dim=2)
238
+ dx = sfb1d(lo, hi, h0_row, h1_row, mode=mode, dim=3)
239
+ if dx.shape[-2] > ctx.shape[-2] and dx.shape[-1] > ctx.shape[-1]:
240
+ dx = dx[:,:,:ctx.shape[-2], :ctx.shape[-1]]
241
+ elif dx.shape[-2] > ctx.shape[-2]:
242
+ dx = dx[:,:,:ctx.shape[-2]]
243
+ elif dx.shape[-1] > ctx.shape[-1]:
244
+ dx = dx[:,:,:,:ctx.shape[-1]]
245
+ return dx, None, None, None, None, None
246
+
247
+
248
+ def prep_filt_afb2d(h0_col, h1_col, h0_row=None, h1_row=None, device=device):
249
+ """
250
+ Prepares the filters to be of the right form for the afb2d function. In
251
+ particular, makes the tensors the right shape. It takes mirror images of
252
+ them as as afb2d uses conv2d which acts like normal correlation.
253
+ Inputs:
254
+ h0_col (array-like): low pass column filter bank
255
+ h1_col (array-like): high pass column filter bank
256
+ h0_row (array-like): low pass row filter bank. If none, will assume the
257
+ same as column filter
258
+ h1_row (array-like): high pass row filter bank. If none, will assume the
259
+ same as column filter
260
+ device: which device to put the tensors on to
261
+ Returns:
262
+ (h0_col, h1_col, h0_row, h1_row)
263
+ """
264
+ h0_col, h1_col = prep_filt_afb1d(h0_col, h1_col, device)
265
+ if h0_row is None:
266
+ h0_row, h1_col = h0_col, h1_col
267
+ else:
268
+ h0_row, h1_row = prep_filt_afb1d(h0_row, h1_row, device)
269
+
270
+ h0_col = h0_col.reshape((1, 1, -1, 1))
271
+ h1_col = h1_col.reshape((1, 1, -1, 1))
272
+ h0_row = h0_row.reshape((1, 1, 1, -1))
273
+ h1_row = h1_row.reshape((1, 1, 1, -1))
274
+ return h0_col, h1_col, h0_row, h1_row
275
+
276
+
277
+ def prep_filt_afb1d(h0, h1, device=device):
278
+ """
279
+ Prepares the filters to be of the right form for the afb2d function. In
280
+ particular, makes the tensors the right shape. It takes mirror images of
281
+ them as as afb2d uses conv2d which acts like normal correlation.
282
+ Inputs:
283
+ h0 (array-like): low pass column filter bank
284
+ h1 (array-like): high pass column filter bank
285
+ device: which device to put the tensors on to
286
+ Returns:
287
+ (h0, h1)
288
+ """
289
+ h0 = np.array(h0[::-1]).ravel()
290
+ h1 = np.array(h1[::-1]).ravel()
291
+ t = torch.get_default_dtype()
292
+ h0 = torch.tensor(h0, device=device, dtype=t).reshape((1, 1, -1))
293
+ h1 = torch.tensor(h1, device=device, dtype=t).reshape((1, 1, -1))
294
+ return h0, h1
295
+
296
+ class DWTForward(nn.Module):
297
+ """ Performs a 2d DWT Forward decomposition of an image
298
+ Args:
299
+ J (int): Number of levels of decomposition
300
+ wave (str or pywt.Wavelet or tuple(ndarray)): Which wavelet to use.
301
+ Can be:
302
+ 1) a string to pass to pywt.Wavelet constructor
303
+ 2) a pywt.Wavelet class
304
+ 3) a tuple of numpy arrays, either (h0, h1) or (h0_col, h1_col, h0_row, h1_row)
305
+ mode (str): 'zero', 'symmetric', 'reflect' or 'periodization'. The
306
+ padding scheme
307
+ """
308
+ def __init__(self, J=1, wave='db1', mode='zero'):
309
+ super().__init__()
310
+ if isinstance(wave, str):
311
+ wave = pywt.Wavelet(wave)
312
+ if isinstance(wave, pywt.Wavelet):
313
+ h0_col, h1_col = wave.dec_lo, wave.dec_hi
314
+ h0_row, h1_row = h0_col, h1_col
315
+ else:
316
+ if len(wave) == 2:
317
+ h0_col, h1_col = wave[0], wave[1]
318
+ h0_row, h1_row = h0_col, h1_col
319
+ elif len(wave) == 4:
320
+ h0_col, h1_col = wave[0], wave[1]
321
+ h0_row, h1_row = wave[2], wave[3]
322
+
323
+ # Prepare the filters
324
+ filts = prep_filt_afb2d(h0_col, h1_col, h0_row, h1_row)
325
+ self.register_buffer('h0_col', filts[0])
326
+ self.register_buffer('h1_col', filts[1])
327
+ self.register_buffer('h0_row', filts[2])
328
+ self.register_buffer('h1_row', filts[3])
329
+ self.J = J
330
+ self.mode = mode
331
+
332
+ def forward(self, x):
333
+ """ Forward pass of the DWT.
334
+ Args:
335
+ x (tensor): Input of shape :math:`(N, C_{in}, H_{in}, W_{in})`
336
+ Returns:
337
+ (yl, yh)
338
+ tuple of lowpass (yl) and bandpass (yh) coefficients.
339
+ yh is a list of length J with the first entry
340
+ being the finest scale coefficients. yl has shape
341
+ :math:`(N, C_{in}, H_{in}', W_{in}')` and yh has shape
342
+ :math:`list(N, C_{in}, 3, H_{in}'', W_{in}'')`. The new
343
+ dimension in yh iterates over the LH, HL and HH coefficients.
344
+ Note:
345
+ :math:`H_{in}', W_{in}', H_{in}'', W_{in}''` denote the correctly
346
+ downsampled shapes of the DWT pyramid.
347
+ """
348
+ yh = []
349
+ ll = x
350
+ mode = mode_to_int(self.mode)
351
+
352
+ # Do a multilevel transform
353
+ for j in range(self.J):
354
+ # Do 1 level of the transform
355
+ ll, high = AFB2D.apply(
356
+ ll, self.h0_col, self.h1_col, self.h0_row, self.h1_row, mode)
357
+ yh.append(high)
358
+
359
+ return ll, yh
360
+
361
+ from numpy.lib.function_base import hamming
362
+
363
+ xf1 = DWTForward(J=1, mode='zero', wave='db1').to(device)
364
+ xf2 = DWTForward(J=2, mode='zero', wave='db1').to(device)
365
+ xf3 = DWTForward(J=3, mode='zero', wave='db1').to(device)
366
+ xf4 = DWTForward(J=4, mode='zero', wave='db1').to(device)
367
+
368
+ class Level1Waveblock(nn.Module):
369
+ def __init__(
370
+ self,
371
+ *,
372
+ mult = 2,
373
+ ff_channel = 16,
374
+ final_dim = 16,
375
+ dropout = 0.5,
376
+ ):
377
+ super().__init__()
378
+
379
+
380
+ self.feedforward = nn.Sequential(
381
+ nn.Conv2d(final_dim, final_dim*mult,1),
382
+ nn.GELU(),
383
+ nn.Dropout(dropout),
384
+ nn.Conv2d(final_dim*mult, ff_channel, 1),
385
+ nn.ConvTranspose2d(ff_channel, final_dim, 4, stride=2, padding=1),
386
+ nn.BatchNorm2d(final_dim)
387
+
388
+ )
389
+
390
+ self.reduction = nn.Conv2d(final_dim, int(final_dim/4), 1)
391
+
392
+
393
+ def forward(self, x):
394
+ b, c, h, w = x.shape
395
+
396
+ x = self.reduction(x)
397
+
398
+ Y1, Yh = xf1(x)
399
+
400
+ x = torch.reshape(Yh[0], (b, int(c*3/4), int(h/2), int(w/2)))
401
+
402
+ x = torch.cat((Y1,x), dim = 1)
403
+
404
+ x = self.feedforward(x)
405
+
406
+ return x
407
+
408
+ class Level2Waveblock(nn.Module):
409
+ def __init__(
410
+ self,
411
+ *,
412
+ mult = 2,
413
+ ff_channel = 16,
414
+ final_dim = 16,
415
+ dropout = 0.5,
416
+ ):
417
+ super().__init__()
418
+
419
+ self.feedforward1 = nn.Sequential(
420
+ nn.Conv2d(final_dim + int(final_dim/2), final_dim*mult,1),
421
+ nn.GELU(),
422
+ nn.Dropout(dropout),
423
+ nn.Conv2d(final_dim*mult, ff_channel, 1),
424
+ nn.ConvTranspose2d(ff_channel, final_dim, 4, stride=2, padding=1),
425
+ nn.BatchNorm2d(final_dim)
426
+ )
427
+
428
+ self.feedforward2 = nn.Sequential(
429
+ nn.Conv2d(final_dim, final_dim*mult,1),
430
+ nn.GELU(),
431
+ nn.Dropout(dropout),
432
+ nn.Conv2d(final_dim*mult, ff_channel, 1),
433
+ nn.ConvTranspose2d(ff_channel, int(final_dim/2), 4, stride=2, padding=1),
434
+ nn.BatchNorm2d(int(final_dim/2))
435
+ )
436
+
437
+ self.reduction = nn.Conv2d(final_dim, int(final_dim/4), 1)
438
+
439
+
440
+ def forward(self, x):
441
+ b, c, h, w = x.shape
442
+
443
+ x = self.reduction(x)
444
+
445
+ Y1, Yh = xf1(x)
446
+ Y2, Yh = xf2(x)
447
+
448
+
449
+ x1 = torch.reshape(Yh[0], (b, int(c*3/4), int(h/2), int(w/2)))
450
+ x2 = torch.reshape(Yh[1], (b, int(c*3/4), int(h/4), int(w/4)))
451
+
452
+ x1 = torch.cat((Y1,x1), dim = 1)
453
+ x2 = torch.cat((Y2,x2), dim = 1)
454
+
455
+ x2 = self.feedforward2(x2)
456
+
457
+ x1 = torch.cat((x1,x2), dim = 1)
458
+ x = self.feedforward1(x1)
459
+
460
+ return x
461
+
462
+
463
+ class Level3Waveblock(nn.Module):
464
+ def __init__(
465
+ self,
466
+ *,
467
+ mult = 2,
468
+ ff_channel = 16,
469
+ final_dim = 16,
470
+ dropout = 0.5,
471
+ ):
472
+ super().__init__()
473
+
474
+ self.feedforward1 = nn.Sequential(
475
+ nn.Conv2d(final_dim + int(final_dim/2), final_dim*mult,1),
476
+ nn.GELU(),
477
+ nn.Dropout(dropout),
478
+ nn.Conv2d(final_dim*mult, ff_channel, 1),
479
+ nn.ConvTranspose2d(ff_channel, final_dim, 4, stride=2, padding=1),
480
+ nn.BatchNorm2d(final_dim)
481
+ )
482
+
483
+ self.feedforward2 = nn.Sequential(
484
+ nn.Conv2d(final_dim + int(final_dim/2), final_dim*mult,1),
485
+ nn.GELU(),
486
+ nn.Dropout(dropout),
487
+ nn.Conv2d(final_dim*mult, ff_channel, 1),
488
+ nn.ConvTranspose2d(ff_channel, int(final_dim/2), 4, stride=2, padding=1),
489
+ nn.BatchNorm2d(int(final_dim/2))
490
+ )
491
+
492
+ self.feedforward3 = nn.Sequential(
493
+ nn.Conv2d(final_dim, final_dim*mult,1),
494
+ nn.GELU(),
495
+ nn.Dropout(dropout),
496
+ nn.Conv2d(final_dim*mult, ff_channel, 1),
497
+ nn.ConvTranspose2d(ff_channel, int(final_dim/2), 4, stride=2, padding=1),
498
+ nn.BatchNorm2d(int(final_dim/2))
499
+ )
500
+
501
+ self.reduction = nn.Conv2d(final_dim, int(final_dim/4), 1)
502
+
503
+
504
+ def forward(self, x):
505
+ b, c, h, w = x.shape
506
+
507
+ x = self.reduction(x)
508
+
509
+ Y1, Yh = xf1(x)
510
+ Y2, Yh = xf2(x)
511
+ Y3, Yh = xf3(x)
512
+
513
+
514
+ x1 = torch.reshape(Yh[0], (b, int(c*3/4), int(h/2), int(w/2)))
515
+ x2 = torch.reshape(Yh[1], (b, int(c*3/4), int(h/4), int(w/4)))
516
+ x3 = torch.reshape(Yh[2], (b, int(c*3/4), int(h/8), int(w/8)))
517
+
518
+
519
+ x1 = torch.cat((Y1,x1), dim = 1)
520
+ x2 = torch.cat((Y2,x2), dim = 1)
521
+ x3 = torch.cat((Y3,x3), dim = 1)
522
+
523
+
524
+ x3 = self.feedforward3(x3)
525
+
526
+ x2 = torch.cat((x2,x3), dim = 1)
527
+
528
+ x2 = self.feedforward2(x2)
529
+
530
+ x1 = torch.cat((x1,x2), dim = 1)
531
+ x = self.feedforward1(x1)
532
+
533
+ return x
534
+
535
+
536
+ class Level4Waveblock(nn.Module):
537
+ def __init__(
538
+ self,
539
+ *,
540
+ mult = 2,
541
+ ff_channel = 16,
542
+ final_dim = 16,
543
+ dropout = 0.5,
544
+ ):
545
+ super().__init__()
546
+
547
+
548
+ self.feedforward1 = nn.Sequential(
549
+ nn.Conv2d(final_dim + int(final_dim/2), final_dim*mult,1),
550
+ nn.GELU(),
551
+ nn.Dropout(dropout),
552
+ nn.Conv2d(final_dim*mult, ff_channel, 1),
553
+ nn.ConvTranspose2d(ff_channel, final_dim, 4, stride=2, padding=1),
554
+ nn.BatchNorm2d(final_dim)
555
+ )
556
+
557
+ self.feedforward2 = nn.Sequential(
558
+ nn.Conv2d(final_dim + int(final_dim/2), final_dim*mult,1),
559
+ nn.GELU(),
560
+ nn.Dropout(dropout),
561
+ nn.Conv2d(final_dim*mult, ff_channel, 1),
562
+ nn.ConvTranspose2d(ff_channel, int(final_dim/2), 4, stride=2, padding=1),
563
+ nn.BatchNorm2d(int(final_dim/2))
564
+ )
565
+
566
+ self.feedforward3 = nn.Sequential(
567
+ nn.Conv2d(final_dim+ int(final_dim/2), final_dim*mult,1),
568
+ nn.GELU(),
569
+ nn.Dropout(dropout),
570
+ nn.Conv2d(final_dim*mult, ff_channel, 1),
571
+ nn.ConvTranspose2d(ff_channel, int(final_dim/2), 4, stride=2, padding=1),
572
+ nn.BatchNorm2d(int(final_dim/2))
573
+ )
574
+
575
+ self.feedforward4 = nn.Sequential(
576
+ nn.Conv2d(final_dim, final_dim*mult,1),
577
+ nn.GELU(),
578
+ nn.Dropout(dropout),
579
+ nn.Conv2d(final_dim*mult, ff_channel, 1),
580
+ nn.ConvTranspose2d(ff_channel, int(final_dim/2), 4, stride=2, padding=1),
581
+ nn.BatchNorm2d(int(final_dim/2))
582
+ )
583
+
584
+ self.reduction = nn.Conv2d(final_dim, int(final_dim/4), 1)
585
+
586
+
587
+ def forward(self, x):
588
+ b, c, h, w = x.shape
589
+
590
+ x = self.reduction(x)
591
+
592
+ Y1, Yh = xf1(x)
593
+ Y2, Yh = xf2(x)
594
+ Y3, Yh = xf3(x)
595
+ Y4, Yh = xf4(x)
596
+
597
+ x1 = torch.reshape(Yh[0], (b, int(c*3/4), int(h/2), int(w/2)))
598
+ x2 = torch.reshape(Yh[1], (b, int(c*3/4), int(h/4), int(w/4)))
599
+ x3 = torch.reshape(Yh[2], (b, int(c*3/4), int(h/8), int(w/8)))
600
+ x4 = torch.reshape(Yh[3], (b, int(c*3/4), int(h/16), int(w/16)))
601
+
602
+ x1 = torch.cat((Y1,x1), dim = 1)
603
+ x2 = torch.cat((Y2,x2), dim = 1)
604
+ x3 = torch.cat((Y3,x3), dim = 1)
605
+ x4 = torch.cat((Y4,x4), dim = 1)
606
+
607
+
608
+ x4 = self.feedforward4(x4)
609
+
610
+ x3 = torch.cat((x3,x4), dim = 1)
611
+
612
+ x3 = self.feedforward3(x3)
613
+
614
+ x2 = torch.cat((x2,x3), dim = 1)
615
+
616
+ x2 = self.feedforward2(x2)
617
+
618
+ x1 = torch.cat((x1,x2), dim = 1)
619
+ x = self.feedforward1(x1)
620
+
621
+ return x
wavemix/classification.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from wavemix import Level4Waveblock, Level3Waveblock, Level2Waveblock, Level1Waveblock
2
+ import torch.nn as nn
3
+ from einops.layers.torch import Rearrange
4
+
5
+ class WaveMix(nn.Module):
6
+ def __init__(
7
+ self,
8
+ *,
9
+ num_classes=1000,
10
+ depth = 16,
11
+ mult = 2,
12
+ ff_channel = 192,
13
+ final_dim = 192,
14
+ dropout = 0.5,
15
+ level = 3,
16
+ initial_conv = 'pachify', # or 'strided'
17
+ patch_size = 4,
18
+ stride = 2,
19
+
20
+ ):
21
+ super().__init__()
22
+
23
+ self.layers = nn.ModuleList([])
24
+ for _ in range(depth):
25
+ if level == 4:
26
+ self.layers.append(Level4Waveblock(mult = mult, ff_channel = ff_channel, final_dim = final_dim, dropout = dropout))
27
+ elif level == 3:
28
+ self.layers.append(Level3Waveblock(mult = mult, ff_channel = ff_channel, final_dim = final_dim, dropout = dropout))
29
+ elif level == 2:
30
+ self.layers.append(Level2Waveblock(mult = mult, ff_channel = ff_channel, final_dim = final_dim, dropout = dropout))
31
+ else:
32
+ self.layers.append(Level1Waveblock(mult = mult, ff_channel = ff_channel, final_dim = final_dim, dropout = dropout))
33
+
34
+ self.pool = nn.Sequential(
35
+ nn.AdaptiveAvgPool2d(1),
36
+ Rearrange('... () () -> ...'),
37
+ nn.Linear(final_dim, num_classes)
38
+ )
39
+
40
+ if initial_conv == 'strided':
41
+ self.conv = nn.Sequential(
42
+ nn.Conv2d(3, int(final_dim/2), 3, stride, 1),
43
+ nn.Conv2d(int(final_dim/2), final_dim, 3, stride, 1)
44
+ )
45
+ else:
46
+ self.conv = nn.Sequential(
47
+ nn.Conv2d(3, int(final_dim/4),3, 1, 1),
48
+ nn.Conv2d(int(final_dim/4), int(final_dim/2), 3, 1, 1),
49
+ nn.Conv2d(int(final_dim/2), final_dim, patch_size, patch_size),
50
+ nn.GELU(),
51
+ nn.BatchNorm2d(final_dim)
52
+ )
53
+
54
+
55
+ def forward(self, img):
56
+ x = self.conv(img)
57
+
58
+ for attn in self.layers:
59
+ x = attn(x) + x
60
+
61
+ out = self.pool(x)
62
+
63
+ return out
wavemix/sisr.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from wavemix import Level4Waveblock, Level3Waveblock, Level2Waveblock, Level1Waveblock
2
+ import torch.nn as nn
3
+
4
+
5
+
6
+ class WaveMix(nn.Module):
7
+ def __init__(
8
+ self,
9
+ *,
10
+ depth = 4,
11
+ mult = 2,
12
+ ff_channel = 144,
13
+ final_dim = 144,
14
+ dropout = 0.,
15
+ level = 1,
16
+ ):
17
+
18
+ super().__init__()
19
+
20
+ self.layers = nn.ModuleList([])
21
+ for _ in range(depth):
22
+ if level == 4:
23
+ self.layers.append(Level4Waveblock(mult = mult, ff_channel = ff_channel, final_dim = final_dim, dropout = dropout))
24
+ elif level == 3:
25
+ self.layers.append(Level3Waveblock(mult = mult, ff_channel = ff_channel, final_dim = final_dim, dropout = dropout))
26
+ elif level == 2:
27
+ self.layers.append(Level2Waveblock(mult = mult, ff_channel = ff_channel, final_dim = final_dim, dropout = dropout))
28
+ else:
29
+ self.layers.append(Level1Waveblock(mult = mult, ff_channel = ff_channel, final_dim = final_dim, dropout = dropout))
30
+
31
+
32
+ self.expand = nn.Sequential(
33
+ nn.ConvTranspose2d(final_dim,int(final_dim/2), 4, stride=2, padding=1),
34
+ nn.Conv2d(int(final_dim/2), 3, 1)
35
+ )
36
+
37
+
38
+ self.conv = nn.Sequential(
39
+ nn.Conv2d(3, int(final_dim/2), 3, 1, 1),
40
+ nn.Conv2d(int(final_dim/2),final_dim, 3, 1, 1)
41
+ )
42
+
43
+
44
+
45
+ def forward(self, img):
46
+ x = self.conv(img)
47
+
48
+ for attn in self.layers:
49
+ x = attn(x) + x
50
+
51
+ out = self.expand(x)
52
+
53
+ return out