mickylan2367
commited on
Commit
•
ba54498
1
Parent(s):
a4f7cb6
Upload model
Browse files- config.json +28 -0
- configuration_fsae.py +65 -0
- model.safetensors +3 -0
- modeling_fsae.py +657 -0
config.json
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"Vth": 0.2,
|
3 |
+
"a": 0.25,
|
4 |
+
"aa": 0.5,
|
5 |
+
"architectures": [
|
6 |
+
"FSAEModel"
|
7 |
+
],
|
8 |
+
"auto_map": {
|
9 |
+
"AutoConfig": "configuration_fsae.FSAEConfig",
|
10 |
+
"AutoModel": "modeling_fsae.FSAEModel"
|
11 |
+
},
|
12 |
+
"dt": 5,
|
13 |
+
"hidden_dims": [
|
14 |
+
32,
|
15 |
+
64,
|
16 |
+
128,
|
17 |
+
256
|
18 |
+
],
|
19 |
+
"in_channels": 1,
|
20 |
+
"k": 20,
|
21 |
+
"latent_dim": 128,
|
22 |
+
"model_type": "fsae",
|
23 |
+
"n_steps": 16,
|
24 |
+
"scheduled": true,
|
25 |
+
"tau": 0.25,
|
26 |
+
"torch_dtype": "float32",
|
27 |
+
"transformers_version": "4.35.0"
|
28 |
+
}
|
configuration_fsae.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# pip install transformers
|
2 |
+
from transformers import PretrainedConfig
|
3 |
+
from typing import List
|
4 |
+
|
5 |
+
|
6 |
+
'''
|
7 |
+
newtwork_config = {
|
8 |
+
"epochs": 150,
|
9 |
+
"batch_size": 250,
|
10 |
+
"n_steps": 16, # timestep
|
11 |
+
"dataset": "CAPS",
|
12 |
+
"in_channels": 1,
|
13 |
+
"data_path": "./data",
|
14 |
+
"lr": 0.001,
|
15 |
+
"n_class": 10,
|
16 |
+
"latent_dim": 128,
|
17 |
+
"input_size": 32,
|
18 |
+
"model": "FSVAE" ,# FSVAE or FSVAE_large
|
19 |
+
"k": 20, # multiplier of channel
|
20 |
+
"scheduled": True, # whether to apply scheduled sampling
|
21 |
+
"loss_func": 'kld', # mmd or kld
|
22 |
+
"accum_iter" : 1,
|
23 |
+
"devices": [0],
|
24 |
+
}
|
25 |
+
|
26 |
+
hidden_dims = [32, 64, 128, 256]
|
27 |
+
|
28 |
+
'''
|
29 |
+
|
30 |
+
class FSAEConfig(PretrainedConfig):
|
31 |
+
model_type = "fsae"
|
32 |
+
|
33 |
+
def __init__(
|
34 |
+
self,
|
35 |
+
in_channels: int = 1,
|
36 |
+
hidden_dims : List[int] = [32, 64, 128, 256],
|
37 |
+
k : int = 20,
|
38 |
+
n_steps : int = 16,
|
39 |
+
latent_dim : int = 128,
|
40 |
+
scheduled : bool = True,
|
41 |
+
# loss_func : str = "kld",
|
42 |
+
dt:float = 5,
|
43 |
+
a:float = 0.25,
|
44 |
+
aa: float = 0.5,
|
45 |
+
Vth : float = 0.2, # しきい値電位
|
46 |
+
tau : float = 0.25,
|
47 |
+
**kwargs,
|
48 |
+
):
|
49 |
+
# if block_type not in ["basic", "bottleneck"]:
|
50 |
+
# raise ValueError(f"`block_type` must be 'basic' or bottleneck', got {block_type}.")
|
51 |
+
# if stem_type not in ["", "deep", "deep-tiered"]:
|
52 |
+
# raise ValueError(f"`stem_type` must be '', 'deep' or 'deep-tiered', got {stem_type}.")
|
53 |
+
|
54 |
+
self.in_channels = in_channels
|
55 |
+
self.hidden_dims = hidden_dims
|
56 |
+
self.k = k
|
57 |
+
self.n_steps = n_steps
|
58 |
+
self.latent_dim = latent_dim
|
59 |
+
self.scheduled = scheduled
|
60 |
+
self.dt = dt
|
61 |
+
self.a = a
|
62 |
+
self.aa = aa
|
63 |
+
self.Vth = Vth
|
64 |
+
self.tau = tau
|
65 |
+
super().__init__(**kwargs)
|
model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:df404010e0e2a77cf66e905d26b0b356397d8cfe61db32fddcca723319cb260b
|
3 |
+
size 4228636
|
modeling_fsae.py
ADDED
@@ -0,0 +1,657 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import torch.optim as optim
|
5 |
+
from argparse import ZERO_OR_MORE
|
6 |
+
import math
|
7 |
+
import random
|
8 |
+
from torch.nn.modules.module import T
|
9 |
+
|
10 |
+
from transformers import PreTrainedModel
|
11 |
+
from .configuration_fsae import FSAEConfig
|
12 |
+
|
13 |
+
dt = 5
|
14 |
+
a = 0.25
|
15 |
+
aa = 0.5
|
16 |
+
Vth = 0.2
|
17 |
+
tau = 0.25
|
18 |
+
|
19 |
+
|
20 |
+
class SpikeAct(torch.autograd.Function):
|
21 |
+
"""
|
22 |
+
Implementation of the spiking activation function with an approximation of gradient.
|
23 |
+
"""
|
24 |
+
@staticmethod
|
25 |
+
def forward(ctx, input):
|
26 |
+
ctx.save_for_backward(input)
|
27 |
+
# if input = u > Vth then output = 1
|
28 |
+
output = torch.gt(input, Vth)
|
29 |
+
return output.float()
|
30 |
+
|
31 |
+
@staticmethod
|
32 |
+
def backward(ctx, grad_output):
|
33 |
+
input, = ctx.saved_tensors
|
34 |
+
grad_input = grad_output.clone()
|
35 |
+
# hu is an approximate func of df/du
|
36 |
+
hu = abs(input) < aa
|
37 |
+
hu = hu.float() / (2 * aa)
|
38 |
+
return grad_input * hu
|
39 |
+
|
40 |
+
class LIFSpike(nn.Module):
|
41 |
+
"""
|
42 |
+
Generates spikes based on LIF module. It can be considered as an activation function and is used similar to ReLU. The input tensor needs to have an additional time dimension, which in this case is on the last dimension of the data.
|
43 |
+
"""
|
44 |
+
def __init__(self):
|
45 |
+
super(LIFSpike, self).__init__()
|
46 |
+
|
47 |
+
def forward(self, x):
|
48 |
+
nsteps = x.shape[-1]
|
49 |
+
u = torch.zeros(x.shape[:-1] , device=x.device)
|
50 |
+
out = torch.zeros(x.shape, device=x.device)
|
51 |
+
for step in range(nsteps):
|
52 |
+
u, out[..., step] = self.state_update(u, out[..., max(step-1, 0)], x[..., step])
|
53 |
+
return out
|
54 |
+
|
55 |
+
def state_update(self, u_t_n1, o_t_n1, W_mul_o_t1_n, tau=tau):
|
56 |
+
u_t1_n1 = tau * u_t_n1 * (1 - o_t_n1) + W_mul_o_t1_n
|
57 |
+
o_t1_n1 = SpikeAct.apply(u_t1_n1)
|
58 |
+
return u_t1_n1, o_t1_n1
|
59 |
+
|
60 |
+
class tdLinear(nn.Linear):
|
61 |
+
def __init__(self,
|
62 |
+
in_features,
|
63 |
+
out_features,
|
64 |
+
bias=True,
|
65 |
+
bn=None,
|
66 |
+
spike=None):
|
67 |
+
assert type(in_features) == int, 'inFeatures should not be more than 1 dimesnion. It was: {}'.format(in_features.shape)
|
68 |
+
assert type(out_features) == int, 'outFeatures should not be more than 1 dimesnion. It was: {}'.format(out_features.shape)
|
69 |
+
|
70 |
+
super(tdLinear, self).__init__(in_features, out_features, bias=bias)
|
71 |
+
|
72 |
+
self.bn = bn
|
73 |
+
self.spike = spike
|
74 |
+
|
75 |
+
|
76 |
+
def forward(self, x):
|
77 |
+
"""
|
78 |
+
x : (N,C,T)
|
79 |
+
"""
|
80 |
+
x = x.transpose(1, 2) # (N, T, C)
|
81 |
+
y = F.linear(x, self.weight, self.bias)
|
82 |
+
y = y.transpose(1, 2)# (N, C, T)
|
83 |
+
|
84 |
+
if self.bn is not None:
|
85 |
+
y = y[:,:,None,None,:]
|
86 |
+
y = self.bn(y)
|
87 |
+
y = y[:,:,0,0,:]
|
88 |
+
if self.spike is not None:
|
89 |
+
y = self.spike(y)
|
90 |
+
return y
|
91 |
+
|
92 |
+
class tdConv(nn.Conv3d):
|
93 |
+
def __init__(self,
|
94 |
+
in_channels,
|
95 |
+
out_channels,
|
96 |
+
kernel_size,
|
97 |
+
stride=1,
|
98 |
+
padding=0,
|
99 |
+
dilation=1,
|
100 |
+
groups=1,
|
101 |
+
bias=True,
|
102 |
+
bn=None,
|
103 |
+
spike=None,
|
104 |
+
is_first_conv=False):
|
105 |
+
|
106 |
+
# kernel
|
107 |
+
if type(kernel_size) == int:
|
108 |
+
kernel = (kernel_size, kernel_size, 1)
|
109 |
+
elif len(kernel_size) == 2:
|
110 |
+
kernel = (kernel_size[0], kernel_size[1], 1)
|
111 |
+
else:
|
112 |
+
raise Exception('kernelSize can only be of 1 or 2 dimension. It was: {}'.format(kernel_size.shape))
|
113 |
+
|
114 |
+
# stride
|
115 |
+
if type(stride) == int:
|
116 |
+
stride = (stride, stride, 1)
|
117 |
+
elif len(stride) == 2:
|
118 |
+
stride = (stride[0], stride[1], 1)
|
119 |
+
else:
|
120 |
+
raise Exception('stride can be either int or tuple of size 2. It was: {}'.format(stride.shape))
|
121 |
+
|
122 |
+
# padding
|
123 |
+
if type(padding) == int:
|
124 |
+
padding = (padding, padding, 0)
|
125 |
+
elif len(padding) == 2:
|
126 |
+
padding = (padding[0], padding[1], 0)
|
127 |
+
else:
|
128 |
+
raise Exception('padding can be either int or tuple of size 2. It was: {}'.format(padding.shape))
|
129 |
+
|
130 |
+
# dilation
|
131 |
+
if type(dilation) == int:
|
132 |
+
dilation = (dilation, dilation, 1)
|
133 |
+
elif len(dilation) == 2:
|
134 |
+
dilation = (dilation[0], dilation[1], 1)
|
135 |
+
else:
|
136 |
+
raise Exception('dilation can be either int or tuple of size 2. It was: {}'.format(dilation.shape))
|
137 |
+
|
138 |
+
super(tdConv, self).__init__(in_channels, out_channels, kernel, stride, padding, dilation, groups,
|
139 |
+
bias=bias)
|
140 |
+
self.bn = bn
|
141 |
+
self.spike = spike
|
142 |
+
self.is_first_conv = is_first_conv
|
143 |
+
|
144 |
+
def forward(self, x):
|
145 |
+
x = F.conv3d(x, self.weight, self.bias,
|
146 |
+
self.stride, self.padding, self.dilation, self.groups)
|
147 |
+
if self.bn is not None:
|
148 |
+
x = self.bn(x)
|
149 |
+
if self.spike is not None:
|
150 |
+
x = self.spike(x)
|
151 |
+
return x
|
152 |
+
|
153 |
+
|
154 |
+
class tdConvTranspose(nn.ConvTranspose3d):
|
155 |
+
def __init__(self,
|
156 |
+
in_channels,
|
157 |
+
out_channels,
|
158 |
+
kernel_size,
|
159 |
+
stride=1,
|
160 |
+
padding=0,
|
161 |
+
output_padding=0,
|
162 |
+
dilation=1,
|
163 |
+
groups=1,
|
164 |
+
bias=True,
|
165 |
+
bn=None,
|
166 |
+
spike=None):
|
167 |
+
|
168 |
+
# kernel
|
169 |
+
if type(kernel_size) == int:
|
170 |
+
kernel = (kernel_size, kernel_size, 1)
|
171 |
+
elif len(kernel_size) == 2:
|
172 |
+
kernel = (kernel_size[0], kernel_size[1], 1)
|
173 |
+
else:
|
174 |
+
raise Exception('kernelSize can only be of 1 or 2 dimension. It was: {}'.format(kernel_size.shape))
|
175 |
+
|
176 |
+
# stride
|
177 |
+
if type(stride) == int:
|
178 |
+
stride = (stride, stride, 1)
|
179 |
+
elif len(stride) == 2:
|
180 |
+
stride = (stride[0], stride[1], 1)
|
181 |
+
else:
|
182 |
+
raise Exception('stride can be either int or tuple of size 2. It was: {}'.format(stride.shape))
|
183 |
+
|
184 |
+
# padding
|
185 |
+
if type(padding) == int:
|
186 |
+
padding = (padding, padding, 0)
|
187 |
+
elif len(padding) == 2:
|
188 |
+
padding = (padding[0], padding[1], 0)
|
189 |
+
else:
|
190 |
+
raise Exception('padding can be either int or tuple of size 2. It was: {}'.format(padding.shape))
|
191 |
+
|
192 |
+
# dilation
|
193 |
+
if type(dilation) == int:
|
194 |
+
dilation = (dilation, dilation, 1)
|
195 |
+
elif len(dilation) == 2:
|
196 |
+
dilation = (dilation[0], dilation[1], 1)
|
197 |
+
else:
|
198 |
+
raise Exception('dilation can be either int or tuple of size 2. It was: {}'.format(dilation.shape))
|
199 |
+
|
200 |
+
|
201 |
+
# output padding
|
202 |
+
if type(output_padding) == int:
|
203 |
+
output_padding = (output_padding, output_padding, 0)
|
204 |
+
elif len(output_padding) == 2:
|
205 |
+
output_padding = (output_padding[0], output_padding[1], 0)
|
206 |
+
else:
|
207 |
+
raise Exception('output_padding can be either int or tuple of size 2. It was: {}'.format(padding.shape))
|
208 |
+
|
209 |
+
super().__init__(in_channels, out_channels, kernel, stride, padding, output_padding, groups,
|
210 |
+
bias=bias, dilation=dilation)
|
211 |
+
|
212 |
+
self.bn = bn
|
213 |
+
self.spike = spike
|
214 |
+
|
215 |
+
def forward(self, x):
|
216 |
+
x = F.conv_transpose3d(x, self.weight, self.bias,
|
217 |
+
self.stride, self.padding,
|
218 |
+
self.output_padding, self.groups, self.dilation)
|
219 |
+
|
220 |
+
if self.bn is not None:
|
221 |
+
x = self.bn(x)
|
222 |
+
if self.spike is not None:
|
223 |
+
x = self.spike(x)
|
224 |
+
return x
|
225 |
+
|
226 |
+
class tdBatchNorm(nn.BatchNorm2d):
|
227 |
+
"""
|
228 |
+
Implementation of tdBN. Link to related paper: https://arxiv.org/pdf/2011.05280. In short it is averaged over the time domain as well when doing BN.
|
229 |
+
Args:
|
230 |
+
num_features (int): same with nn.BatchNorm2d
|
231 |
+
eps (float): same with nn.BatchNorm2d
|
232 |
+
momentum (float): same with nn.BatchNorm2d
|
233 |
+
alpha (float): an addtional parameter which may change in resblock.
|
234 |
+
affine (bool): same with nn.BatchNorm2d
|
235 |
+
track_running_stats (bool): same with nn.BatchNorm2d
|
236 |
+
"""
|
237 |
+
def __init__(self, num_features, eps=1e-05, momentum=0.1, alpha=1, affine=True, track_running_stats=True):
|
238 |
+
super(tdBatchNorm, self).__init__(
|
239 |
+
num_features, eps, momentum, affine, track_running_stats)
|
240 |
+
self.alpha = alpha
|
241 |
+
|
242 |
+
def forward(self, input):
|
243 |
+
exponential_average_factor = 0.0
|
244 |
+
|
245 |
+
if self.training and self.track_running_stats:
|
246 |
+
if self.num_batches_tracked is not None:
|
247 |
+
self.num_batches_tracked += 1
|
248 |
+
if self.momentum is None: # use cumulative moving average
|
249 |
+
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
|
250 |
+
else: # use exponential moving average
|
251 |
+
exponential_average_factor = self.momentum
|
252 |
+
|
253 |
+
# calculate running estimates
|
254 |
+
if self.training:
|
255 |
+
mean = input.mean([0, 2, 3, 4])
|
256 |
+
# use biased var in train
|
257 |
+
var = input.var([0, 2, 3, 4], unbiased=False)
|
258 |
+
n = input.numel() / input.size(1)
|
259 |
+
with torch.no_grad():
|
260 |
+
self.running_mean = exponential_average_factor * mean\
|
261 |
+
+ (1 - exponential_average_factor) * self.running_mean
|
262 |
+
# update running_var with unbiased var
|
263 |
+
self.running_var = exponential_average_factor * var * n / (n - 1)\
|
264 |
+
+ (1 - exponential_average_factor) * self.running_var
|
265 |
+
else:
|
266 |
+
mean = self.running_mean
|
267 |
+
var = self.running_var
|
268 |
+
|
269 |
+
input = self.alpha * Vth * (input - mean[None, :, None, None, None]) / (torch.sqrt(var[None, :, None, None, None] + self.eps))
|
270 |
+
if self.affine:
|
271 |
+
input = input * self.weight[None, :, None, None, None] + self.bias[None, :, None, None, None]
|
272 |
+
|
273 |
+
return input
|
274 |
+
|
275 |
+
|
276 |
+
class PSP(torch.nn.Module):
|
277 |
+
def __init__(self):
|
278 |
+
super().__init__()
|
279 |
+
self.tau_s = 2
|
280 |
+
|
281 |
+
def forward(self, inputs):
|
282 |
+
"""
|
283 |
+
inputs: (N, C, T)
|
284 |
+
"""
|
285 |
+
syns = None
|
286 |
+
syn = 0
|
287 |
+
n_steps = inputs.shape[-1]
|
288 |
+
for t in range(n_steps):
|
289 |
+
syn = syn + (inputs[...,t] - syn) / self.tau_s
|
290 |
+
if syns is None:
|
291 |
+
syns = syn.unsqueeze(-1)
|
292 |
+
else:
|
293 |
+
syns = torch.cat([syns, syn.unsqueeze(-1)], dim=-1)
|
294 |
+
|
295 |
+
return syns
|
296 |
+
|
297 |
+
class MembraneOutputLayer(nn.Module):
|
298 |
+
"""
|
299 |
+
outputs the last time membrane potential of the LIF neuron with V_th=infty
|
300 |
+
"""
|
301 |
+
def __init__(self) -> None:
|
302 |
+
super().__init__()
|
303 |
+
# n_steps = glv.n_steps
|
304 |
+
n_steps = 16
|
305 |
+
|
306 |
+
arr = torch.arange(n_steps-1,-1,-1)
|
307 |
+
self.register_buffer("coef", torch.pow(0.8, arr)[None,None,None,None,:]) # (1,1,1,1,T)
|
308 |
+
|
309 |
+
def forward(self, x):
|
310 |
+
"""
|
311 |
+
x : (N,C,H,W,T)
|
312 |
+
"""
|
313 |
+
out = torch.sum(x*self.coef, dim=-1)
|
314 |
+
return out
|
315 |
+
|
316 |
+
class PriorBernoulliSTBP(nn.Module):
|
317 |
+
def __init__(self, k=20) -> None:
|
318 |
+
"""
|
319 |
+
modeling of p(z_t|z_<t)
|
320 |
+
"""
|
321 |
+
super().__init__()
|
322 |
+
# self.channels = glv.network_config['latent_dim']
|
323 |
+
self.channels = 128
|
324 |
+
self.k = k
|
325 |
+
# self.n_steps = glv.network_config['n_steps']
|
326 |
+
self.n_steps = 16
|
327 |
+
|
328 |
+
self.layers = nn.Sequential(
|
329 |
+
tdLinear(self.channels,
|
330 |
+
self.channels*2,
|
331 |
+
bias=True,
|
332 |
+
bn=tdBatchNorm(self.channels*2, alpha=2),
|
333 |
+
spike=LIFSpike()),
|
334 |
+
tdLinear(self.channels*2,
|
335 |
+
self.channels*4,
|
336 |
+
bias=True,
|
337 |
+
bn=tdBatchNorm(self.channels*4, alpha=2),
|
338 |
+
spike=LIFSpike()),
|
339 |
+
tdLinear(self.channels*4,
|
340 |
+
self.channels*k,
|
341 |
+
bias=True,
|
342 |
+
bn=tdBatchNorm(self.channels*k, alpha=2),
|
343 |
+
spike=LIFSpike())
|
344 |
+
)
|
345 |
+
self.register_buffer('initial_input', torch.zeros(1, self.channels, 1))# (1,C,1)
|
346 |
+
|
347 |
+
|
348 |
+
def forward(self, z, scheduled=False, p=None):
|
349 |
+
if scheduled:
|
350 |
+
return self._forward_scheduled_sampling(z, p)
|
351 |
+
else:
|
352 |
+
return self._forward(z)
|
353 |
+
|
354 |
+
def _forward(self, z):
|
355 |
+
"""
|
356 |
+
input z: (B,C,T) # latent spike sampled from posterior
|
357 |
+
output : (B,C,k,T) # indicates p(z_t|z_<t) (t=1,...,T)
|
358 |
+
"""
|
359 |
+
z_shape = z.shape # (B,C,T)
|
360 |
+
batch_size = z_shape[0]
|
361 |
+
z = z.detach()
|
362 |
+
|
363 |
+
z0 = self.initial_input.repeat(batch_size, 1, 1) # (B,C,1)
|
364 |
+
inputs = torch.cat([z0, z[...,:-1]], dim=-1) # (B,C,T)
|
365 |
+
outputs = self.layers(inputs) # (B,C*k,T)
|
366 |
+
|
367 |
+
p_z = outputs.view(batch_size, self.channels, self.k, self.n_steps) # (B,C,k,T)
|
368 |
+
return p_z
|
369 |
+
|
370 |
+
def _forward_scheduled_sampling(self, z, p):
|
371 |
+
"""
|
372 |
+
use scheduled sampling
|
373 |
+
input
|
374 |
+
z: (B,C,T) # latent spike sampled from posterior
|
375 |
+
p: float # prob of scheduled sampling
|
376 |
+
output : (B,C,k,T) # indicates p(z_t|z_<t) (t=1,...,T)
|
377 |
+
"""
|
378 |
+
z_shape = z.shape # (B,C,T)
|
379 |
+
batch_size = z_shape[0]
|
380 |
+
z = z.detach()
|
381 |
+
|
382 |
+
z_t_minus = self.initial_input.repeat(batch_size,1,1) # z_<t, z0=zeros:(B,C,1)
|
383 |
+
if self.training:
|
384 |
+
with torch.no_grad():
|
385 |
+
for t in range(self.n_steps-1):
|
386 |
+
if t>=5 and random.random() < p: # scheduled sampling
|
387 |
+
outputs = self.layers(z_t_minus.detach()) #binary (B, C*k, t+1) z_<=t
|
388 |
+
p_z_t = outputs[...,-1] # (B, C*k, 1)
|
389 |
+
# sampling from p(z_t | z_<t)
|
390 |
+
prob1 = p_z_t.view(batch_size, self.channels, self.k).mean(-1) # (B,C)
|
391 |
+
prob1 = prob1 + 1e-3 * torch.randn_like(prob1)
|
392 |
+
z_t = (prob1>0.5).float() # (B,C)
|
393 |
+
z_t = z_t.view(batch_size, self.channels, 1) #(B,C,1)
|
394 |
+
z_t_minus = torch.cat([z_t_minus, z_t], dim=-1) # (B,C,t+2)
|
395 |
+
else:
|
396 |
+
z_t_minus = torch.cat([z_t_minus, z[...,t].unsqueeze(-1)], dim=-1) # (B,C,t+2)
|
397 |
+
else: # for test time
|
398 |
+
z_t_minus = torch.cat([z_t_minus, z[:,:,:-1]], dim=-1) # (B,C,T)
|
399 |
+
|
400 |
+
z_t_minus = z_t_minus.detach() # (B,C,T) z_{<=T-1}
|
401 |
+
p_z = self.layers(z_t_minus) # (B,C*k,T)
|
402 |
+
p_z = p_z.view(batch_size, self.channels, self.k, self.n_steps)# (B,C,k,T)
|
403 |
+
return p_z
|
404 |
+
|
405 |
+
def sample(self, batch_size=64):
|
406 |
+
z_minus_t = self.initial_input.repeat(batch_size, 1, 1) # (B, C, 1)
|
407 |
+
for t in range(self.n_steps):
|
408 |
+
outputs = self.layers(z_minus_t) # (B, C*k, t+1)
|
409 |
+
p_z_t = outputs[...,-1] # (B, C*k, 1)
|
410 |
+
|
411 |
+
random_index = torch.randint(0, self.k, (batch_size*self.channels,)) \
|
412 |
+
+ torch.arange(start=0, end=batch_size*self.channels*self.k, step=self.k) #(B*C,) pick one from k
|
413 |
+
random_index = random_index.to(z_minus_t.device)
|
414 |
+
|
415 |
+
z_t = p_z_t.view(batch_size*self.channels*self.k)[random_index] # (B*C,)
|
416 |
+
z_t = z_t.view(batch_size, self.channels, 1) #(B,C,1)
|
417 |
+
z_minus_t = torch.cat([z_minus_t, z_t], dim=-1) # (B,C,t+2)
|
418 |
+
|
419 |
+
|
420 |
+
sampled_z = z_minus_t[...,1:] # (B,C,T)
|
421 |
+
|
422 |
+
return sampled_z
|
423 |
+
|
424 |
+
class PosteriorBernoulliSTBP(nn.Module):
|
425 |
+
def __init__(self, k=20) -> None:
|
426 |
+
"""
|
427 |
+
modeling of q(z_t | x_<=t, z_<t)
|
428 |
+
"""
|
429 |
+
super().__init__()
|
430 |
+
# self.channels = glv.network_config['latent_dim']
|
431 |
+
self.channels = 128
|
432 |
+
self.k = k
|
433 |
+
# self.n_steps = glv.network_config['n_steps']
|
434 |
+
self.n_steps = 16
|
435 |
+
|
436 |
+
self.layers = nn.Sequential(
|
437 |
+
tdLinear(self.channels*2,
|
438 |
+
self.channels*2,
|
439 |
+
bias=True,
|
440 |
+
bn=tdBatchNorm(self.channels*2, alpha=2),
|
441 |
+
spike=LIFSpike()),
|
442 |
+
tdLinear(self.channels*2,
|
443 |
+
self.channels*4,
|
444 |
+
bias=True,
|
445 |
+
bn=tdBatchNorm(self.channels*4, alpha=2),
|
446 |
+
spike=LIFSpike()),
|
447 |
+
tdLinear(self.channels*4,
|
448 |
+
self.channels*k,
|
449 |
+
bias=True,
|
450 |
+
bn=tdBatchNorm(self.channels*k, alpha=2),
|
451 |
+
spike=LIFSpike())
|
452 |
+
)
|
453 |
+
self.register_buffer('initial_input', torch.zeros(1, self.channels, 1))# (1,C,1)
|
454 |
+
|
455 |
+
self.is_true_scheduled_sampling = True
|
456 |
+
|
457 |
+
def forward(self, x):
|
458 |
+
"""
|
459 |
+
input:
|
460 |
+
x:(B,C,T)
|
461 |
+
returns:
|
462 |
+
sampled_z:(B,C,T)
|
463 |
+
q_z: (B,C,k,T) # indicates q(z_t | x_<=t, z_<t) (t=1,...,T)
|
464 |
+
"""
|
465 |
+
x_shape = x.shape # (B,C,T)
|
466 |
+
batch_size=x_shape[0]
|
467 |
+
random_indices = []
|
468 |
+
# sample z inadvance without gradient
|
469 |
+
with torch.no_grad():
|
470 |
+
z_t_minus = self.initial_input.repeat(x_shape[0],1,1) # z_<t z0=zeros:(B,C,1)
|
471 |
+
for t in range(self.n_steps-1):
|
472 |
+
inputs = torch.cat([x[...,:t+1].detach(), z_t_minus.detach()], dim=1) # (B,C+C,t+1) x_<=t and z_<t
|
473 |
+
outputs = self.layers(inputs) #(B, C*k, t+1)
|
474 |
+
q_z_t = outputs[...,-1] # (B, C*k, 1) q(z_t | x_<=t, z_<t)
|
475 |
+
|
476 |
+
# sampling from q(z_t | x_<=t, z_<t)
|
477 |
+
random_index = torch.randint(0, self.k, (batch_size*self.channels,)) \
|
478 |
+
+ torch.arange(start=0, end=batch_size*self.channels*self.k, step=self.k) #(B*C,) select 1 from every k value
|
479 |
+
random_index = random_index.to(x.device)
|
480 |
+
random_indices.append(random_index)
|
481 |
+
|
482 |
+
z_t = q_z_t.view(batch_size*self.channels*self.k)[random_index] # (B*C,)
|
483 |
+
z_t = z_t.view(batch_size, self.channels, 1) #(B,C,1)
|
484 |
+
|
485 |
+
z_t_minus = torch.cat([z_t_minus, z_t], dim=-1) # (B,C,t+2)
|
486 |
+
|
487 |
+
z_t_minus = z_t_minus.detach() # (B,C,T) z_0,...,z_{T-1}
|
488 |
+
q_z = self.layers(torch.cat([x, z_t_minus], dim=1)) # (B,C*k,T)
|
489 |
+
|
490 |
+
# input z_t_minus again to calculate tdBN
|
491 |
+
sampled_z = None
|
492 |
+
for t in range(self.n_steps):
|
493 |
+
|
494 |
+
if t == self.n_steps-1:
|
495 |
+
# when t=T
|
496 |
+
random_index = torch.randint(0, self.k, (batch_size*self.channels,)) \
|
497 |
+
+ torch.arange(start=0, end=batch_size*self.channels*self.k, step=self.k)
|
498 |
+
random_indices.append(random_index)
|
499 |
+
else:
|
500 |
+
# when t<=T-1
|
501 |
+
random_index = random_indices[t]
|
502 |
+
|
503 |
+
# sampling
|
504 |
+
sampled_z_t = q_z[...,t].view(batch_size*self.channels*self.k)[random_index] # (B*C,)
|
505 |
+
sampled_z_t = sampled_z_t.view(batch_size, self.channels, 1) #(B,C,1)
|
506 |
+
if t==0:
|
507 |
+
sampled_z = sampled_z_t
|
508 |
+
else:
|
509 |
+
sampled_z = torch.cat([sampled_z, sampled_z_t], dim=-1)
|
510 |
+
|
511 |
+
q_z = q_z.view(batch_size, self.channels, self.k, self.n_steps)# (B,C,k,T)
|
512 |
+
|
513 |
+
return sampled_z, q_z
|
514 |
+
|
515 |
+
|
516 |
+
class FSAEModel(PreTrainedModel):
|
517 |
+
config_class = FSAEConfig
|
518 |
+
|
519 |
+
def __init__(self, config):
|
520 |
+
super().__init__(config)
|
521 |
+
|
522 |
+
self.in_channels = config.in_channels
|
523 |
+
in_channels = self.in_channels
|
524 |
+
|
525 |
+
self.hidden_dims = config.hidden_dims
|
526 |
+
hidden_dims = self.hidden_dims
|
527 |
+
|
528 |
+
self.latent_dim = config.latent_dim
|
529 |
+
latent_dim = self.latent_dim
|
530 |
+
|
531 |
+
self.n_steps = config.n_steps
|
532 |
+
n_steps = self.n_steps
|
533 |
+
|
534 |
+
self.k = config.k
|
535 |
+
k = self.k
|
536 |
+
|
537 |
+
# Build Encoder
|
538 |
+
modules = []
|
539 |
+
is_first_conv = True
|
540 |
+
for h_dim in hidden_dims:
|
541 |
+
modules.append(
|
542 |
+
tdConv(
|
543 |
+
in_channels,
|
544 |
+
out_channels=h_dim,
|
545 |
+
kernel_size=3,
|
546 |
+
stride=2,
|
547 |
+
padding=1,
|
548 |
+
bias=True,
|
549 |
+
bn=tdBatchNorm(h_dim),
|
550 |
+
spike=LIFSpike(),
|
551 |
+
is_first_conv=is_first_conv,
|
552 |
+
)
|
553 |
+
)
|
554 |
+
in_channels = h_dim
|
555 |
+
is_first_conv = False
|
556 |
+
|
557 |
+
self.encoder = nn.Sequential(*modules)
|
558 |
+
self.before_latent_layer = tdLinear(
|
559 |
+
hidden_dims[-1] * 4,
|
560 |
+
latent_dim,
|
561 |
+
bias=True,
|
562 |
+
bn=tdBatchNorm(latent_dim),
|
563 |
+
spike=LIFSpike(),
|
564 |
+
)
|
565 |
+
|
566 |
+
# Build Decoder
|
567 |
+
modules = []
|
568 |
+
|
569 |
+
self.decoder_input = tdLinear(
|
570 |
+
latent_dim,
|
571 |
+
hidden_dims[-1] * 4,
|
572 |
+
bias=True,
|
573 |
+
bn=tdBatchNorm(hidden_dims[-1] * 4),
|
574 |
+
spike=LIFSpike(),
|
575 |
+
)
|
576 |
+
|
577 |
+
hidden_reverse = hidden_dims[::-1]
|
578 |
+
|
579 |
+
for i in range(len(hidden_reverse) - 1):
|
580 |
+
modules.append(
|
581 |
+
tdConvTranspose(
|
582 |
+
hidden_reverse[i],
|
583 |
+
hidden_reverse[i + 1],
|
584 |
+
kernel_size=3,
|
585 |
+
stride=2,
|
586 |
+
padding=1,
|
587 |
+
output_padding=1,
|
588 |
+
bias=True,
|
589 |
+
bn=tdBatchNorm(hidden_reverse[i + 1]),
|
590 |
+
spike=LIFSpike(),
|
591 |
+
)
|
592 |
+
)
|
593 |
+
self.decoder = nn.Sequential(*modules)
|
594 |
+
|
595 |
+
self.final_layer = nn.Sequential(
|
596 |
+
tdConvTranspose(
|
597 |
+
hidden_reverse[-1],
|
598 |
+
hidden_reverse[-1],
|
599 |
+
kernel_size=3,
|
600 |
+
stride=2,
|
601 |
+
padding=1,
|
602 |
+
output_padding=1,
|
603 |
+
bias=True,
|
604 |
+
bn=tdBatchNorm(hidden_reverse[-1]),
|
605 |
+
spike=LIFSpike(),
|
606 |
+
),
|
607 |
+
tdConvTranspose(
|
608 |
+
hidden_reverse[-1],
|
609 |
+
out_channels=1,
|
610 |
+
kernel_size=3,
|
611 |
+
padding=1,
|
612 |
+
bias=True,
|
613 |
+
bn=None,
|
614 |
+
spike=None,
|
615 |
+
),
|
616 |
+
)
|
617 |
+
|
618 |
+
self.p = 0
|
619 |
+
|
620 |
+
self.membrane_output_layer = MembraneOutputLayer()
|
621 |
+
|
622 |
+
def forward(self, x, scheduled=False):
|
623 |
+
sampled_z = self.encode(x, scheduled)
|
624 |
+
x_recon = self.decode(sampled_z)
|
625 |
+
return x_recon, sampled_z
|
626 |
+
|
627 |
+
def encode(self, x, scheduled=False):
|
628 |
+
x = self.encoder(x) # (N,C,H,W,T)
|
629 |
+
x = torch.flatten(x, start_dim=1, end_dim=3) # (N,C*H*W,T)
|
630 |
+
latent_x = self.before_latent_layer(x) # (N,latent_dim,T)
|
631 |
+
return latent_x
|
632 |
+
|
633 |
+
def decode(self, z):
|
634 |
+
result = self.decoder_input(z) # (N,C*H*W,T)
|
635 |
+
result = result.view(
|
636 |
+
result.shape[0], self.hidden_dims[-1], 2, 2, self.n_steps
|
637 |
+
) # (N,C,H,W,T)
|
638 |
+
result = self.decoder(result) # (N,C,H,W,T)
|
639 |
+
result = self.final_layer(result) # (N,C,H,W,T)
|
640 |
+
out = torch.tanh(self.membrane_output_layer(result))
|
641 |
+
return out
|
642 |
+
|
643 |
+
def sample(self, batch_size=64):
|
644 |
+
raise NotImplementedError()
|
645 |
+
|
646 |
+
def loss_function(self, recons_img, input_img):
|
647 |
+
"""
|
648 |
+
Computes the VAE loss function.
|
649 |
+
KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
|
650 |
+
:param args:
|
651 |
+
:param kwargs:
|
652 |
+
:return:
|
653 |
+
"""
|
654 |
+
|
655 |
+
recons_loss = F.mse_loss(recons_img, input_img)
|
656 |
+
|
657 |
+
return recons_loss
|