mickylan2367 commited on
Commit
ba54498
1 Parent(s): a4f7cb6

Upload model

Browse files
Files changed (4) hide show
  1. config.json +28 -0
  2. configuration_fsae.py +65 -0
  3. model.safetensors +3 -0
  4. modeling_fsae.py +657 -0
config.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "Vth": 0.2,
3
+ "a": 0.25,
4
+ "aa": 0.5,
5
+ "architectures": [
6
+ "FSAEModel"
7
+ ],
8
+ "auto_map": {
9
+ "AutoConfig": "configuration_fsae.FSAEConfig",
10
+ "AutoModel": "modeling_fsae.FSAEModel"
11
+ },
12
+ "dt": 5,
13
+ "hidden_dims": [
14
+ 32,
15
+ 64,
16
+ 128,
17
+ 256
18
+ ],
19
+ "in_channels": 1,
20
+ "k": 20,
21
+ "latent_dim": 128,
22
+ "model_type": "fsae",
23
+ "n_steps": 16,
24
+ "scheduled": true,
25
+ "tau": 0.25,
26
+ "torch_dtype": "float32",
27
+ "transformers_version": "4.35.0"
28
+ }
configuration_fsae.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pip install transformers
2
+ from transformers import PretrainedConfig
3
+ from typing import List
4
+
5
+
6
+ '''
7
+ newtwork_config = {
8
+ "epochs": 150,
9
+ "batch_size": 250,
10
+ "n_steps": 16, # timestep
11
+ "dataset": "CAPS",
12
+ "in_channels": 1,
13
+ "data_path": "./data",
14
+ "lr": 0.001,
15
+ "n_class": 10,
16
+ "latent_dim": 128,
17
+ "input_size": 32,
18
+ "model": "FSVAE" ,# FSVAE or FSVAE_large
19
+ "k": 20, # multiplier of channel
20
+ "scheduled": True, # whether to apply scheduled sampling
21
+ "loss_func": 'kld', # mmd or kld
22
+ "accum_iter" : 1,
23
+ "devices": [0],
24
+ }
25
+
26
+ hidden_dims = [32, 64, 128, 256]
27
+
28
+ '''
29
+
30
+ class FSAEConfig(PretrainedConfig):
31
+ model_type = "fsae"
32
+
33
+ def __init__(
34
+ self,
35
+ in_channels: int = 1,
36
+ hidden_dims : List[int] = [32, 64, 128, 256],
37
+ k : int = 20,
38
+ n_steps : int = 16,
39
+ latent_dim : int = 128,
40
+ scheduled : bool = True,
41
+ # loss_func : str = "kld",
42
+ dt:float = 5,
43
+ a:float = 0.25,
44
+ aa: float = 0.5,
45
+ Vth : float = 0.2, # しきい値電位
46
+ tau : float = 0.25,
47
+ **kwargs,
48
+ ):
49
+ # if block_type not in ["basic", "bottleneck"]:
50
+ # raise ValueError(f"`block_type` must be 'basic' or bottleneck', got {block_type}.")
51
+ # if stem_type not in ["", "deep", "deep-tiered"]:
52
+ # raise ValueError(f"`stem_type` must be '', 'deep' or 'deep-tiered', got {stem_type}.")
53
+
54
+ self.in_channels = in_channels
55
+ self.hidden_dims = hidden_dims
56
+ self.k = k
57
+ self.n_steps = n_steps
58
+ self.latent_dim = latent_dim
59
+ self.scheduled = scheduled
60
+ self.dt = dt
61
+ self.a = a
62
+ self.aa = aa
63
+ self.Vth = Vth
64
+ self.tau = tau
65
+ super().__init__(**kwargs)
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:df404010e0e2a77cf66e905d26b0b356397d8cfe61db32fddcca723319cb260b
3
+ size 4228636
modeling_fsae.py ADDED
@@ -0,0 +1,657 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torch.optim as optim
5
+ from argparse import ZERO_OR_MORE
6
+ import math
7
+ import random
8
+ from torch.nn.modules.module import T
9
+
10
+ from transformers import PreTrainedModel
11
+ from .configuration_fsae import FSAEConfig
12
+
13
+ dt = 5
14
+ a = 0.25
15
+ aa = 0.5
16
+ Vth = 0.2
17
+ tau = 0.25
18
+
19
+
20
+ class SpikeAct(torch.autograd.Function):
21
+ """
22
+ Implementation of the spiking activation function with an approximation of gradient.
23
+ """
24
+ @staticmethod
25
+ def forward(ctx, input):
26
+ ctx.save_for_backward(input)
27
+ # if input = u > Vth then output = 1
28
+ output = torch.gt(input, Vth)
29
+ return output.float()
30
+
31
+ @staticmethod
32
+ def backward(ctx, grad_output):
33
+ input, = ctx.saved_tensors
34
+ grad_input = grad_output.clone()
35
+ # hu is an approximate func of df/du
36
+ hu = abs(input) < aa
37
+ hu = hu.float() / (2 * aa)
38
+ return grad_input * hu
39
+
40
+ class LIFSpike(nn.Module):
41
+ """
42
+ Generates spikes based on LIF module. It can be considered as an activation function and is used similar to ReLU. The input tensor needs to have an additional time dimension, which in this case is on the last dimension of the data.
43
+ """
44
+ def __init__(self):
45
+ super(LIFSpike, self).__init__()
46
+
47
+ def forward(self, x):
48
+ nsteps = x.shape[-1]
49
+ u = torch.zeros(x.shape[:-1] , device=x.device)
50
+ out = torch.zeros(x.shape, device=x.device)
51
+ for step in range(nsteps):
52
+ u, out[..., step] = self.state_update(u, out[..., max(step-1, 0)], x[..., step])
53
+ return out
54
+
55
+ def state_update(self, u_t_n1, o_t_n1, W_mul_o_t1_n, tau=tau):
56
+ u_t1_n1 = tau * u_t_n1 * (1 - o_t_n1) + W_mul_o_t1_n
57
+ o_t1_n1 = SpikeAct.apply(u_t1_n1)
58
+ return u_t1_n1, o_t1_n1
59
+
60
+ class tdLinear(nn.Linear):
61
+ def __init__(self,
62
+ in_features,
63
+ out_features,
64
+ bias=True,
65
+ bn=None,
66
+ spike=None):
67
+ assert type(in_features) == int, 'inFeatures should not be more than 1 dimesnion. It was: {}'.format(in_features.shape)
68
+ assert type(out_features) == int, 'outFeatures should not be more than 1 dimesnion. It was: {}'.format(out_features.shape)
69
+
70
+ super(tdLinear, self).__init__(in_features, out_features, bias=bias)
71
+
72
+ self.bn = bn
73
+ self.spike = spike
74
+
75
+
76
+ def forward(self, x):
77
+ """
78
+ x : (N,C,T)
79
+ """
80
+ x = x.transpose(1, 2) # (N, T, C)
81
+ y = F.linear(x, self.weight, self.bias)
82
+ y = y.transpose(1, 2)# (N, C, T)
83
+
84
+ if self.bn is not None:
85
+ y = y[:,:,None,None,:]
86
+ y = self.bn(y)
87
+ y = y[:,:,0,0,:]
88
+ if self.spike is not None:
89
+ y = self.spike(y)
90
+ return y
91
+
92
+ class tdConv(nn.Conv3d):
93
+ def __init__(self,
94
+ in_channels,
95
+ out_channels,
96
+ kernel_size,
97
+ stride=1,
98
+ padding=0,
99
+ dilation=1,
100
+ groups=1,
101
+ bias=True,
102
+ bn=None,
103
+ spike=None,
104
+ is_first_conv=False):
105
+
106
+ # kernel
107
+ if type(kernel_size) == int:
108
+ kernel = (kernel_size, kernel_size, 1)
109
+ elif len(kernel_size) == 2:
110
+ kernel = (kernel_size[0], kernel_size[1], 1)
111
+ else:
112
+ raise Exception('kernelSize can only be of 1 or 2 dimension. It was: {}'.format(kernel_size.shape))
113
+
114
+ # stride
115
+ if type(stride) == int:
116
+ stride = (stride, stride, 1)
117
+ elif len(stride) == 2:
118
+ stride = (stride[0], stride[1], 1)
119
+ else:
120
+ raise Exception('stride can be either int or tuple of size 2. It was: {}'.format(stride.shape))
121
+
122
+ # padding
123
+ if type(padding) == int:
124
+ padding = (padding, padding, 0)
125
+ elif len(padding) == 2:
126
+ padding = (padding[0], padding[1], 0)
127
+ else:
128
+ raise Exception('padding can be either int or tuple of size 2. It was: {}'.format(padding.shape))
129
+
130
+ # dilation
131
+ if type(dilation) == int:
132
+ dilation = (dilation, dilation, 1)
133
+ elif len(dilation) == 2:
134
+ dilation = (dilation[0], dilation[1], 1)
135
+ else:
136
+ raise Exception('dilation can be either int or tuple of size 2. It was: {}'.format(dilation.shape))
137
+
138
+ super(tdConv, self).__init__(in_channels, out_channels, kernel, stride, padding, dilation, groups,
139
+ bias=bias)
140
+ self.bn = bn
141
+ self.spike = spike
142
+ self.is_first_conv = is_first_conv
143
+
144
+ def forward(self, x):
145
+ x = F.conv3d(x, self.weight, self.bias,
146
+ self.stride, self.padding, self.dilation, self.groups)
147
+ if self.bn is not None:
148
+ x = self.bn(x)
149
+ if self.spike is not None:
150
+ x = self.spike(x)
151
+ return x
152
+
153
+
154
+ class tdConvTranspose(nn.ConvTranspose3d):
155
+ def __init__(self,
156
+ in_channels,
157
+ out_channels,
158
+ kernel_size,
159
+ stride=1,
160
+ padding=0,
161
+ output_padding=0,
162
+ dilation=1,
163
+ groups=1,
164
+ bias=True,
165
+ bn=None,
166
+ spike=None):
167
+
168
+ # kernel
169
+ if type(kernel_size) == int:
170
+ kernel = (kernel_size, kernel_size, 1)
171
+ elif len(kernel_size) == 2:
172
+ kernel = (kernel_size[0], kernel_size[1], 1)
173
+ else:
174
+ raise Exception('kernelSize can only be of 1 or 2 dimension. It was: {}'.format(kernel_size.shape))
175
+
176
+ # stride
177
+ if type(stride) == int:
178
+ stride = (stride, stride, 1)
179
+ elif len(stride) == 2:
180
+ stride = (stride[0], stride[1], 1)
181
+ else:
182
+ raise Exception('stride can be either int or tuple of size 2. It was: {}'.format(stride.shape))
183
+
184
+ # padding
185
+ if type(padding) == int:
186
+ padding = (padding, padding, 0)
187
+ elif len(padding) == 2:
188
+ padding = (padding[0], padding[1], 0)
189
+ else:
190
+ raise Exception('padding can be either int or tuple of size 2. It was: {}'.format(padding.shape))
191
+
192
+ # dilation
193
+ if type(dilation) == int:
194
+ dilation = (dilation, dilation, 1)
195
+ elif len(dilation) == 2:
196
+ dilation = (dilation[0], dilation[1], 1)
197
+ else:
198
+ raise Exception('dilation can be either int or tuple of size 2. It was: {}'.format(dilation.shape))
199
+
200
+
201
+ # output padding
202
+ if type(output_padding) == int:
203
+ output_padding = (output_padding, output_padding, 0)
204
+ elif len(output_padding) == 2:
205
+ output_padding = (output_padding[0], output_padding[1], 0)
206
+ else:
207
+ raise Exception('output_padding can be either int or tuple of size 2. It was: {}'.format(padding.shape))
208
+
209
+ super().__init__(in_channels, out_channels, kernel, stride, padding, output_padding, groups,
210
+ bias=bias, dilation=dilation)
211
+
212
+ self.bn = bn
213
+ self.spike = spike
214
+
215
+ def forward(self, x):
216
+ x = F.conv_transpose3d(x, self.weight, self.bias,
217
+ self.stride, self.padding,
218
+ self.output_padding, self.groups, self.dilation)
219
+
220
+ if self.bn is not None:
221
+ x = self.bn(x)
222
+ if self.spike is not None:
223
+ x = self.spike(x)
224
+ return x
225
+
226
+ class tdBatchNorm(nn.BatchNorm2d):
227
+ """
228
+ Implementation of tdBN. Link to related paper: https://arxiv.org/pdf/2011.05280. In short it is averaged over the time domain as well when doing BN.
229
+ Args:
230
+ num_features (int): same with nn.BatchNorm2d
231
+ eps (float): same with nn.BatchNorm2d
232
+ momentum (float): same with nn.BatchNorm2d
233
+ alpha (float): an addtional parameter which may change in resblock.
234
+ affine (bool): same with nn.BatchNorm2d
235
+ track_running_stats (bool): same with nn.BatchNorm2d
236
+ """
237
+ def __init__(self, num_features, eps=1e-05, momentum=0.1, alpha=1, affine=True, track_running_stats=True):
238
+ super(tdBatchNorm, self).__init__(
239
+ num_features, eps, momentum, affine, track_running_stats)
240
+ self.alpha = alpha
241
+
242
+ def forward(self, input):
243
+ exponential_average_factor = 0.0
244
+
245
+ if self.training and self.track_running_stats:
246
+ if self.num_batches_tracked is not None:
247
+ self.num_batches_tracked += 1
248
+ if self.momentum is None: # use cumulative moving average
249
+ exponential_average_factor = 1.0 / float(self.num_batches_tracked)
250
+ else: # use exponential moving average
251
+ exponential_average_factor = self.momentum
252
+
253
+ # calculate running estimates
254
+ if self.training:
255
+ mean = input.mean([0, 2, 3, 4])
256
+ # use biased var in train
257
+ var = input.var([0, 2, 3, 4], unbiased=False)
258
+ n = input.numel() / input.size(1)
259
+ with torch.no_grad():
260
+ self.running_mean = exponential_average_factor * mean\
261
+ + (1 - exponential_average_factor) * self.running_mean
262
+ # update running_var with unbiased var
263
+ self.running_var = exponential_average_factor * var * n / (n - 1)\
264
+ + (1 - exponential_average_factor) * self.running_var
265
+ else:
266
+ mean = self.running_mean
267
+ var = self.running_var
268
+
269
+ input = self.alpha * Vth * (input - mean[None, :, None, None, None]) / (torch.sqrt(var[None, :, None, None, None] + self.eps))
270
+ if self.affine:
271
+ input = input * self.weight[None, :, None, None, None] + self.bias[None, :, None, None, None]
272
+
273
+ return input
274
+
275
+
276
+ class PSP(torch.nn.Module):
277
+ def __init__(self):
278
+ super().__init__()
279
+ self.tau_s = 2
280
+
281
+ def forward(self, inputs):
282
+ """
283
+ inputs: (N, C, T)
284
+ """
285
+ syns = None
286
+ syn = 0
287
+ n_steps = inputs.shape[-1]
288
+ for t in range(n_steps):
289
+ syn = syn + (inputs[...,t] - syn) / self.tau_s
290
+ if syns is None:
291
+ syns = syn.unsqueeze(-1)
292
+ else:
293
+ syns = torch.cat([syns, syn.unsqueeze(-1)], dim=-1)
294
+
295
+ return syns
296
+
297
+ class MembraneOutputLayer(nn.Module):
298
+ """
299
+ outputs the last time membrane potential of the LIF neuron with V_th=infty
300
+ """
301
+ def __init__(self) -> None:
302
+ super().__init__()
303
+ # n_steps = glv.n_steps
304
+ n_steps = 16
305
+
306
+ arr = torch.arange(n_steps-1,-1,-1)
307
+ self.register_buffer("coef", torch.pow(0.8, arr)[None,None,None,None,:]) # (1,1,1,1,T)
308
+
309
+ def forward(self, x):
310
+ """
311
+ x : (N,C,H,W,T)
312
+ """
313
+ out = torch.sum(x*self.coef, dim=-1)
314
+ return out
315
+
316
+ class PriorBernoulliSTBP(nn.Module):
317
+ def __init__(self, k=20) -> None:
318
+ """
319
+ modeling of p(z_t|z_<t)
320
+ """
321
+ super().__init__()
322
+ # self.channels = glv.network_config['latent_dim']
323
+ self.channels = 128
324
+ self.k = k
325
+ # self.n_steps = glv.network_config['n_steps']
326
+ self.n_steps = 16
327
+
328
+ self.layers = nn.Sequential(
329
+ tdLinear(self.channels,
330
+ self.channels*2,
331
+ bias=True,
332
+ bn=tdBatchNorm(self.channels*2, alpha=2),
333
+ spike=LIFSpike()),
334
+ tdLinear(self.channels*2,
335
+ self.channels*4,
336
+ bias=True,
337
+ bn=tdBatchNorm(self.channels*4, alpha=2),
338
+ spike=LIFSpike()),
339
+ tdLinear(self.channels*4,
340
+ self.channels*k,
341
+ bias=True,
342
+ bn=tdBatchNorm(self.channels*k, alpha=2),
343
+ spike=LIFSpike())
344
+ )
345
+ self.register_buffer('initial_input', torch.zeros(1, self.channels, 1))# (1,C,1)
346
+
347
+
348
+ def forward(self, z, scheduled=False, p=None):
349
+ if scheduled:
350
+ return self._forward_scheduled_sampling(z, p)
351
+ else:
352
+ return self._forward(z)
353
+
354
+ def _forward(self, z):
355
+ """
356
+ input z: (B,C,T) # latent spike sampled from posterior
357
+ output : (B,C,k,T) # indicates p(z_t|z_<t) (t=1,...,T)
358
+ """
359
+ z_shape = z.shape # (B,C,T)
360
+ batch_size = z_shape[0]
361
+ z = z.detach()
362
+
363
+ z0 = self.initial_input.repeat(batch_size, 1, 1) # (B,C,1)
364
+ inputs = torch.cat([z0, z[...,:-1]], dim=-1) # (B,C,T)
365
+ outputs = self.layers(inputs) # (B,C*k,T)
366
+
367
+ p_z = outputs.view(batch_size, self.channels, self.k, self.n_steps) # (B,C,k,T)
368
+ return p_z
369
+
370
+ def _forward_scheduled_sampling(self, z, p):
371
+ """
372
+ use scheduled sampling
373
+ input
374
+ z: (B,C,T) # latent spike sampled from posterior
375
+ p: float # prob of scheduled sampling
376
+ output : (B,C,k,T) # indicates p(z_t|z_<t) (t=1,...,T)
377
+ """
378
+ z_shape = z.shape # (B,C,T)
379
+ batch_size = z_shape[0]
380
+ z = z.detach()
381
+
382
+ z_t_minus = self.initial_input.repeat(batch_size,1,1) # z_<t, z0=zeros:(B,C,1)
383
+ if self.training:
384
+ with torch.no_grad():
385
+ for t in range(self.n_steps-1):
386
+ if t>=5 and random.random() < p: # scheduled sampling
387
+ outputs = self.layers(z_t_minus.detach()) #binary (B, C*k, t+1) z_<=t
388
+ p_z_t = outputs[...,-1] # (B, C*k, 1)
389
+ # sampling from p(z_t | z_<t)
390
+ prob1 = p_z_t.view(batch_size, self.channels, self.k).mean(-1) # (B,C)
391
+ prob1 = prob1 + 1e-3 * torch.randn_like(prob1)
392
+ z_t = (prob1>0.5).float() # (B,C)
393
+ z_t = z_t.view(batch_size, self.channels, 1) #(B,C,1)
394
+ z_t_minus = torch.cat([z_t_minus, z_t], dim=-1) # (B,C,t+2)
395
+ else:
396
+ z_t_minus = torch.cat([z_t_minus, z[...,t].unsqueeze(-1)], dim=-1) # (B,C,t+2)
397
+ else: # for test time
398
+ z_t_minus = torch.cat([z_t_minus, z[:,:,:-1]], dim=-1) # (B,C,T)
399
+
400
+ z_t_minus = z_t_minus.detach() # (B,C,T) z_{<=T-1}
401
+ p_z = self.layers(z_t_minus) # (B,C*k,T)
402
+ p_z = p_z.view(batch_size, self.channels, self.k, self.n_steps)# (B,C,k,T)
403
+ return p_z
404
+
405
+ def sample(self, batch_size=64):
406
+ z_minus_t = self.initial_input.repeat(batch_size, 1, 1) # (B, C, 1)
407
+ for t in range(self.n_steps):
408
+ outputs = self.layers(z_minus_t) # (B, C*k, t+1)
409
+ p_z_t = outputs[...,-1] # (B, C*k, 1)
410
+
411
+ random_index = torch.randint(0, self.k, (batch_size*self.channels,)) \
412
+ + torch.arange(start=0, end=batch_size*self.channels*self.k, step=self.k) #(B*C,) pick one from k
413
+ random_index = random_index.to(z_minus_t.device)
414
+
415
+ z_t = p_z_t.view(batch_size*self.channels*self.k)[random_index] # (B*C,)
416
+ z_t = z_t.view(batch_size, self.channels, 1) #(B,C,1)
417
+ z_minus_t = torch.cat([z_minus_t, z_t], dim=-1) # (B,C,t+2)
418
+
419
+
420
+ sampled_z = z_minus_t[...,1:] # (B,C,T)
421
+
422
+ return sampled_z
423
+
424
+ class PosteriorBernoulliSTBP(nn.Module):
425
+ def __init__(self, k=20) -> None:
426
+ """
427
+ modeling of q(z_t | x_<=t, z_<t)
428
+ """
429
+ super().__init__()
430
+ # self.channels = glv.network_config['latent_dim']
431
+ self.channels = 128
432
+ self.k = k
433
+ # self.n_steps = glv.network_config['n_steps']
434
+ self.n_steps = 16
435
+
436
+ self.layers = nn.Sequential(
437
+ tdLinear(self.channels*2,
438
+ self.channels*2,
439
+ bias=True,
440
+ bn=tdBatchNorm(self.channels*2, alpha=2),
441
+ spike=LIFSpike()),
442
+ tdLinear(self.channels*2,
443
+ self.channels*4,
444
+ bias=True,
445
+ bn=tdBatchNorm(self.channels*4, alpha=2),
446
+ spike=LIFSpike()),
447
+ tdLinear(self.channels*4,
448
+ self.channels*k,
449
+ bias=True,
450
+ bn=tdBatchNorm(self.channels*k, alpha=2),
451
+ spike=LIFSpike())
452
+ )
453
+ self.register_buffer('initial_input', torch.zeros(1, self.channels, 1))# (1,C,1)
454
+
455
+ self.is_true_scheduled_sampling = True
456
+
457
+ def forward(self, x):
458
+ """
459
+ input:
460
+ x:(B,C,T)
461
+ returns:
462
+ sampled_z:(B,C,T)
463
+ q_z: (B,C,k,T) # indicates q(z_t | x_<=t, z_<t) (t=1,...,T)
464
+ """
465
+ x_shape = x.shape # (B,C,T)
466
+ batch_size=x_shape[0]
467
+ random_indices = []
468
+ # sample z inadvance without gradient
469
+ with torch.no_grad():
470
+ z_t_minus = self.initial_input.repeat(x_shape[0],1,1) # z_<t z0=zeros:(B,C,1)
471
+ for t in range(self.n_steps-1):
472
+ inputs = torch.cat([x[...,:t+1].detach(), z_t_minus.detach()], dim=1) # (B,C+C,t+1) x_<=t and z_<t
473
+ outputs = self.layers(inputs) #(B, C*k, t+1)
474
+ q_z_t = outputs[...,-1] # (B, C*k, 1) q(z_t | x_<=t, z_<t)
475
+
476
+ # sampling from q(z_t | x_<=t, z_<t)
477
+ random_index = torch.randint(0, self.k, (batch_size*self.channels,)) \
478
+ + torch.arange(start=0, end=batch_size*self.channels*self.k, step=self.k) #(B*C,) select 1 from every k value
479
+ random_index = random_index.to(x.device)
480
+ random_indices.append(random_index)
481
+
482
+ z_t = q_z_t.view(batch_size*self.channels*self.k)[random_index] # (B*C,)
483
+ z_t = z_t.view(batch_size, self.channels, 1) #(B,C,1)
484
+
485
+ z_t_minus = torch.cat([z_t_minus, z_t], dim=-1) # (B,C,t+2)
486
+
487
+ z_t_minus = z_t_minus.detach() # (B,C,T) z_0,...,z_{T-1}
488
+ q_z = self.layers(torch.cat([x, z_t_minus], dim=1)) # (B,C*k,T)
489
+
490
+ # input z_t_minus again to calculate tdBN
491
+ sampled_z = None
492
+ for t in range(self.n_steps):
493
+
494
+ if t == self.n_steps-1:
495
+ # when t=T
496
+ random_index = torch.randint(0, self.k, (batch_size*self.channels,)) \
497
+ + torch.arange(start=0, end=batch_size*self.channels*self.k, step=self.k)
498
+ random_indices.append(random_index)
499
+ else:
500
+ # when t<=T-1
501
+ random_index = random_indices[t]
502
+
503
+ # sampling
504
+ sampled_z_t = q_z[...,t].view(batch_size*self.channels*self.k)[random_index] # (B*C,)
505
+ sampled_z_t = sampled_z_t.view(batch_size, self.channels, 1) #(B,C,1)
506
+ if t==0:
507
+ sampled_z = sampled_z_t
508
+ else:
509
+ sampled_z = torch.cat([sampled_z, sampled_z_t], dim=-1)
510
+
511
+ q_z = q_z.view(batch_size, self.channels, self.k, self.n_steps)# (B,C,k,T)
512
+
513
+ return sampled_z, q_z
514
+
515
+
516
+ class FSAEModel(PreTrainedModel):
517
+ config_class = FSAEConfig
518
+
519
+ def __init__(self, config):
520
+ super().__init__(config)
521
+
522
+ self.in_channels = config.in_channels
523
+ in_channels = self.in_channels
524
+
525
+ self.hidden_dims = config.hidden_dims
526
+ hidden_dims = self.hidden_dims
527
+
528
+ self.latent_dim = config.latent_dim
529
+ latent_dim = self.latent_dim
530
+
531
+ self.n_steps = config.n_steps
532
+ n_steps = self.n_steps
533
+
534
+ self.k = config.k
535
+ k = self.k
536
+
537
+ # Build Encoder
538
+ modules = []
539
+ is_first_conv = True
540
+ for h_dim in hidden_dims:
541
+ modules.append(
542
+ tdConv(
543
+ in_channels,
544
+ out_channels=h_dim,
545
+ kernel_size=3,
546
+ stride=2,
547
+ padding=1,
548
+ bias=True,
549
+ bn=tdBatchNorm(h_dim),
550
+ spike=LIFSpike(),
551
+ is_first_conv=is_first_conv,
552
+ )
553
+ )
554
+ in_channels = h_dim
555
+ is_first_conv = False
556
+
557
+ self.encoder = nn.Sequential(*modules)
558
+ self.before_latent_layer = tdLinear(
559
+ hidden_dims[-1] * 4,
560
+ latent_dim,
561
+ bias=True,
562
+ bn=tdBatchNorm(latent_dim),
563
+ spike=LIFSpike(),
564
+ )
565
+
566
+ # Build Decoder
567
+ modules = []
568
+
569
+ self.decoder_input = tdLinear(
570
+ latent_dim,
571
+ hidden_dims[-1] * 4,
572
+ bias=True,
573
+ bn=tdBatchNorm(hidden_dims[-1] * 4),
574
+ spike=LIFSpike(),
575
+ )
576
+
577
+ hidden_reverse = hidden_dims[::-1]
578
+
579
+ for i in range(len(hidden_reverse) - 1):
580
+ modules.append(
581
+ tdConvTranspose(
582
+ hidden_reverse[i],
583
+ hidden_reverse[i + 1],
584
+ kernel_size=3,
585
+ stride=2,
586
+ padding=1,
587
+ output_padding=1,
588
+ bias=True,
589
+ bn=tdBatchNorm(hidden_reverse[i + 1]),
590
+ spike=LIFSpike(),
591
+ )
592
+ )
593
+ self.decoder = nn.Sequential(*modules)
594
+
595
+ self.final_layer = nn.Sequential(
596
+ tdConvTranspose(
597
+ hidden_reverse[-1],
598
+ hidden_reverse[-1],
599
+ kernel_size=3,
600
+ stride=2,
601
+ padding=1,
602
+ output_padding=1,
603
+ bias=True,
604
+ bn=tdBatchNorm(hidden_reverse[-1]),
605
+ spike=LIFSpike(),
606
+ ),
607
+ tdConvTranspose(
608
+ hidden_reverse[-1],
609
+ out_channels=1,
610
+ kernel_size=3,
611
+ padding=1,
612
+ bias=True,
613
+ bn=None,
614
+ spike=None,
615
+ ),
616
+ )
617
+
618
+ self.p = 0
619
+
620
+ self.membrane_output_layer = MembraneOutputLayer()
621
+
622
+ def forward(self, x, scheduled=False):
623
+ sampled_z = self.encode(x, scheduled)
624
+ x_recon = self.decode(sampled_z)
625
+ return x_recon, sampled_z
626
+
627
+ def encode(self, x, scheduled=False):
628
+ x = self.encoder(x) # (N,C,H,W,T)
629
+ x = torch.flatten(x, start_dim=1, end_dim=3) # (N,C*H*W,T)
630
+ latent_x = self.before_latent_layer(x) # (N,latent_dim,T)
631
+ return latent_x
632
+
633
+ def decode(self, z):
634
+ result = self.decoder_input(z) # (N,C*H*W,T)
635
+ result = result.view(
636
+ result.shape[0], self.hidden_dims[-1], 2, 2, self.n_steps
637
+ ) # (N,C,H,W,T)
638
+ result = self.decoder(result) # (N,C,H,W,T)
639
+ result = self.final_layer(result) # (N,C,H,W,T)
640
+ out = torch.tanh(self.membrane_output_layer(result))
641
+ return out
642
+
643
+ def sample(self, batch_size=64):
644
+ raise NotImplementedError()
645
+
646
+ def loss_function(self, recons_img, input_img):
647
+ """
648
+ Computes the VAE loss function.
649
+ KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
650
+ :param args:
651
+ :param kwargs:
652
+ :return:
653
+ """
654
+
655
+ recons_loss = F.mse_loss(recons_img, input_img)
656
+
657
+ return recons_loss