import torch |
import torch.nn as nn |
import torch.nn.functional as F |
import torch.optim as optim |
from argparse import ZERO_OR_MORE |
import math |
import random |
from torch.nn.modules.module import T |
from transformers import PreTrainedModel |
from .configuration_fsae import FSAEConfig |
dt = 5 |
a = 0.25 |
aa = 0.5 |
Vth = 0.2 |
tau = 0.25 |
class SpikeAct(torch.autograd.Function): |
""" |
Implementation of the spiking activation function with an approximation of gradient. |
""" |
@staticmethod |
def forward(ctx, input): |
ctx.save_for_backward(input) |
output = torch.gt(input, Vth) |
return output.float() |
@staticmethod |
def backward(ctx, grad_output): |
input, = ctx.saved_tensors |
grad_input = grad_output.clone() |
hu = abs(input) < aa |
hu = hu.float() / (2 * aa) |
return grad_input * hu |
class LIFSpike(nn.Module): |
""" |
Generates spikes based on LIF module. It can be considered as an activation function and is used similar to ReLU. The input tensor needs to have an additional time dimension, which in this case is on the last dimension of the data. |
""" |
def __init__(self): |
super(LIFSpike, self).__init__() |
def forward(self, x): |
nsteps = x.shape[-1] |
u = torch.zeros(x.shape[:-1] , device=x.device) |
out = torch.zeros(x.shape, device=x.device) |
for step in range(nsteps): |
u, out[..., step] = self.state_update(u, out[..., max(step-1, 0)], x[..., step]) |
return out |
def state_update(self, u_t_n1, o_t_n1, W_mul_o_t1_n, tau=tau): |
u_t1_n1 = tau * u_t_n1 * (1 - o_t_n1) + W_mul_o_t1_n |
o_t1_n1 = SpikeAct.apply(u_t1_n1) |
return u_t1_n1, o_t1_n1 |
class tdLinear(nn.Linear): |
def __init__(self, |
in_features, |
out_features, |
bias=True, |
bn=None, |
spike=None): |
assert type(in_features) == int, 'inFeatures should not be more than 1 dimesnion. It was: {}'.format(in_features.shape) |
assert type(out_features) == int, 'outFeatures should not be more than 1 dimesnion. It was: {}'.format(out_features.shape) |
super(tdLinear, self).__init__(in_features, out_features, bias=bias) |
self.bn = bn |
self.spike = spike |
def forward(self, x): |
""" |
x : (N,C,T) |
""" |
x = x.transpose(1, 2) |
y = F.linear(x, self.weight, self.bias) |
y = y.transpose(1, 2) |
if self.bn is not None: |
y = y[:,:,None,None,:] |
y = self.bn(y) |
y = y[:,:,0,0,:] |
if self.spike is not None: |
y = self.spike(y) |
return y |
class tdConv(nn.Conv3d): |
def __init__(self, |
in_channels, |
out_channels, |
kernel_size, |
stride=1, |
padding=0, |
dilation=1, |
groups=1, |
bias=True, |
bn=None, |
spike=None, |
is_first_conv=False): |
if type(kernel_size) == int: |
kernel = (kernel_size, kernel_size, 1) |
elif len(kernel_size) == 2: |
kernel = (kernel_size[0], kernel_size[1], 1) |
else: |
raise Exception('kernelSize can only be of 1 or 2 dimension. It was: {}'.format(kernel_size.shape)) |
if type(stride) == int: |
stride = (stride, stride, 1) |
elif len(stride) == 2: |
stride = (stride[0], stride[1], 1) |
else: |
raise Exception('stride can be either int or tuple of size 2. It was: {}'.format(stride.shape)) |
if type(padding) == int: |
padding = (padding, padding, 0) |
elif len(padding) == 2: |
padding = (padding[0], padding[1], 0) |
else: |
raise Exception('padding can be either int or tuple of size 2. It was: {}'.format(padding.shape)) |
if type(dilation) == int: |
dilation = (dilation, dilation, 1) |
elif len(dilation) == 2: |
dilation = (dilation[0], dilation[1], 1) |
else: |
raise Exception('dilation can be either int or tuple of size 2. It was: {}'.format(dilation.shape)) |
super(tdConv, self).__init__(in_channels, out_channels, kernel, stride, padding, dilation, groups, |
bias=bias) |
self.bn = bn |
self.spike = spike |
self.is_first_conv = is_first_conv |
def forward(self, x): |
x = F.conv3d(x, self.weight, self.bias, |
self.stride, self.padding, self.dilation, self.groups) |
if self.bn is not None: |
x = self.bn(x) |
if self.spike is not None: |
x = self.spike(x) |
return x |
class tdConvTranspose(nn.ConvTranspose3d): |
def __init__(self, |
in_channels, |
out_channels, |
kernel_size, |
stride=1, |
padding=0, |
output_padding=0, |
dilation=1, |
groups=1, |
bias=True, |
bn=None, |
spike=None): |
if type(kernel_size) == int: |
kernel = (kernel_size, kernel_size, 1) |
elif len(kernel_size) == 2: |
kernel = (kernel_size[0], kernel_size[1], 1) |
else: |
raise Exception('kernelSize can only be of 1 or 2 dimension. It was: {}'.format(kernel_size.shape)) |
if type(stride) == int: |
stride = (stride, stride, 1) |
elif len(stride) == 2: |
stride = (stride[0], stride[1], 1) |
else: |
raise Exception('stride can be either int or tuple of size 2. It was: {}'.format(stride.shape)) |
if type(padding) == int: |
padding = (padding, padding, 0) |
elif len(padding) == 2: |
padding = (padding[0], padding[1], 0) |
else: |
raise Exception('padding can be either int or tuple of size 2. It was: {}'.format(padding.shape)) |
if type(dilation) == int: |
dilation = (dilation, dilation, 1) |
elif len(dilation) == 2: |
dilation = (dilation[0], dilation[1], 1) |
else: |
raise Exception('dilation can be either int or tuple of size 2. It was: {}'.format(dilation.shape)) |
if type(output_padding) == int: |
output_padding = (output_padding, output_padding, 0) |
elif len(output_padding) == 2: |
output_padding = (output_padding[0], output_padding[1], 0) |
else: |
raise Exception('output_padding can be either int or tuple of size 2. It was: {}'.format(padding.shape)) |
super().__init__(in_channels, out_channels, kernel, stride, padding, output_padding, groups, |
bias=bias, dilation=dilation) |
self.bn = bn |
self.spike = spike |
def forward(self, x): |
x = F.conv_transpose3d(x, self.weight, self.bias, |
self.stride, self.padding, |
self.output_padding, self.groups, self.dilation) |
if self.bn is not None: |
x = self.bn(x) |
if self.spike is not None: |
x = self.spike(x) |
return x |
class tdBatchNorm(nn.BatchNorm2d): |
""" |
Implementation of tdBN. Link to related paper: https://arxiv.org/pdf/2011.05280. In short it is averaged over the time domain as well when doing BN. |
Args: |
num_features (int): same with nn.BatchNorm2d |
eps (float): same with nn.BatchNorm2d |
momentum (float): same with nn.BatchNorm2d |
alpha (float): an addtional parameter which may change in resblock. |
affine (bool): same with nn.BatchNorm2d |
track_running_stats (bool): same with nn.BatchNorm2d |
""" |
def __init__(self, num_features, eps=1e-05, momentum=0.1, alpha=1, affine=True, track_running_stats=True): |
super(tdBatchNorm, self).__init__( |
num_features, eps, momentum, affine, track_running_stats) |
self.alpha = alpha |
def forward(self, input): |
exponential_average_factor = 0.0 |
if self.training and self.track_running_stats: |
if self.num_batches_tracked is not None: |
self.num_batches_tracked += 1 |
if self.momentum is None: |
exponential_average_factor = 1.0 / float(self.num_batches_tracked) |
else: |
exponential_average_factor = self.momentum |
if self.training: |
mean = input.mean([0, 2, 3, 4]) |
var = input.var([0, 2, 3, 4], unbiased=False) |
n = input.numel() / input.size(1) |
with torch.no_grad(): |
self.running_mean = exponential_average_factor * mean\ |
+ (1 - exponential_average_factor) * self.running_mean |
self.running_var = exponential_average_factor * var * n / (n - 1)\ |
+ (1 - exponential_average_factor) * self.running_var |
else: |
mean = self.running_mean |
var = self.running_var |
input = self.alpha * Vth * (input - mean[None, :, None, None, None]) / (torch.sqrt(var[None, :, None, None, None] + self.eps)) |
if self.affine: |
input = input * self.weight[None, :, None, None, None] + self.bias[None, :, None, None, None] |
return input |
class PSP(torch.nn.Module): |
def __init__(self): |
super().__init__() |
self.tau_s = 2 |
def forward(self, inputs): |
""" |
inputs: (N, C, T) |
""" |
syns = None |
syn = 0 |
n_steps = inputs.shape[-1] |
for t in range(n_steps): |
syn = syn + (inputs[...,t] - syn) / self.tau_s |
if syns is None: |
syns = syn.unsqueeze(-1) |
else: |
syns = torch.cat([syns, syn.unsqueeze(-1)], dim=-1) |
return syns |
class MembraneOutputLayer(nn.Module): |
""" |
outputs the last time membrane potential of the LIF neuron with V_th=infty |
""" |
def __init__(self) -> None: |
super().__init__() |
n_steps = 16 |
arr = torch.arange(n_steps-1,-1,-1) |
self.register_buffer("coef", torch.pow(0.8, arr)[None,None,None,None,:]) |
def forward(self, x): |
""" |
x : (N,C,H,W,T) |
""" |
out = torch.sum(x*self.coef, dim=-1) |
return out |
class PriorBernoulliSTBP(nn.Module): |
def __init__(self, k=20) -> None: |
""" |
modeling of p(z_t|z_<t) |
""" |
super().__init__() |
self.channels = 128 |
self.k = k |
self.n_steps = 16 |
self.layers = nn.Sequential( |
tdLinear(self.channels, |
self.channels*2, |
bias=True, |
bn=tdBatchNorm(self.channels*2, alpha=2), |
spike=LIFSpike()), |
tdLinear(self.channels*2, |
self.channels*4, |
bias=True, |
bn=tdBatchNorm(self.channels*4, alpha=2), |
spike=LIFSpike()), |
tdLinear(self.channels*4, |
self.channels*k, |
bias=True, |
bn=tdBatchNorm(self.channels*k, alpha=2), |
spike=LIFSpike()) |
) |
self.register_buffer('initial_input', torch.zeros(1, self.channels, 1)) |
def forward(self, z, scheduled=False, p=None): |
if scheduled: |
return self._forward_scheduled_sampling(z, p) |
else: |
return self._forward(z) |
def _forward(self, z): |
""" |
input z: (B,C,T) # latent spike sampled from posterior |
output : (B,C,k,T) # indicates p(z_t|z_<t) (t=1,...,T) |
""" |
z_shape = z.shape |
batch_size = z_shape[0] |
z = z.detach() |
z0 = self.initial_input.repeat(batch_size, 1, 1) |
inputs = torch.cat([z0, z[...,:-1]], dim=-1) |
outputs = self.layers(inputs) |
p_z = outputs.view(batch_size, self.channels, self.k, self.n_steps) |
return p_z |
def _forward_scheduled_sampling(self, z, p): |
""" |
use scheduled sampling |
input |
z: (B,C,T) # latent spike sampled from posterior |
p: float # prob of scheduled sampling |
output : (B,C,k,T) # indicates p(z_t|z_<t) (t=1,...,T) |
""" |
z_shape = z.shape |
batch_size = z_shape[0] |
z = z.detach() |
z_t_minus = self.initial_input.repeat(batch_size,1,1) |
if self.training: |
with torch.no_grad(): |
for t in range(self.n_steps-1): |
if t>=5 and random.random() < p: |
outputs = self.layers(z_t_minus.detach()) |
p_z_t = outputs[...,-1] |
prob1 = p_z_t.view(batch_size, self.channels, self.k).mean(-1) |
prob1 = prob1 + 1e-3 * torch.randn_like(prob1) |
z_t = (prob1>0.5).float() |
z_t = z_t.view(batch_size, self.channels, 1) |
z_t_minus = torch.cat([z_t_minus, z_t], dim=-1) |
else: |
z_t_minus = torch.cat([z_t_minus, z[...,t].unsqueeze(-1)], dim=-1) |
else: |
z_t_minus = torch.cat([z_t_minus, z[:,:,:-1]], dim=-1) |
z_t_minus = z_t_minus.detach() |
p_z = self.layers(z_t_minus) |
p_z = p_z.view(batch_size, self.channels, self.k, self.n_steps) |
return p_z |
def sample(self, batch_size=64): |
z_minus_t = self.initial_input.repeat(batch_size, 1, 1) |
for t in range(self.n_steps): |
outputs = self.layers(z_minus_t) |
p_z_t = outputs[...,-1] |
random_index = torch.randint(0, self.k, (batch_size*self.channels,)) \ |
+ torch.arange(start=0, end=batch_size*self.channels*self.k, step=self.k) |
random_index = random_index.to(z_minus_t.device) |
z_t = p_z_t.view(batch_size*self.channels*self.k)[random_index] |
z_t = z_t.view(batch_size, self.channels, 1) |
z_minus_t = torch.cat([z_minus_t, z_t], dim=-1) |
sampled_z = z_minus_t[...,1:] |
return sampled_z |
class PosteriorBernoulliSTBP(nn.Module): |
def __init__(self, k=20) -> None: |
""" |
modeling of q(z_t | x_<=t, z_<t) |
""" |
super().__init__() |
self.channels = 128 |
self.k = k |
self.n_steps = 16 |
self.layers = nn.Sequential( |
tdLinear(self.channels*2, |
self.channels*2, |
bias=True, |
bn=tdBatchNorm(self.channels*2, alpha=2), |
spike=LIFSpike()), |
tdLinear(self.channels*2, |
self.channels*4, |
bias=True, |
bn=tdBatchNorm(self.channels*4, alpha=2), |
spike=LIFSpike()), |
tdLinear(self.channels*4, |
self.channels*k, |
bias=True, |
bn=tdBatchNorm(self.channels*k, alpha=2), |
spike=LIFSpike()) |
) |
self.register_buffer('initial_input', torch.zeros(1, self.channels, 1)) |
self.is_true_scheduled_sampling = True |
def forward(self, x): |
""" |
input: |
x:(B,C,T) |
returns: |
sampled_z:(B,C,T) |
q_z: (B,C,k,T) # indicates q(z_t | x_<=t, z_<t) (t=1,...,T) |
""" |
x_shape = x.shape |
batch_size=x_shape[0] |
random_indices = [] |
with torch.no_grad(): |
z_t_minus = self.initial_input.repeat(x_shape[0],1,1) |
for t in range(self.n_steps-1): |
inputs = torch.cat([x[...,:t+1].detach(), z_t_minus.detach()], dim=1) |
outputs = self.layers(inputs) |
q_z_t = outputs[...,-1] |
random_index = torch.randint(0, self.k, (batch_size*self.channels,)) \ |
+ torch.arange(start=0, end=batch_size*self.channels*self.k, step=self.k) |
random_index = random_index.to(x.device) |
random_indices.append(random_index) |
z_t = q_z_t.view(batch_size*self.channels*self.k)[random_index] |
z_t = z_t.view(batch_size, self.channels, 1) |
z_t_minus = torch.cat([z_t_minus, z_t], dim=-1) |
z_t_minus = z_t_minus.detach() |
q_z = self.layers(torch.cat([x, z_t_minus], dim=1)) |
sampled_z = None |
for t in range(self.n_steps): |
if t == self.n_steps-1: |
random_index = torch.randint(0, self.k, (batch_size*self.channels,)) \ |
+ torch.arange(start=0, end=batch_size*self.channels*self.k, step=self.k) |
random_indices.append(random_index) |
else: |
random_index = random_indices[t] |
sampled_z_t = q_z[...,t].view(batch_size*self.channels*self.k)[random_index] |
sampled_z_t = sampled_z_t.view(batch_size, self.channels, 1) |
if t==0: |
sampled_z = sampled_z_t |
else: |
sampled_z = torch.cat([sampled_z, sampled_z_t], dim=-1) |
q_z = q_z.view(batch_size, self.channels, self.k, self.n_steps) |
return sampled_z, q_z |
class FSAEModel(PreTrainedModel): |
config_class = FSAEConfig |
def __init__(self, config): |
super().__init__(config) |
self.in_channels = config.in_channels |
in_channels = self.in_channels |
self.hidden_dims = config.hidden_dims |
hidden_dims = self.hidden_dims |
self.latent_dim = config.latent_dim |
latent_dim = self.latent_dim |
self.n_steps = config.n_steps |
n_steps = self.n_steps |
self.k = config.k |
k = self.k |
modules = [] |
is_first_conv = True |
for h_dim in hidden_dims: |
modules.append( |
tdConv( |
in_channels, |
out_channels=h_dim, |
kernel_size=3, |
stride=2, |
padding=1, |
bias=True, |
bn=tdBatchNorm(h_dim), |
spike=LIFSpike(), |
is_first_conv=is_first_conv, |
) |
) |
in_channels = h_dim |
is_first_conv = False |
self.encoder = nn.Sequential(*modules) |
self.before_latent_layer = tdLinear( |
hidden_dims[-1] * 4, |
latent_dim, |
bias=True, |
bn=tdBatchNorm(latent_dim), |
spike=LIFSpike(), |
) |
modules = [] |
self.decoder_input = tdLinear( |
latent_dim, |
hidden_dims[-1] * 4, |
bias=True, |
bn=tdBatchNorm(hidden_dims[-1] * 4), |
spike=LIFSpike(), |
) |
hidden_reverse = hidden_dims[::-1] |
for i in range(len(hidden_reverse) - 1): |
modules.append( |
tdConvTranspose( |
hidden_reverse[i], |
hidden_reverse[i + 1], |
kernel_size=3, |
stride=2, |
padding=1, |
output_padding=1, |
bias=True, |
bn=tdBatchNorm(hidden_reverse[i + 1]), |
spike=LIFSpike(), |
) |
) |
self.decoder = nn.Sequential(*modules) |
self.final_layer = nn.Sequential( |
tdConvTranspose( |
hidden_reverse[-1], |
hidden_reverse[-1], |
kernel_size=3, |
stride=2, |
padding=1, |
output_padding=1, |
bias=True, |
bn=tdBatchNorm(hidden_reverse[-1]), |
spike=LIFSpike(), |
), |
tdConvTranspose( |
hidden_reverse[-1], |
out_channels=1, |
kernel_size=3, |
padding=1, |
bias=True, |
bn=None, |
spike=None, |
), |
) |
self.p = 0 |
self.membrane_output_layer = MembraneOutputLayer() |
def forward(self, x, scheduled=False): |
sampled_z = self.encode(x, scheduled) |
x_recon = self.decode(sampled_z) |
return x_recon, sampled_z |
def encode(self, x, scheduled=False): |
x = self.encoder(x) |
x = torch.flatten(x, start_dim=1, end_dim=3) |
latent_x = self.before_latent_layer(x) |
return latent_x |
def decode(self, z): |
result = self.decoder_input(z) |
result = result.view( |
result.shape[0], self.hidden_dims[-1], 2, 2, self.n_steps |
) |
result = self.decoder(result) |
result = self.final_layer(result) |
out = torch.tanh(self.membrane_output_layer(result)) |
return out |
def sample(self, batch_size=64): |
raise NotImplementedError() |
def loss_function(self, recons_img, input_img): |
""" |
Computes the VAE loss function. |
KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2} |
:param args: |
:param kwargs: |
:return: |
""" |
recons_loss = F.mse_loss(recons_img, input_img) |
return recons_loss |