import math import scipy import torch import torch.nn.functional as F from torch import nn, einsum from functools import partial from einops import rearrange, reduce from scipy.fftpack import next_fast_len def exists(x): return x is not None def default(val, d): if exists(val): return val return d() if callable(d) else d def identity(t, *args, **kwargs): return t def extract(a, t, x_shape): b, *_ = t.shape out = a.gather(-1, t) return out.reshape(b, *((1,) * (len(x_shape) - 1))) def Upsample(dim, dim_out=None): return nn.Sequential( nn.Upsample(scale_factor=2, mode="nearest"), nn.Conv1d(dim, default(dim_out, dim), 3, padding=1), ) def Downsample(dim, dim_out=None): return nn.Conv1d(dim, default(dim_out, dim), 4, 2, 1) # normalization functions def normalize_to_neg_one_to_one(x): return x * 2 - 1 def unnormalize_to_zero_to_one(x): return (x + 1) * 0.5 # sinusoidal positional embeds class SinusoidalPosEmb(nn.Module): def __init__(self, dim): super().__init__() self.dim = dim def forward(self, x): device = x.device half_dim = self.dim // 2 emb = math.log(10000) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, device=device) * -emb) emb = x[:, None] * emb[None, :] emb = torch.cat((emb.sin(), emb.cos()), dim=-1) return emb # learnable positional embeds class LearnablePositionalEncoding(nn.Module): def __init__(self, d_model, dropout=0.1, max_len=1024): super(LearnablePositionalEncoding, self).__init__() self.dropout = nn.Dropout(p=dropout) # Each position gets its own embedding # Since indices are always 0 ... max_len, we don't have to do a look-up self.pe = nn.Parameter( torch.empty(1, max_len, d_model) ) # requires_grad automatically set to True nn.init.uniform_(self.pe, -0.02, 0.02) def forward(self, x): r"""Inputs of forward function Args: x: the sequence fed to the positional encoder model (required). Shape: x: [batch size, sequence length, embed dim] output: [batch size, sequence length, embed dim] """ # print(x.shape) x = x + self.pe return self.dropout(x) class moving_avg(nn.Module): """ Moving average block to highlight the trend of time series """ def __init__(self, kernel_size, stride): super(moving_avg, self).__init__() self.kernel_size = kernel_size self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0) def forward(self, x): # padding on the both ends of time series front = x[:, 0:1, :].repeat( 1, self.kernel_size - 1 - math.floor((self.kernel_size - 1) // 2), 1 ) end = x[:, -1:, :].repeat(1, math.floor((self.kernel_size - 1) // 2), 1) x = torch.cat([front, x, end], dim=1) x = self.avg(x.permute(0, 2, 1)) x = x.permute(0, 2, 1) return x class series_decomp(nn.Module): """ Series decomposition block """ def __init__(self, kernel_size): super(series_decomp, self).__init__() self.moving_avg = moving_avg(kernel_size, stride=1) def forward(self, x): moving_mean = self.moving_avg(x) res = x - moving_mean return res, moving_mean class series_decomp_multi(nn.Module): """ Series decomposition block """ def __init__(self, kernel_size): super(series_decomp_multi, self).__init__() self.moving_avg = [moving_avg(kernel, stride=1) for kernel in kernel_size] self.layer = torch.nn.Linear(1, len(kernel_size)) def forward(self, x): moving_mean = [] for func in self.moving_avg: moving_avg = func(x) moving_mean.append(moving_avg.unsqueeze(-1)) moving_mean = torch.cat(moving_mean, dim=-1) moving_mean = torch.sum( moving_mean * nn.Softmax(-1)(self.layer(x.unsqueeze(-1))), dim=-1 ) res = x - moving_mean return res, moving_mean class Transpose(nn.Module): """Wrapper class of torch.transpose() for Sequential module.""" def __init__(self, shape: tuple): super(Transpose, self).__init__() self.shape = shape def forward(self, x): return x.transpose(*self.shape) class Conv_MLP(nn.Module): def __init__(self, in_dim, out_dim, resid_pdrop=0.0): super().__init__() self.sequential = nn.Sequential( Transpose(shape=(1, 2)), nn.Conv1d(in_dim, out_dim, 3, stride=1, padding=1), nn.Dropout(p=resid_pdrop), ) def forward(self, x): return self.sequential(x).transpose(1, 2) class Transformer_MLP(nn.Module): def __init__(self, n_embd, mlp_hidden_times, act, resid_pdrop): super().__init__() self.sequential = nn.Sequential( nn.Conv1d( in_channels=n_embd, out_channels=int(mlp_hidden_times * n_embd), kernel_size=1, padding=0, ), act, nn.Conv1d( in_channels=int(mlp_hidden_times * n_embd), out_channels=int(mlp_hidden_times * n_embd), kernel_size=3, padding=1, ), act, nn.Conv1d( in_channels=int(mlp_hidden_times * n_embd), out_channels=n_embd, kernel_size=3, padding=1, ), nn.Dropout(p=resid_pdrop), ) def forward(self, x): return self.sequential(x) class GELU2(nn.Module): def __init__(self): super().__init__() def forward(self, x): return x * F.sigmoid(1.702 * x) class AdaLayerNorm(nn.Module): def __init__(self, n_embd): super().__init__() self.emb = SinusoidalPosEmb(n_embd) self.silu = nn.SiLU() self.linear = nn.Linear(n_embd, n_embd * 2) self.layernorm = nn.LayerNorm(n_embd, elementwise_affine=False) def forward(self, x, timestep, label_emb=None): emb = self.emb(timestep) if label_emb is not None: emb = emb + label_emb emb = self.linear(self.silu(emb)).unsqueeze(1) scale, shift = torch.chunk(emb, 2, dim=2) x = self.layernorm(x) * (1 + scale) + shift return x class AdaInsNorm(nn.Module): def __init__(self, n_embd): super().__init__() self.emb = SinusoidalPosEmb(n_embd) self.silu = nn.SiLU() self.linear = nn.Linear(n_embd, n_embd * 2) self.instancenorm = nn.InstanceNorm1d(n_embd) def forward(self, x, timestep, label_emb=None): emb = self.emb(timestep) if label_emb is not None: emb = emb + label_emb emb = self.linear(self.silu(emb)).unsqueeze(1) scale, shift = torch.chunk(emb, 2, dim=2) x = ( self.instancenorm(x.transpose(-1, -2)).transpose(-1, -2) * (1 + scale) + shift ) return x