FullySpikingAutoEncoder / modeling_fsae.py
mickylan2367's picture
Upload model
ba54498
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from argparse import ZERO_OR_MORE
import math
import random
from torch.nn.modules.module import T
from transformers import PreTrainedModel
from .configuration_fsae import FSAEConfig
dt = 5
a = 0.25
aa = 0.5
Vth = 0.2
tau = 0.25
class SpikeAct(torch.autograd.Function):
"""
Implementation of the spiking activation function with an approximation of gradient.
"""
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
# if input = u > Vth then output = 1
output = torch.gt(input, Vth)
return output.float()
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
grad_input = grad_output.clone()
# hu is an approximate func of df/du
hu = abs(input) < aa
hu = hu.float() / (2 * aa)
return grad_input * hu
class LIFSpike(nn.Module):
"""
Generates spikes based on LIF module. It can be considered as an activation function and is used similar to ReLU. The input tensor needs to have an additional time dimension, which in this case is on the last dimension of the data.
"""
def __init__(self):
super(LIFSpike, self).__init__()
def forward(self, x):
nsteps = x.shape[-1]
u = torch.zeros(x.shape[:-1] , device=x.device)
out = torch.zeros(x.shape, device=x.device)
for step in range(nsteps):
u, out[..., step] = self.state_update(u, out[..., max(step-1, 0)], x[..., step])
return out
def state_update(self, u_t_n1, o_t_n1, W_mul_o_t1_n, tau=tau):
u_t1_n1 = tau * u_t_n1 * (1 - o_t_n1) + W_mul_o_t1_n
o_t1_n1 = SpikeAct.apply(u_t1_n1)
return u_t1_n1, o_t1_n1
class tdLinear(nn.Linear):
def __init__(self,
in_features,
out_features,
bias=True,
bn=None,
spike=None):
assert type(in_features) == int, 'inFeatures should not be more than 1 dimesnion. It was: {}'.format(in_features.shape)
assert type(out_features) == int, 'outFeatures should not be more than 1 dimesnion. It was: {}'.format(out_features.shape)
super(tdLinear, self).__init__(in_features, out_features, bias=bias)
self.bn = bn
self.spike = spike
def forward(self, x):
"""
x : (N,C,T)
"""
x = x.transpose(1, 2) # (N, T, C)
y = F.linear(x, self.weight, self.bias)
y = y.transpose(1, 2)# (N, C, T)
if self.bn is not None:
y = y[:,:,None,None,:]
y = self.bn(y)
y = y[:,:,0,0,:]
if self.spike is not None:
y = self.spike(y)
return y
class tdConv(nn.Conv3d):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
bn=None,
spike=None,
is_first_conv=False):
# kernel
if type(kernel_size) == int:
kernel = (kernel_size, kernel_size, 1)
elif len(kernel_size) == 2:
kernel = (kernel_size[0], kernel_size[1], 1)
else:
raise Exception('kernelSize can only be of 1 or 2 dimension. It was: {}'.format(kernel_size.shape))
# stride
if type(stride) == int:
stride = (stride, stride, 1)
elif len(stride) == 2:
stride = (stride[0], stride[1], 1)
else:
raise Exception('stride can be either int or tuple of size 2. It was: {}'.format(stride.shape))
# padding
if type(padding) == int:
padding = (padding, padding, 0)
elif len(padding) == 2:
padding = (padding[0], padding[1], 0)
else:
raise Exception('padding can be either int or tuple of size 2. It was: {}'.format(padding.shape))
# dilation
if type(dilation) == int:
dilation = (dilation, dilation, 1)
elif len(dilation) == 2:
dilation = (dilation[0], dilation[1], 1)
else:
raise Exception('dilation can be either int or tuple of size 2. It was: {}'.format(dilation.shape))
super(tdConv, self).__init__(in_channels, out_channels, kernel, stride, padding, dilation, groups,
bias=bias)
self.bn = bn
self.spike = spike
self.is_first_conv = is_first_conv
def forward(self, x):
x = F.conv3d(x, self.weight, self.bias,
self.stride, self.padding, self.dilation, self.groups)
if self.bn is not None:
x = self.bn(x)
if self.spike is not None:
x = self.spike(x)
return x
class tdConvTranspose(nn.ConvTranspose3d):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
output_padding=0,
dilation=1,
groups=1,
bias=True,
bn=None,
spike=None):
# kernel
if type(kernel_size) == int:
kernel = (kernel_size, kernel_size, 1)
elif len(kernel_size) == 2:
kernel = (kernel_size[0], kernel_size[1], 1)
else:
raise Exception('kernelSize can only be of 1 or 2 dimension. It was: {}'.format(kernel_size.shape))
# stride
if type(stride) == int:
stride = (stride, stride, 1)
elif len(stride) == 2:
stride = (stride[0], stride[1], 1)
else:
raise Exception('stride can be either int or tuple of size 2. It was: {}'.format(stride.shape))
# padding
if type(padding) == int:
padding = (padding, padding, 0)
elif len(padding) == 2:
padding = (padding[0], padding[1], 0)
else:
raise Exception('padding can be either int or tuple of size 2. It was: {}'.format(padding.shape))
# dilation
if type(dilation) == int:
dilation = (dilation, dilation, 1)
elif len(dilation) == 2:
dilation = (dilation[0], dilation[1], 1)
else:
raise Exception('dilation can be either int or tuple of size 2. It was: {}'.format(dilation.shape))
# output padding
if type(output_padding) == int:
output_padding = (output_padding, output_padding, 0)
elif len(output_padding) == 2:
output_padding = (output_padding[0], output_padding[1], 0)
else:
raise Exception('output_padding can be either int or tuple of size 2. It was: {}'.format(padding.shape))
super().__init__(in_channels, out_channels, kernel, stride, padding, output_padding, groups,
bias=bias, dilation=dilation)
self.bn = bn
self.spike = spike
def forward(self, x):
x = F.conv_transpose3d(x, self.weight, self.bias,
self.stride, self.padding,
self.output_padding, self.groups, self.dilation)
if self.bn is not None:
x = self.bn(x)
if self.spike is not None:
x = self.spike(x)
return x
class tdBatchNorm(nn.BatchNorm2d):
"""
Implementation of tdBN. Link to related paper: https://arxiv.org/pdf/2011.05280. In short it is averaged over the time domain as well when doing BN.
Args:
num_features (int): same with nn.BatchNorm2d
eps (float): same with nn.BatchNorm2d
momentum (float): same with nn.BatchNorm2d
alpha (float): an addtional parameter which may change in resblock.
affine (bool): same with nn.BatchNorm2d
track_running_stats (bool): same with nn.BatchNorm2d
"""
def __init__(self, num_features, eps=1e-05, momentum=0.1, alpha=1, affine=True, track_running_stats=True):
super(tdBatchNorm, self).__init__(
num_features, eps, momentum, affine, track_running_stats)
self.alpha = alpha
def forward(self, input):
exponential_average_factor = 0.0
if self.training and self.track_running_stats:
if self.num_batches_tracked is not None:
self.num_batches_tracked += 1
if self.momentum is None: # use cumulative moving average
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
else: # use exponential moving average
exponential_average_factor = self.momentum
# calculate running estimates
if self.training:
mean = input.mean([0, 2, 3, 4])
# use biased var in train
var = input.var([0, 2, 3, 4], unbiased=False)
n = input.numel() / input.size(1)
with torch.no_grad():
self.running_mean = exponential_average_factor * mean\
+ (1 - exponential_average_factor) * self.running_mean
# update running_var with unbiased var
self.running_var = exponential_average_factor * var * n / (n - 1)\
+ (1 - exponential_average_factor) * self.running_var
else:
mean = self.running_mean
var = self.running_var
input = self.alpha * Vth * (input - mean[None, :, None, None, None]) / (torch.sqrt(var[None, :, None, None, None] + self.eps))
if self.affine:
input = input * self.weight[None, :, None, None, None] + self.bias[None, :, None, None, None]
return input
class PSP(torch.nn.Module):
def __init__(self):
super().__init__()
self.tau_s = 2
def forward(self, inputs):
"""
inputs: (N, C, T)
"""
syns = None
syn = 0
n_steps = inputs.shape[-1]
for t in range(n_steps):
syn = syn + (inputs[...,t] - syn) / self.tau_s
if syns is None:
syns = syn.unsqueeze(-1)
else:
syns = torch.cat([syns, syn.unsqueeze(-1)], dim=-1)
return syns
class MembraneOutputLayer(nn.Module):
"""
outputs the last time membrane potential of the LIF neuron with V_th=infty
"""
def __init__(self) -> None:
super().__init__()
# n_steps = glv.n_steps
n_steps = 16
arr = torch.arange(n_steps-1,-1,-1)
self.register_buffer("coef", torch.pow(0.8, arr)[None,None,None,None,:]) # (1,1,1,1,T)
def forward(self, x):
"""
x : (N,C,H,W,T)
"""
out = torch.sum(x*self.coef, dim=-1)
return out
class PriorBernoulliSTBP(nn.Module):
def __init__(self, k=20) -> None:
"""
modeling of p(z_t|z_<t)
"""
super().__init__()
# self.channels = glv.network_config['latent_dim']
self.channels = 128
self.k = k
# self.n_steps = glv.network_config['n_steps']
self.n_steps = 16
self.layers = nn.Sequential(
tdLinear(self.channels,
self.channels*2,
bias=True,
bn=tdBatchNorm(self.channels*2, alpha=2),
spike=LIFSpike()),
tdLinear(self.channels*2,
self.channels*4,
bias=True,
bn=tdBatchNorm(self.channels*4, alpha=2),
spike=LIFSpike()),
tdLinear(self.channels*4,
self.channels*k,
bias=True,
bn=tdBatchNorm(self.channels*k, alpha=2),
spike=LIFSpike())
)
self.register_buffer('initial_input', torch.zeros(1, self.channels, 1))# (1,C,1)
def forward(self, z, scheduled=False, p=None):
if scheduled:
return self._forward_scheduled_sampling(z, p)
else:
return self._forward(z)
def _forward(self, z):
"""
input z: (B,C,T) # latent spike sampled from posterior
output : (B,C,k,T) # indicates p(z_t|z_<t) (t=1,...,T)
"""
z_shape = z.shape # (B,C,T)
batch_size = z_shape[0]
z = z.detach()
z0 = self.initial_input.repeat(batch_size, 1, 1) # (B,C,1)
inputs = torch.cat([z0, z[...,:-1]], dim=-1) # (B,C,T)
outputs = self.layers(inputs) # (B,C*k,T)
p_z = outputs.view(batch_size, self.channels, self.k, self.n_steps) # (B,C,k,T)
return p_z
def _forward_scheduled_sampling(self, z, p):
"""
use scheduled sampling
input
z: (B,C,T) # latent spike sampled from posterior
p: float # prob of scheduled sampling
output : (B,C,k,T) # indicates p(z_t|z_<t) (t=1,...,T)
"""
z_shape = z.shape # (B,C,T)
batch_size = z_shape[0]
z = z.detach()
z_t_minus = self.initial_input.repeat(batch_size,1,1) # z_<t, z0=zeros:(B,C,1)
if self.training:
with torch.no_grad():
for t in range(self.n_steps-1):
if t>=5 and random.random() < p: # scheduled sampling
outputs = self.layers(z_t_minus.detach()) #binary (B, C*k, t+1) z_<=t
p_z_t = outputs[...,-1] # (B, C*k, 1)
# sampling from p(z_t | z_<t)
prob1 = p_z_t.view(batch_size, self.channels, self.k).mean(-1) # (B,C)
prob1 = prob1 + 1e-3 * torch.randn_like(prob1)
z_t = (prob1>0.5).float() # (B,C)
z_t = z_t.view(batch_size, self.channels, 1) #(B,C,1)
z_t_minus = torch.cat([z_t_minus, z_t], dim=-1) # (B,C,t+2)
else:
z_t_minus = torch.cat([z_t_minus, z[...,t].unsqueeze(-1)], dim=-1) # (B,C,t+2)
else: # for test time
z_t_minus = torch.cat([z_t_minus, z[:,:,:-1]], dim=-1) # (B,C,T)
z_t_minus = z_t_minus.detach() # (B,C,T) z_{<=T-1}
p_z = self.layers(z_t_minus) # (B,C*k,T)
p_z = p_z.view(batch_size, self.channels, self.k, self.n_steps)# (B,C,k,T)
return p_z
def sample(self, batch_size=64):
z_minus_t = self.initial_input.repeat(batch_size, 1, 1) # (B, C, 1)
for t in range(self.n_steps):
outputs = self.layers(z_minus_t) # (B, C*k, t+1)
p_z_t = outputs[...,-1] # (B, C*k, 1)
random_index = torch.randint(0, self.k, (batch_size*self.channels,)) \
+ torch.arange(start=0, end=batch_size*self.channels*self.k, step=self.k) #(B*C,) pick one from k
random_index = random_index.to(z_minus_t.device)
z_t = p_z_t.view(batch_size*self.channels*self.k)[random_index] # (B*C,)
z_t = z_t.view(batch_size, self.channels, 1) #(B,C,1)
z_minus_t = torch.cat([z_minus_t, z_t], dim=-1) # (B,C,t+2)
sampled_z = z_minus_t[...,1:] # (B,C,T)
return sampled_z
class PosteriorBernoulliSTBP(nn.Module):
def __init__(self, k=20) -> None:
"""
modeling of q(z_t | x_<=t, z_<t)
"""
super().__init__()
# self.channels = glv.network_config['latent_dim']
self.channels = 128
self.k = k
# self.n_steps = glv.network_config['n_steps']
self.n_steps = 16
self.layers = nn.Sequential(
tdLinear(self.channels*2,
self.channels*2,
bias=True,
bn=tdBatchNorm(self.channels*2, alpha=2),
spike=LIFSpike()),
tdLinear(self.channels*2,
self.channels*4,
bias=True,
bn=tdBatchNorm(self.channels*4, alpha=2),
spike=LIFSpike()),
tdLinear(self.channels*4,
self.channels*k,
bias=True,
bn=tdBatchNorm(self.channels*k, alpha=2),
spike=LIFSpike())
)
self.register_buffer('initial_input', torch.zeros(1, self.channels, 1))# (1,C,1)
self.is_true_scheduled_sampling = True
def forward(self, x):
"""
input:
x:(B,C,T)
returns:
sampled_z:(B,C,T)
q_z: (B,C,k,T) # indicates q(z_t | x_<=t, z_<t) (t=1,...,T)
"""
x_shape = x.shape # (B,C,T)
batch_size=x_shape[0]
random_indices = []
# sample z inadvance without gradient
with torch.no_grad():
z_t_minus = self.initial_input.repeat(x_shape[0],1,1) # z_<t z0=zeros:(B,C,1)
for t in range(self.n_steps-1):
inputs = torch.cat([x[...,:t+1].detach(), z_t_minus.detach()], dim=1) # (B,C+C,t+1) x_<=t and z_<t
outputs = self.layers(inputs) #(B, C*k, t+1)
q_z_t = outputs[...,-1] # (B, C*k, 1) q(z_t | x_<=t, z_<t)
# sampling from q(z_t | x_<=t, z_<t)
random_index = torch.randint(0, self.k, (batch_size*self.channels,)) \
+ torch.arange(start=0, end=batch_size*self.channels*self.k, step=self.k) #(B*C,) select 1 from every k value
random_index = random_index.to(x.device)
random_indices.append(random_index)
z_t = q_z_t.view(batch_size*self.channels*self.k)[random_index] # (B*C,)
z_t = z_t.view(batch_size, self.channels, 1) #(B,C,1)
z_t_minus = torch.cat([z_t_minus, z_t], dim=-1) # (B,C,t+2)
z_t_minus = z_t_minus.detach() # (B,C,T) z_0,...,z_{T-1}
q_z = self.layers(torch.cat([x, z_t_minus], dim=1)) # (B,C*k,T)
# input z_t_minus again to calculate tdBN
sampled_z = None
for t in range(self.n_steps):
if t == self.n_steps-1:
# when t=T
random_index = torch.randint(0, self.k, (batch_size*self.channels,)) \
+ torch.arange(start=0, end=batch_size*self.channels*self.k, step=self.k)
random_indices.append(random_index)
else:
# when t<=T-1
random_index = random_indices[t]
# sampling
sampled_z_t = q_z[...,t].view(batch_size*self.channels*self.k)[random_index] # (B*C,)
sampled_z_t = sampled_z_t.view(batch_size, self.channels, 1) #(B,C,1)
if t==0:
sampled_z = sampled_z_t
else:
sampled_z = torch.cat([sampled_z, sampled_z_t], dim=-1)
q_z = q_z.view(batch_size, self.channels, self.k, self.n_steps)# (B,C,k,T)
return sampled_z, q_z
class FSAEModel(PreTrainedModel):
config_class = FSAEConfig
def __init__(self, config):
super().__init__(config)
self.in_channels = config.in_channels
in_channels = self.in_channels
self.hidden_dims = config.hidden_dims
hidden_dims = self.hidden_dims
self.latent_dim = config.latent_dim
latent_dim = self.latent_dim
self.n_steps = config.n_steps
n_steps = self.n_steps
self.k = config.k
k = self.k
# Build Encoder
modules = []
is_first_conv = True
for h_dim in hidden_dims:
modules.append(
tdConv(
in_channels,
out_channels=h_dim,
kernel_size=3,
stride=2,
padding=1,
bias=True,
bn=tdBatchNorm(h_dim),
spike=LIFSpike(),
is_first_conv=is_first_conv,
)
)
in_channels = h_dim
is_first_conv = False
self.encoder = nn.Sequential(*modules)
self.before_latent_layer = tdLinear(
hidden_dims[-1] * 4,
latent_dim,
bias=True,
bn=tdBatchNorm(latent_dim),
spike=LIFSpike(),
)
# Build Decoder
modules = []
self.decoder_input = tdLinear(
latent_dim,
hidden_dims[-1] * 4,
bias=True,
bn=tdBatchNorm(hidden_dims[-1] * 4),
spike=LIFSpike(),
)
hidden_reverse = hidden_dims[::-1]
for i in range(len(hidden_reverse) - 1):
modules.append(
tdConvTranspose(
hidden_reverse[i],
hidden_reverse[i + 1],
kernel_size=3,
stride=2,
padding=1,
output_padding=1,
bias=True,
bn=tdBatchNorm(hidden_reverse[i + 1]),
spike=LIFSpike(),
)
)
self.decoder = nn.Sequential(*modules)
self.final_layer = nn.Sequential(
tdConvTranspose(
hidden_reverse[-1],
hidden_reverse[-1],
kernel_size=3,
stride=2,
padding=1,
output_padding=1,
bias=True,
bn=tdBatchNorm(hidden_reverse[-1]),
spike=LIFSpike(),
),
tdConvTranspose(
hidden_reverse[-1],
out_channels=1,
kernel_size=3,
padding=1,
bias=True,
bn=None,
spike=None,
),
)
self.p = 0
self.membrane_output_layer = MembraneOutputLayer()
def forward(self, x, scheduled=False):
sampled_z = self.encode(x, scheduled)
x_recon = self.decode(sampled_z)
return x_recon, sampled_z
def encode(self, x, scheduled=False):
x = self.encoder(x) # (N,C,H,W,T)
x = torch.flatten(x, start_dim=1, end_dim=3) # (N,C*H*W,T)
latent_x = self.before_latent_layer(x) # (N,latent_dim,T)
return latent_x
def decode(self, z):
result = self.decoder_input(z) # (N,C*H*W,T)
result = result.view(
result.shape[0], self.hidden_dims[-1], 2, 2, self.n_steps
) # (N,C,H,W,T)
result = self.decoder(result) # (N,C,H,W,T)
result = self.final_layer(result) # (N,C,H,W,T)
out = torch.tanh(self.membrane_output_layer(result))
return out
def sample(self, batch_size=64):
raise NotImplementedError()
def loss_function(self, recons_img, input_img):
"""
Computes the VAE loss function.
KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
:param args:
:param kwargs:
:return:
"""
recons_loss = F.mse_loss(recons_img, input_img)
return recons_loss