|  | import torch | 
					
						
						|  | import torch.nn as nn | 
					
						
						|  | import torch.nn.functional as F | 
					
						
						|  | import torch.optim as optim | 
					
						
						|  | from argparse import ZERO_OR_MORE | 
					
						
						|  | import math | 
					
						
						|  | import random | 
					
						
						|  | from torch.nn.modules.module import T | 
					
						
						|  |  | 
					
						
						|  | from transformers import PreTrainedModel | 
					
						
						|  | from .configuration_fsae import FSAEConfig | 
					
						
						|  |  | 
					
						
						|  | dt = 5 | 
					
						
						|  | a = 0.25 | 
					
						
						|  | aa = 0.5 | 
					
						
						|  | Vth = 0.2 | 
					
						
						|  | tau = 0.25 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class SpikeAct(torch.autograd.Function): | 
					
						
						|  | """ | 
					
						
						|  | Implementation of the spiking activation function with an approximation of gradient. | 
					
						
						|  | """ | 
					
						
						|  | @staticmethod | 
					
						
						|  | def forward(ctx, input): | 
					
						
						|  | ctx.save_for_backward(input) | 
					
						
						|  |  | 
					
						
						|  | output = torch.gt(input, Vth) | 
					
						
						|  | return output.float() | 
					
						
						|  |  | 
					
						
						|  | @staticmethod | 
					
						
						|  | def backward(ctx, grad_output): | 
					
						
						|  | input, = ctx.saved_tensors | 
					
						
						|  | grad_input = grad_output.clone() | 
					
						
						|  |  | 
					
						
						|  | hu = abs(input) < aa | 
					
						
						|  | hu = hu.float() / (2 * aa) | 
					
						
						|  | return grad_input * hu | 
					
						
						|  |  | 
					
						
						|  | class LIFSpike(nn.Module): | 
					
						
						|  | """ | 
					
						
						|  | Generates spikes based on LIF module. It can be considered as an activation function and is used similar to ReLU. The input tensor needs to have an additional time dimension, which in this case is on the last dimension of the data. | 
					
						
						|  | """ | 
					
						
						|  | def __init__(self): | 
					
						
						|  | super(LIFSpike, self).__init__() | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x): | 
					
						
						|  | nsteps = x.shape[-1] | 
					
						
						|  | u   = torch.zeros(x.shape[:-1] , device=x.device) | 
					
						
						|  | out = torch.zeros(x.shape, device=x.device) | 
					
						
						|  | for step in range(nsteps): | 
					
						
						|  | u, out[..., step] = self.state_update(u, out[..., max(step-1, 0)], x[..., step]) | 
					
						
						|  | return out | 
					
						
						|  |  | 
					
						
						|  | def state_update(self, u_t_n1, o_t_n1, W_mul_o_t1_n, tau=tau): | 
					
						
						|  | u_t1_n1 = tau * u_t_n1 * (1 - o_t_n1) + W_mul_o_t1_n | 
					
						
						|  | o_t1_n1 = SpikeAct.apply(u_t1_n1) | 
					
						
						|  | return u_t1_n1, o_t1_n1 | 
					
						
						|  |  | 
					
						
						|  | class tdLinear(nn.Linear): | 
					
						
						|  | def __init__(self, | 
					
						
						|  | in_features, | 
					
						
						|  | out_features, | 
					
						
						|  | bias=True, | 
					
						
						|  | bn=None, | 
					
						
						|  | spike=None): | 
					
						
						|  | assert type(in_features) == int, 'inFeatures should not be more than 1 dimesnion. It was: {}'.format(in_features.shape) | 
					
						
						|  | assert type(out_features) == int, 'outFeatures should not be more than 1 dimesnion. It was: {}'.format(out_features.shape) | 
					
						
						|  |  | 
					
						
						|  | super(tdLinear, self).__init__(in_features, out_features, bias=bias) | 
					
						
						|  |  | 
					
						
						|  | self.bn = bn | 
					
						
						|  | self.spike = spike | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x): | 
					
						
						|  | """ | 
					
						
						|  | x : (N,C,T) | 
					
						
						|  | """ | 
					
						
						|  | x = x.transpose(1, 2) | 
					
						
						|  | y = F.linear(x, self.weight, self.bias) | 
					
						
						|  | y = y.transpose(1, 2) | 
					
						
						|  |  | 
					
						
						|  | if self.bn is not None: | 
					
						
						|  | y = y[:,:,None,None,:] | 
					
						
						|  | y = self.bn(y) | 
					
						
						|  | y = y[:,:,0,0,:] | 
					
						
						|  | if self.spike is not None: | 
					
						
						|  | y = self.spike(y) | 
					
						
						|  | return y | 
					
						
						|  |  | 
					
						
						|  | class tdConv(nn.Conv3d): | 
					
						
						|  | def __init__(self, | 
					
						
						|  | in_channels, | 
					
						
						|  | out_channels, | 
					
						
						|  | kernel_size, | 
					
						
						|  | stride=1, | 
					
						
						|  | padding=0, | 
					
						
						|  | dilation=1, | 
					
						
						|  | groups=1, | 
					
						
						|  | bias=True, | 
					
						
						|  | bn=None, | 
					
						
						|  | spike=None, | 
					
						
						|  | is_first_conv=False): | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if type(kernel_size) == int: | 
					
						
						|  | kernel = (kernel_size, kernel_size, 1) | 
					
						
						|  | elif len(kernel_size) == 2: | 
					
						
						|  | kernel = (kernel_size[0], kernel_size[1], 1) | 
					
						
						|  | else: | 
					
						
						|  | raise Exception('kernelSize can only be of 1 or 2 dimension. It was: {}'.format(kernel_size.shape)) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if type(stride) == int: | 
					
						
						|  | stride = (stride, stride, 1) | 
					
						
						|  | elif len(stride) == 2: | 
					
						
						|  | stride = (stride[0], stride[1], 1) | 
					
						
						|  | else: | 
					
						
						|  | raise Exception('stride can be either int or tuple of size 2. It was: {}'.format(stride.shape)) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if type(padding) == int: | 
					
						
						|  | padding = (padding, padding, 0) | 
					
						
						|  | elif len(padding) == 2: | 
					
						
						|  | padding = (padding[0], padding[1], 0) | 
					
						
						|  | else: | 
					
						
						|  | raise Exception('padding can be either int or tuple of size 2. It was: {}'.format(padding.shape)) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if type(dilation) == int: | 
					
						
						|  | dilation = (dilation, dilation, 1) | 
					
						
						|  | elif len(dilation) == 2: | 
					
						
						|  | dilation = (dilation[0], dilation[1], 1) | 
					
						
						|  | else: | 
					
						
						|  | raise Exception('dilation can be either int or tuple of size 2. It was: {}'.format(dilation.shape)) | 
					
						
						|  |  | 
					
						
						|  | super(tdConv, self).__init__(in_channels, out_channels, kernel, stride, padding, dilation, groups, | 
					
						
						|  | bias=bias) | 
					
						
						|  | self.bn = bn | 
					
						
						|  | self.spike = spike | 
					
						
						|  | self.is_first_conv = is_first_conv | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x): | 
					
						
						|  | x = F.conv3d(x, self.weight, self.bias, | 
					
						
						|  | self.stride, self.padding, self.dilation, self.groups) | 
					
						
						|  | if self.bn is not None: | 
					
						
						|  | x = self.bn(x) | 
					
						
						|  | if self.spike is not None: | 
					
						
						|  | x = self.spike(x) | 
					
						
						|  | return x | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class tdConvTranspose(nn.ConvTranspose3d): | 
					
						
						|  | def __init__(self, | 
					
						
						|  | in_channels, | 
					
						
						|  | out_channels, | 
					
						
						|  | kernel_size, | 
					
						
						|  | stride=1, | 
					
						
						|  | padding=0, | 
					
						
						|  | output_padding=0, | 
					
						
						|  | dilation=1, | 
					
						
						|  | groups=1, | 
					
						
						|  | bias=True, | 
					
						
						|  | bn=None, | 
					
						
						|  | spike=None): | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if type(kernel_size) == int: | 
					
						
						|  | kernel = (kernel_size, kernel_size, 1) | 
					
						
						|  | elif len(kernel_size) == 2: | 
					
						
						|  | kernel = (kernel_size[0], kernel_size[1], 1) | 
					
						
						|  | else: | 
					
						
						|  | raise Exception('kernelSize can only be of 1 or 2 dimension. It was: {}'.format(kernel_size.shape)) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if type(stride) == int: | 
					
						
						|  | stride = (stride, stride, 1) | 
					
						
						|  | elif len(stride) == 2: | 
					
						
						|  | stride = (stride[0], stride[1], 1) | 
					
						
						|  | else: | 
					
						
						|  | raise Exception('stride can be either int or tuple of size 2. It was: {}'.format(stride.shape)) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if type(padding) == int: | 
					
						
						|  | padding = (padding, padding, 0) | 
					
						
						|  | elif len(padding) == 2: | 
					
						
						|  | padding = (padding[0], padding[1], 0) | 
					
						
						|  | else: | 
					
						
						|  | raise Exception('padding can be either int or tuple of size 2. It was: {}'.format(padding.shape)) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if type(dilation) == int: | 
					
						
						|  | dilation = (dilation, dilation, 1) | 
					
						
						|  | elif len(dilation) == 2: | 
					
						
						|  | dilation = (dilation[0], dilation[1], 1) | 
					
						
						|  | else: | 
					
						
						|  | raise Exception('dilation can be either int or tuple of size 2. It was: {}'.format(dilation.shape)) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if type(output_padding) == int: | 
					
						
						|  | output_padding = (output_padding, output_padding, 0) | 
					
						
						|  | elif len(output_padding) == 2: | 
					
						
						|  | output_padding = (output_padding[0], output_padding[1], 0) | 
					
						
						|  | else: | 
					
						
						|  | raise Exception('output_padding can be either int or tuple of size 2. It was: {}'.format(padding.shape)) | 
					
						
						|  |  | 
					
						
						|  | super().__init__(in_channels, out_channels, kernel, stride, padding, output_padding, groups, | 
					
						
						|  | bias=bias, dilation=dilation) | 
					
						
						|  |  | 
					
						
						|  | self.bn = bn | 
					
						
						|  | self.spike = spike | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x): | 
					
						
						|  | x = F.conv_transpose3d(x, self.weight, self.bias, | 
					
						
						|  | self.stride, self.padding, | 
					
						
						|  | self.output_padding, self.groups, self.dilation) | 
					
						
						|  |  | 
					
						
						|  | if self.bn is not None: | 
					
						
						|  | x = self.bn(x) | 
					
						
						|  | if self.spike is not None: | 
					
						
						|  | x = self.spike(x) | 
					
						
						|  | return x | 
					
						
						|  |  | 
					
						
						|  | class tdBatchNorm(nn.BatchNorm2d): | 
					
						
						|  | """ | 
					
						
						|  | Implementation of tdBN. Link to related paper: https://arxiv.org/pdf/2011.05280. In short it is averaged over the time domain as well when doing BN. | 
					
						
						|  | Args: | 
					
						
						|  | num_features (int): same with nn.BatchNorm2d | 
					
						
						|  | eps (float): same with nn.BatchNorm2d | 
					
						
						|  | momentum (float): same with nn.BatchNorm2d | 
					
						
						|  | alpha (float): an addtional parameter which may change in resblock. | 
					
						
						|  | affine (bool): same with nn.BatchNorm2d | 
					
						
						|  | track_running_stats (bool): same with nn.BatchNorm2d | 
					
						
						|  | """ | 
					
						
						|  | def __init__(self, num_features, eps=1e-05, momentum=0.1, alpha=1, affine=True, track_running_stats=True): | 
					
						
						|  | super(tdBatchNorm, self).__init__( | 
					
						
						|  | num_features, eps, momentum, affine, track_running_stats) | 
					
						
						|  | self.alpha = alpha | 
					
						
						|  |  | 
					
						
						|  | def forward(self, input): | 
					
						
						|  | exponential_average_factor = 0.0 | 
					
						
						|  |  | 
					
						
						|  | if self.training and self.track_running_stats: | 
					
						
						|  | if self.num_batches_tracked is not None: | 
					
						
						|  | self.num_batches_tracked += 1 | 
					
						
						|  | if self.momentum is None: | 
					
						
						|  | exponential_average_factor = 1.0 / float(self.num_batches_tracked) | 
					
						
						|  | else: | 
					
						
						|  | exponential_average_factor = self.momentum | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if self.training: | 
					
						
						|  | mean = input.mean([0, 2, 3, 4]) | 
					
						
						|  |  | 
					
						
						|  | var = input.var([0, 2, 3, 4], unbiased=False) | 
					
						
						|  | n = input.numel() / input.size(1) | 
					
						
						|  | with torch.no_grad(): | 
					
						
						|  | self.running_mean = exponential_average_factor * mean\ | 
					
						
						|  | + (1 - exponential_average_factor) * self.running_mean | 
					
						
						|  |  | 
					
						
						|  | self.running_var = exponential_average_factor * var * n / (n - 1)\ | 
					
						
						|  | + (1 - exponential_average_factor) * self.running_var | 
					
						
						|  | else: | 
					
						
						|  | mean = self.running_mean | 
					
						
						|  | var = self.running_var | 
					
						
						|  |  | 
					
						
						|  | input = self.alpha * Vth * (input - mean[None, :, None, None, None]) / (torch.sqrt(var[None, :, None, None, None] + self.eps)) | 
					
						
						|  | if self.affine: | 
					
						
						|  | input = input * self.weight[None, :, None, None, None] + self.bias[None, :, None, None, None] | 
					
						
						|  |  | 
					
						
						|  | return input | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class PSP(torch.nn.Module): | 
					
						
						|  | def __init__(self): | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.tau_s = 2 | 
					
						
						|  |  | 
					
						
						|  | def forward(self, inputs): | 
					
						
						|  | """ | 
					
						
						|  | inputs: (N, C, T) | 
					
						
						|  | """ | 
					
						
						|  | syns = None | 
					
						
						|  | syn = 0 | 
					
						
						|  | n_steps = inputs.shape[-1] | 
					
						
						|  | for t in range(n_steps): | 
					
						
						|  | syn = syn + (inputs[...,t] - syn) / self.tau_s | 
					
						
						|  | if syns is None: | 
					
						
						|  | syns = syn.unsqueeze(-1) | 
					
						
						|  | else: | 
					
						
						|  | syns = torch.cat([syns, syn.unsqueeze(-1)], dim=-1) | 
					
						
						|  |  | 
					
						
						|  | return syns | 
					
						
						|  |  | 
					
						
						|  | class MembraneOutputLayer(nn.Module): | 
					
						
						|  | """ | 
					
						
						|  | outputs the last time membrane potential of the LIF neuron with V_th=infty | 
					
						
						|  | """ | 
					
						
						|  | def __init__(self) -> None: | 
					
						
						|  | super().__init__() | 
					
						
						|  |  | 
					
						
						|  | n_steps = 16 | 
					
						
						|  |  | 
					
						
						|  | arr = torch.arange(n_steps-1,-1,-1) | 
					
						
						|  | self.register_buffer("coef", torch.pow(0.8, arr)[None,None,None,None,:]) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x): | 
					
						
						|  | """ | 
					
						
						|  | x : (N,C,H,W,T) | 
					
						
						|  | """ | 
					
						
						|  | out = torch.sum(x*self.coef, dim=-1) | 
					
						
						|  | return out | 
					
						
						|  |  | 
					
						
						|  | class PriorBernoulliSTBP(nn.Module): | 
					
						
						|  | def __init__(self, k=20) -> None: | 
					
						
						|  | """ | 
					
						
						|  | modeling of p(z_t|z_<t) | 
					
						
						|  | """ | 
					
						
						|  | super().__init__() | 
					
						
						|  |  | 
					
						
						|  | self.channels = 128 | 
					
						
						|  | self.k = k | 
					
						
						|  |  | 
					
						
						|  | self.n_steps = 16 | 
					
						
						|  |  | 
					
						
						|  | self.layers = nn.Sequential( | 
					
						
						|  | tdLinear(self.channels, | 
					
						
						|  | self.channels*2, | 
					
						
						|  | bias=True, | 
					
						
						|  | bn=tdBatchNorm(self.channels*2, alpha=2), | 
					
						
						|  | spike=LIFSpike()), | 
					
						
						|  | tdLinear(self.channels*2, | 
					
						
						|  | self.channels*4, | 
					
						
						|  | bias=True, | 
					
						
						|  | bn=tdBatchNorm(self.channels*4, alpha=2), | 
					
						
						|  | spike=LIFSpike()), | 
					
						
						|  | tdLinear(self.channels*4, | 
					
						
						|  | self.channels*k, | 
					
						
						|  | bias=True, | 
					
						
						|  | bn=tdBatchNorm(self.channels*k, alpha=2), | 
					
						
						|  | spike=LIFSpike()) | 
					
						
						|  | ) | 
					
						
						|  | self.register_buffer('initial_input', torch.zeros(1, self.channels, 1)) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def forward(self, z, scheduled=False, p=None): | 
					
						
						|  | if scheduled: | 
					
						
						|  | return self._forward_scheduled_sampling(z, p) | 
					
						
						|  | else: | 
					
						
						|  | return self._forward(z) | 
					
						
						|  |  | 
					
						
						|  | def _forward(self, z): | 
					
						
						|  | """ | 
					
						
						|  | input z: (B,C,T) # latent spike sampled from posterior | 
					
						
						|  | output : (B,C,k,T) # indicates p(z_t|z_<t) (t=1,...,T) | 
					
						
						|  | """ | 
					
						
						|  | z_shape = z.shape | 
					
						
						|  | batch_size = z_shape[0] | 
					
						
						|  | z = z.detach() | 
					
						
						|  |  | 
					
						
						|  | z0 = self.initial_input.repeat(batch_size, 1, 1) | 
					
						
						|  | inputs = torch.cat([z0, z[...,:-1]], dim=-1) | 
					
						
						|  | outputs = self.layers(inputs) | 
					
						
						|  |  | 
					
						
						|  | p_z = outputs.view(batch_size, self.channels, self.k, self.n_steps) | 
					
						
						|  | return p_z | 
					
						
						|  |  | 
					
						
						|  | def _forward_scheduled_sampling(self, z, p): | 
					
						
						|  | """ | 
					
						
						|  | use scheduled sampling | 
					
						
						|  | input | 
					
						
						|  | z: (B,C,T) # latent spike sampled from posterior | 
					
						
						|  | p: float # prob of scheduled sampling | 
					
						
						|  | output : (B,C,k,T) # indicates p(z_t|z_<t) (t=1,...,T) | 
					
						
						|  | """ | 
					
						
						|  | z_shape = z.shape | 
					
						
						|  | batch_size = z_shape[0] | 
					
						
						|  | z = z.detach() | 
					
						
						|  |  | 
					
						
						|  | z_t_minus = self.initial_input.repeat(batch_size,1,1) | 
					
						
						|  | if self.training: | 
					
						
						|  | with torch.no_grad(): | 
					
						
						|  | for t in range(self.n_steps-1): | 
					
						
						|  | if t>=5 and random.random() < p: | 
					
						
						|  | outputs = self.layers(z_t_minus.detach()) | 
					
						
						|  | p_z_t = outputs[...,-1] | 
					
						
						|  |  | 
					
						
						|  | prob1 = p_z_t.view(batch_size, self.channels, self.k).mean(-1) | 
					
						
						|  | prob1 = prob1 + 1e-3 * torch.randn_like(prob1) | 
					
						
						|  | z_t = (prob1>0.5).float() | 
					
						
						|  | z_t = z_t.view(batch_size, self.channels, 1) | 
					
						
						|  | z_t_minus = torch.cat([z_t_minus, z_t], dim=-1) | 
					
						
						|  | else: | 
					
						
						|  | z_t_minus = torch.cat([z_t_minus, z[...,t].unsqueeze(-1)], dim=-1) | 
					
						
						|  | else: | 
					
						
						|  | z_t_minus = torch.cat([z_t_minus, z[:,:,:-1]], dim=-1) | 
					
						
						|  |  | 
					
						
						|  | z_t_minus = z_t_minus.detach() | 
					
						
						|  | p_z = self.layers(z_t_minus) | 
					
						
						|  | p_z = p_z.view(batch_size, self.channels, self.k, self.n_steps) | 
					
						
						|  | return p_z | 
					
						
						|  |  | 
					
						
						|  | def sample(self, batch_size=64): | 
					
						
						|  | z_minus_t = self.initial_input.repeat(batch_size, 1, 1) | 
					
						
						|  | for t in range(self.n_steps): | 
					
						
						|  | outputs = self.layers(z_minus_t) | 
					
						
						|  | p_z_t = outputs[...,-1] | 
					
						
						|  |  | 
					
						
						|  | random_index = torch.randint(0, self.k, (batch_size*self.channels,)) \ | 
					
						
						|  | + torch.arange(start=0, end=batch_size*self.channels*self.k, step=self.k) | 
					
						
						|  | random_index = random_index.to(z_minus_t.device) | 
					
						
						|  |  | 
					
						
						|  | z_t = p_z_t.view(batch_size*self.channels*self.k)[random_index] | 
					
						
						|  | z_t = z_t.view(batch_size, self.channels, 1) | 
					
						
						|  | z_minus_t = torch.cat([z_minus_t, z_t], dim=-1) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | sampled_z = z_minus_t[...,1:] | 
					
						
						|  |  | 
					
						
						|  | return sampled_z | 
					
						
						|  |  | 
					
						
						|  | class PosteriorBernoulliSTBP(nn.Module): | 
					
						
						|  | def __init__(self, k=20) -> None: | 
					
						
						|  | """ | 
					
						
						|  | modeling of q(z_t | x_<=t, z_<t) | 
					
						
						|  | """ | 
					
						
						|  | super().__init__() | 
					
						
						|  |  | 
					
						
						|  | self.channels = 128 | 
					
						
						|  | self.k = k | 
					
						
						|  |  | 
					
						
						|  | self.n_steps = 16 | 
					
						
						|  |  | 
					
						
						|  | self.layers = nn.Sequential( | 
					
						
						|  | tdLinear(self.channels*2, | 
					
						
						|  | self.channels*2, | 
					
						
						|  | bias=True, | 
					
						
						|  | bn=tdBatchNorm(self.channels*2, alpha=2), | 
					
						
						|  | spike=LIFSpike()), | 
					
						
						|  | tdLinear(self.channels*2, | 
					
						
						|  | self.channels*4, | 
					
						
						|  | bias=True, | 
					
						
						|  | bn=tdBatchNorm(self.channels*4, alpha=2), | 
					
						
						|  | spike=LIFSpike()), | 
					
						
						|  | tdLinear(self.channels*4, | 
					
						
						|  | self.channels*k, | 
					
						
						|  | bias=True, | 
					
						
						|  | bn=tdBatchNorm(self.channels*k, alpha=2), | 
					
						
						|  | spike=LIFSpike()) | 
					
						
						|  | ) | 
					
						
						|  | self.register_buffer('initial_input', torch.zeros(1, self.channels, 1)) | 
					
						
						|  |  | 
					
						
						|  | self.is_true_scheduled_sampling = True | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x): | 
					
						
						|  | """ | 
					
						
						|  | input: | 
					
						
						|  | x:(B,C,T) | 
					
						
						|  | returns: | 
					
						
						|  | sampled_z:(B,C,T) | 
					
						
						|  | q_z: (B,C,k,T) # indicates q(z_t | x_<=t, z_<t) (t=1,...,T) | 
					
						
						|  | """ | 
					
						
						|  | x_shape = x.shape | 
					
						
						|  | batch_size=x_shape[0] | 
					
						
						|  | random_indices = [] | 
					
						
						|  |  | 
					
						
						|  | with torch.no_grad(): | 
					
						
						|  | z_t_minus = self.initial_input.repeat(x_shape[0],1,1) | 
					
						
						|  | for t in range(self.n_steps-1): | 
					
						
						|  | inputs = torch.cat([x[...,:t+1].detach(), z_t_minus.detach()], dim=1) | 
					
						
						|  | outputs = self.layers(inputs) | 
					
						
						|  | q_z_t = outputs[...,-1] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | random_index = torch.randint(0, self.k, (batch_size*self.channels,)) \ | 
					
						
						|  | + torch.arange(start=0, end=batch_size*self.channels*self.k, step=self.k) | 
					
						
						|  | random_index = random_index.to(x.device) | 
					
						
						|  | random_indices.append(random_index) | 
					
						
						|  |  | 
					
						
						|  | z_t = q_z_t.view(batch_size*self.channels*self.k)[random_index] | 
					
						
						|  | z_t = z_t.view(batch_size, self.channels, 1) | 
					
						
						|  |  | 
					
						
						|  | z_t_minus = torch.cat([z_t_minus, z_t], dim=-1) | 
					
						
						|  |  | 
					
						
						|  | z_t_minus = z_t_minus.detach() | 
					
						
						|  | q_z = self.layers(torch.cat([x, z_t_minus], dim=1)) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | sampled_z = None | 
					
						
						|  | for t in range(self.n_steps): | 
					
						
						|  |  | 
					
						
						|  | if t == self.n_steps-1: | 
					
						
						|  |  | 
					
						
						|  | random_index = torch.randint(0, self.k, (batch_size*self.channels,)) \ | 
					
						
						|  | + torch.arange(start=0, end=batch_size*self.channels*self.k, step=self.k) | 
					
						
						|  | random_indices.append(random_index) | 
					
						
						|  | else: | 
					
						
						|  |  | 
					
						
						|  | random_index = random_indices[t] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | sampled_z_t = q_z[...,t].view(batch_size*self.channels*self.k)[random_index] | 
					
						
						|  | sampled_z_t = sampled_z_t.view(batch_size, self.channels, 1) | 
					
						
						|  | if t==0: | 
					
						
						|  | sampled_z = sampled_z_t | 
					
						
						|  | else: | 
					
						
						|  | sampled_z = torch.cat([sampled_z, sampled_z_t], dim=-1) | 
					
						
						|  |  | 
					
						
						|  | q_z = q_z.view(batch_size, self.channels, self.k, self.n_steps) | 
					
						
						|  |  | 
					
						
						|  | return sampled_z, q_z | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class FSAEModel(PreTrainedModel): | 
					
						
						|  | config_class = FSAEConfig | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, config): | 
					
						
						|  | super().__init__(config) | 
					
						
						|  |  | 
					
						
						|  | self.in_channels = config.in_channels | 
					
						
						|  | in_channels = self.in_channels | 
					
						
						|  |  | 
					
						
						|  | self.hidden_dims  = config.hidden_dims | 
					
						
						|  | hidden_dims = self.hidden_dims | 
					
						
						|  |  | 
					
						
						|  | self.latent_dim = config.latent_dim | 
					
						
						|  | latent_dim = self.latent_dim | 
					
						
						|  |  | 
					
						
						|  | self.n_steps = config.n_steps | 
					
						
						|  | n_steps = self.n_steps | 
					
						
						|  |  | 
					
						
						|  | self.k = config.k | 
					
						
						|  | k = self.k | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | modules = [] | 
					
						
						|  | is_first_conv = True | 
					
						
						|  | for h_dim in hidden_dims: | 
					
						
						|  | modules.append( | 
					
						
						|  | tdConv( | 
					
						
						|  | in_channels, | 
					
						
						|  | out_channels=h_dim, | 
					
						
						|  | kernel_size=3, | 
					
						
						|  | stride=2, | 
					
						
						|  | padding=1, | 
					
						
						|  | bias=True, | 
					
						
						|  | bn=tdBatchNorm(h_dim), | 
					
						
						|  | spike=LIFSpike(), | 
					
						
						|  | is_first_conv=is_first_conv, | 
					
						
						|  | ) | 
					
						
						|  | ) | 
					
						
						|  | in_channels = h_dim | 
					
						
						|  | is_first_conv = False | 
					
						
						|  |  | 
					
						
						|  | self.encoder = nn.Sequential(*modules) | 
					
						
						|  | self.before_latent_layer = tdLinear( | 
					
						
						|  | hidden_dims[-1] * 4, | 
					
						
						|  | latent_dim, | 
					
						
						|  | bias=True, | 
					
						
						|  | bn=tdBatchNorm(latent_dim), | 
					
						
						|  | spike=LIFSpike(), | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | modules = [] | 
					
						
						|  |  | 
					
						
						|  | self.decoder_input = tdLinear( | 
					
						
						|  | latent_dim, | 
					
						
						|  | hidden_dims[-1] * 4, | 
					
						
						|  | bias=True, | 
					
						
						|  | bn=tdBatchNorm(hidden_dims[-1] * 4), | 
					
						
						|  | spike=LIFSpike(), | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | hidden_reverse = hidden_dims[::-1] | 
					
						
						|  |  | 
					
						
						|  | for i in range(len(hidden_reverse) - 1): | 
					
						
						|  | modules.append( | 
					
						
						|  | tdConvTranspose( | 
					
						
						|  | hidden_reverse[i], | 
					
						
						|  | hidden_reverse[i + 1], | 
					
						
						|  | kernel_size=3, | 
					
						
						|  | stride=2, | 
					
						
						|  | padding=1, | 
					
						
						|  | output_padding=1, | 
					
						
						|  | bias=True, | 
					
						
						|  | bn=tdBatchNorm(hidden_reverse[i + 1]), | 
					
						
						|  | spike=LIFSpike(), | 
					
						
						|  | ) | 
					
						
						|  | ) | 
					
						
						|  | self.decoder = nn.Sequential(*modules) | 
					
						
						|  |  | 
					
						
						|  | self.final_layer = nn.Sequential( | 
					
						
						|  | tdConvTranspose( | 
					
						
						|  | hidden_reverse[-1], | 
					
						
						|  | hidden_reverse[-1], | 
					
						
						|  | kernel_size=3, | 
					
						
						|  | stride=2, | 
					
						
						|  | padding=1, | 
					
						
						|  | output_padding=1, | 
					
						
						|  | bias=True, | 
					
						
						|  | bn=tdBatchNorm(hidden_reverse[-1]), | 
					
						
						|  | spike=LIFSpike(), | 
					
						
						|  | ), | 
					
						
						|  | tdConvTranspose( | 
					
						
						|  | hidden_reverse[-1], | 
					
						
						|  | out_channels=1, | 
					
						
						|  | kernel_size=3, | 
					
						
						|  | padding=1, | 
					
						
						|  | bias=True, | 
					
						
						|  | bn=None, | 
					
						
						|  | spike=None, | 
					
						
						|  | ), | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | self.p = 0 | 
					
						
						|  |  | 
					
						
						|  | self.membrane_output_layer = MembraneOutputLayer() | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x, scheduled=False): | 
					
						
						|  | sampled_z = self.encode(x, scheduled) | 
					
						
						|  | x_recon = self.decode(sampled_z) | 
					
						
						|  | return x_recon, sampled_z | 
					
						
						|  |  | 
					
						
						|  | def encode(self, x, scheduled=False): | 
					
						
						|  | x = self.encoder(x) | 
					
						
						|  | x = torch.flatten(x, start_dim=1, end_dim=3) | 
					
						
						|  | latent_x = self.before_latent_layer(x) | 
					
						
						|  | return latent_x | 
					
						
						|  |  | 
					
						
						|  | def decode(self, z): | 
					
						
						|  | result = self.decoder_input(z) | 
					
						
						|  | result = result.view( | 
					
						
						|  | result.shape[0], self.hidden_dims[-1], 2, 2, self.n_steps | 
					
						
						|  | ) | 
					
						
						|  | result = self.decoder(result) | 
					
						
						|  | result = self.final_layer(result) | 
					
						
						|  | out = torch.tanh(self.membrane_output_layer(result)) | 
					
						
						|  | return out | 
					
						
						|  |  | 
					
						
						|  | def sample(self, batch_size=64): | 
					
						
						|  | raise NotImplementedError() | 
					
						
						|  |  | 
					
						
						|  | def loss_function(self, recons_img, input_img): | 
					
						
						|  | """ | 
					
						
						|  | Computes the VAE loss function. | 
					
						
						|  | KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2} | 
					
						
						|  | :param args: | 
					
						
						|  | :param kwargs: | 
					
						
						|  | :return: | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | recons_loss = F.mse_loss(recons_img, input_img) | 
					
						
						|  |  | 
					
						
						|  | return recons_loss |