PeterYu's picture
update
2875fe6
import math
import torch
import numpy as np
import torch.nn.functional as F
from torch import nn
from einops import rearrange, reduce, repeat
from .model_utils import (
LearnablePositionalEncoding,
Conv_MLP,
AdaLayerNorm,
Transpose,
GELU2,
series_decomp,
)
class TrendBlock(nn.Module):
"""
Model trend of time series using the polynomial regressor.
"""
def __init__(self, in_dim, out_dim, in_feat, out_feat, act):
super(TrendBlock, self).__init__()
trend_poly = 3
self.trend = nn.Sequential(
nn.Conv1d(
in_channels=in_dim, out_channels=trend_poly, kernel_size=3, padding=1
),
act,
Transpose(shape=(1, 2)),
nn.Conv1d(in_feat, out_feat, 3, stride=1, padding=1),
)
lin_space = torch.arange(1, out_dim + 1, 1) / (out_dim + 1)
self.poly_space = torch.stack(
[lin_space ** float(p + 1) for p in range(trend_poly)], dim=0
)
def forward(self, input):
b, c, h = input.shape
x = self.trend(input).transpose(1, 2)
trend_vals = torch.matmul(x.transpose(1, 2), self.poly_space.to(x.device))
trend_vals = trend_vals.transpose(1, 2)
return trend_vals
class MovingBlock(nn.Module):
"""
Model trend of time series using the moving average.
"""
def __init__(self, out_dim):
super(MovingBlock, self).__init__()
size = max(min(int(out_dim / 4), 24), 4)
self.decomp = series_decomp(size)
def forward(self, input):
b, c, h = input.shape
x, trend_vals = self.decomp(input)
return x, trend_vals
class FourierLayer(nn.Module):
"""
Model seasonality of time series using the inverse DFT.
"""
def __init__(self, d_model, low_freq=1, factor=1):
super().__init__()
self.d_model = d_model
self.factor = factor
self.low_freq = low_freq
def forward(self, x):
"""x: (b, t, d)"""
# x = x.to("cpu") if torch.backends.mps.is_available() else x
b, t, d = x.shape
x_freq = torch.fft.rfft(x, dim=1)
if t % 2 == 0:
x_freq = x_freq[:, self.low_freq : -1]
f = torch.fft.rfftfreq(t)[self.low_freq : -1]
else:
x_freq = x_freq[:, self.low_freq :]
f = torch.fft.rfftfreq(t)[self.low_freq :]
x_freq, index_tuple = self.topk_freq(x_freq)
f = repeat(f, "f -> b f d", b=x_freq.size(0), d=x_freq.size(2)).to(
x_freq.device
)
f = rearrange(f[index_tuple], "b f d -> b f () d").to(x_freq.device)
result = self.extrapolate(x_freq, f, t)
return result
# return result.to("mps") if torch.backends.mps.is_available() else result
def extrapolate(self, x_freq, f, t):
x_freq = torch.cat([x_freq, x_freq.conj()], dim=1)
f = torch.cat([f, -f], dim=1)
t = rearrange(torch.arange(t, dtype=torch.float), "t -> () () t ()").to(
x_freq.device
)
amp = rearrange(x_freq.abs(), "b f d -> b f () d")
phase = rearrange(x_freq.angle(), "b f d -> b f () d")
# x_freq_angle = x_freq.cpu().angle().to(x_freq.device)
# print(x_freq.device, x_freq.shape)
# def angle(x): return torch.atan2(x.imag, x.real)
# print(x_freq.angle().type(), x_freq.angle().device, x_freq.angle().shape)
# print(angle(x_freq).type(), angle(x_freq).device, angle(x_freq).shape)
# phase = rearrange(angle(x_freq).float(), 'b f d -> b f () d')
x_time = amp * torch.cos(2 * math.pi * f * t + phase)
return reduce(x_time, "b f t d -> b t d", "sum")
def topk_freq(self, x_freq):
length = x_freq.shape[1]
top_k = int(self.factor * math.log(length))
values, indices = torch.topk(
x_freq.abs(), top_k, dim=1, largest=True, sorted=True
)
mesh_a, mesh_b = torch.meshgrid(
torch.arange(x_freq.size(0)), torch.arange(x_freq.size(2)), indexing="ij"
)
index_tuple = (mesh_a.unsqueeze(1), indices, mesh_b.unsqueeze(1))
x_freq = x_freq[index_tuple]
return x_freq, index_tuple
class SeasonBlock(nn.Module):
"""
Model seasonality of time series using the Fourier series.
"""
def __init__(self, in_dim, out_dim, factor=1):
super(SeasonBlock, self).__init__()
season_poly = factor * min(32, int(out_dim // 2))
self.season = nn.Conv1d(
in_channels=in_dim, out_channels=season_poly, kernel_size=1, padding=0
)
fourier_space = torch.arange(0, out_dim, 1) / out_dim
p1, p2 = (
(season_poly // 2, season_poly // 2)
if season_poly % 2 == 0
else (season_poly // 2, season_poly // 2 + 1)
)
s1 = torch.stack(
[torch.cos(2 * np.pi * p * fourier_space) for p in range(1, p1 + 1)], dim=0
)
s2 = torch.stack(
[torch.sin(2 * np.pi * p * fourier_space) for p in range(1, p2 + 1)], dim=0
)
self.poly_space = torch.cat([s1, s2])
def forward(self, input):
b, c, h = input.shape
x = self.season(input)
season_vals = torch.matmul(x.transpose(1, 2), self.poly_space.to(x.device))
season_vals = season_vals.transpose(1, 2)
return season_vals
class FullAttention(nn.Module):
def __init__(
self,
n_embd, # the embed dim
n_head, # the number of heads
attn_pdrop=0.1, # attention dropout prob
resid_pdrop=0.1, # residual attention dropout prob
):
super().__init__()
assert n_embd % n_head == 0
# key, query, value projections for all heads
self.key = nn.Linear(n_embd, n_embd)
self.query = nn.Linear(n_embd, n_embd)
self.value = nn.Linear(n_embd, n_embd)
# regularization
self.attn_drop = nn.Dropout(attn_pdrop)
self.resid_drop = nn.Dropout(resid_pdrop)
# output projection
self.proj = nn.Linear(n_embd, n_embd)
self.n_head = n_head
def forward(self, x, mask=None):
B, T, C = x.size()
k = (
self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
) # (B, nh, T, hs)
q = (
self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
) # (B, nh, T, hs)
v = (
self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
) # (B, nh, T, hs)
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) # (B, nh, T, T)
att = F.softmax(att, dim=-1) # (B, nh, T, T)
att = self.attn_drop(att)
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
y = (
y.transpose(1, 2).contiguous().view(B, T, C)
) # re-assemble all head outputs side by side, (B, T, C)
att = att.mean(dim=1, keepdim=False) # (B, T, T)
# output projection
y = self.resid_drop(self.proj(y))
return y, att
class CrossAttention(nn.Module):
def __init__(
self,
n_embd, # the embed dim
condition_embd, # condition dim
n_head, # the number of heads
attn_pdrop=0.1, # attention dropout prob
resid_pdrop=0.1, # residual attention dropout prob
):
super().__init__()
assert n_embd % n_head == 0
# key, query, value projections for all heads
self.key = nn.Linear(condition_embd, n_embd)
self.query = nn.Linear(n_embd, n_embd)
self.value = nn.Linear(condition_embd, n_embd)
# regularization
self.attn_drop = nn.Dropout(attn_pdrop)
self.resid_drop = nn.Dropout(resid_pdrop)
# output projection
self.proj = nn.Linear(n_embd, n_embd)
self.n_head = n_head
def forward(self, x, encoder_output, mask=None):
B, T, C = x.size()
B, T_E, _ = encoder_output.size()
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
k = (
self.key(encoder_output)
.view(B, T_E, self.n_head, C // self.n_head)
.transpose(1, 2)
) # (B, nh, T, hs)
q = (
self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
) # (B, nh, T, hs)
v = (
self.value(encoder_output)
.view(B, T_E, self.n_head, C // self.n_head)
.transpose(1, 2)
) # (B, nh, T, hs)
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) # (B, nh, T, T)
att = F.softmax(att, dim=-1) # (B, nh, T, T)
att = self.attn_drop(att)
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
y = (
y.transpose(1, 2).contiguous().view(B, T, C)
) # re-assemble all head outputs side by side, (B, T, C)
att = att.mean(dim=1, keepdim=False) # (B, T, T)
# output projection
y = self.resid_drop(self.proj(y))
return y, att
class EncoderBlock(nn.Module):
"""an unassuming Transformer block"""
def __init__(
self,
n_embd=1024,
n_head=16,
attn_pdrop=0.1,
resid_pdrop=0.1,
mlp_hidden_times=4,
activate="GELU",
):
super().__init__()
self.ln1 = AdaLayerNorm(n_embd)
self.ln2 = nn.LayerNorm(n_embd)
self.attn = FullAttention(
n_embd=n_embd,
n_head=n_head,
attn_pdrop=attn_pdrop,
resid_pdrop=resid_pdrop,
)
assert activate in ["GELU", "GELU2"]
act = nn.GELU() if activate == "GELU" else GELU2()
self.mlp = nn.Sequential(
nn.Linear(n_embd, mlp_hidden_times * n_embd),
act,
nn.Linear(mlp_hidden_times * n_embd, n_embd),
nn.Dropout(resid_pdrop),
)
def forward(self, x, timestep, mask=None, label_emb=None):
a, att = self.attn(self.ln1(x, timestep, label_emb), mask=mask)
x = x + a
x = x + self.mlp(self.ln2(x)) # only one really use encoder_output
return x, att
class Encoder(nn.Module):
def __init__(
self,
n_layer=14,
n_embd=1024,
n_head=16,
attn_pdrop=0.0,
resid_pdrop=0.0,
mlp_hidden_times=4,
block_activate="GELU",
):
super().__init__()
self.blocks = nn.Sequential(
*[
EncoderBlock(
n_embd=n_embd,
n_head=n_head,
attn_pdrop=attn_pdrop,
resid_pdrop=resid_pdrop,
mlp_hidden_times=mlp_hidden_times,
activate=block_activate,
)
for _ in range(n_layer)
]
)
def forward(self, input, t, padding_masks=None, label_emb=None):
x = input
for block_idx in range(len(self.blocks)):
x, _ = self.blocks[block_idx](x, t, mask=padding_masks, label_emb=label_emb)
return x
class DecoderBlock(nn.Module):
"""an unassuming Transformer block"""
def __init__(
self,
n_channel,
n_feat,
n_embd=1024,
n_head=16,
attn_pdrop=0.1,
resid_pdrop=0.1,
mlp_hidden_times=4,
activate="GELU",
condition_dim=1024,
):
super().__init__()
self.ln1 = AdaLayerNorm(n_embd)
self.ln2 = nn.LayerNorm(n_embd)
self.attn1 = FullAttention(
n_embd=n_embd,
n_head=n_head,
attn_pdrop=attn_pdrop,
resid_pdrop=resid_pdrop,
)
self.attn2 = CrossAttention(
n_embd=n_embd,
condition_embd=condition_dim,
n_head=n_head,
attn_pdrop=attn_pdrop,
resid_pdrop=resid_pdrop,
)
self.ln1_1 = AdaLayerNorm(n_embd)
assert activate in ["GELU", "GELU2"]
act = nn.GELU() if activate == "GELU" else GELU2()
self.trend = TrendBlock(n_channel, n_channel, n_embd, n_feat, act=act)
# self.decomp = MovingBlock(n_channel)
self.seasonal = FourierLayer(d_model=n_embd)
# self.seasonal = SeasonBlock(n_channel, n_channel)
self.mlp = nn.Sequential(
nn.Linear(n_embd, mlp_hidden_times * n_embd),
act,
nn.Linear(mlp_hidden_times * n_embd, n_embd),
nn.Dropout(resid_pdrop),
)
self.proj = nn.Conv1d(n_channel, n_channel * 2, 1)
self.linear = nn.Linear(n_embd, n_feat)
def forward(self, x, encoder_output, timestep, mask=None, label_emb=None):
a, att = self.attn1(self.ln1(x, timestep, label_emb), mask=mask)
x = x + a
a, att = self.attn2(self.ln1_1(x, timestep), encoder_output, mask=mask)
x = x + a
x1, x2 = self.proj(x).chunk(2, dim=1)
trend, season = self.trend(x1), self.seasonal(x2)
x = x + self.mlp(self.ln2(x))
m = torch.mean(x, dim=1, keepdim=True)
return x - m, self.linear(m), trend, season
class Decoder(nn.Module):
def __init__(
self,
n_channel,
n_feat,
n_embd=1024,
n_head=16,
n_layer=10,
attn_pdrop=0.1,
resid_pdrop=0.1,
mlp_hidden_times=4,
block_activate="GELU",
condition_dim=512,
):
super().__init__()
self.d_model = n_embd
self.n_feat = n_feat
self.blocks = nn.Sequential(
*[
DecoderBlock(
n_feat=n_feat,
n_channel=n_channel,
n_embd=n_embd,
n_head=n_head,
attn_pdrop=attn_pdrop,
resid_pdrop=resid_pdrop,
mlp_hidden_times=mlp_hidden_times,
activate=block_activate,
condition_dim=condition_dim,
)
for _ in range(n_layer)
]
)
def forward(self, x, t, enc, padding_masks=None, label_emb=None):
b, c, _ = x.shape
# att_weights = []
mean = []
season = torch.zeros((b, c, self.d_model), device=x.device)
trend = torch.zeros((b, c, self.n_feat), device=x.device)
for block_idx in range(len(self.blocks)):
x, residual_mean, residual_trend, residual_season = self.blocks[block_idx](
x, enc, t, mask=padding_masks, label_emb=label_emb
)
season += residual_season
trend += residual_trend
mean.append(residual_mean)
mean = torch.cat(mean, dim=1)
return x, mean, trend, season
class Transformer(nn.Module):
def __init__(
self,
n_feat,
n_channel,
n_layer_enc=5,
n_layer_dec=14,
n_embd=1024,
n_heads=16,
attn_pdrop=0.1,
resid_pdrop=0.1,
mlp_hidden_times=4,
block_activate="GELU",
max_len=2048,
conv_params=None,
**kwargs
):
super().__init__()
self.emb = Conv_MLP(n_feat, n_embd, resid_pdrop=resid_pdrop)
self.inverse = Conv_MLP(n_embd, n_feat, resid_pdrop=resid_pdrop)
if conv_params is None or conv_params[0] is None:
if n_feat < 32 and n_channel < 64:
kernel_size, padding = 1, 0
else:
kernel_size, padding = 5, 2
else:
kernel_size, padding = conv_params
self.combine_s = nn.Conv1d(
n_embd,
n_feat,
kernel_size=kernel_size,
stride=1,
padding=padding,
padding_mode="circular",
bias=False,
)
self.combine_m = nn.Conv1d(
n_layer_dec,
1,
kernel_size=1,
stride=1,
padding=0,
padding_mode="circular",
bias=False,
)
self.encoder = Encoder(
n_layer_enc,
n_embd,
n_heads,
attn_pdrop,
resid_pdrop,
mlp_hidden_times,
block_activate,
)
self.pos_enc = LearnablePositionalEncoding(
n_embd, dropout=resid_pdrop, max_len=max_len
)
self.decoder = Decoder(
n_channel,
n_feat,
n_embd,
n_heads,
n_layer_dec,
attn_pdrop,
resid_pdrop,
mlp_hidden_times,
block_activate,
condition_dim=n_embd,
)
self.pos_dec = LearnablePositionalEncoding(
n_embd, dropout=resid_pdrop, max_len=max_len
)
def forward(self, input, t, padding_masks=None, return_res=False):
emb = self.emb(input)
inp_enc = self.pos_enc(emb)
enc_cond = self.encoder(inp_enc, t, padding_masks=padding_masks)
inp_dec = self.pos_dec(emb)
output, mean, trend, season = self.decoder(
inp_dec, t, enc_cond, padding_masks=padding_masks
)
res = self.inverse(output)
res_m = torch.mean(res, dim=1, keepdim=True)
season_error = (
self.combine_s(season.transpose(1, 2)).transpose(1, 2) + res - res_m
)
trend = self.combine_m(mean) + res_m + trend
if return_res:
return (
trend,
self.combine_s(season.transpose(1, 2)).transpose(1, 2),
res - res_m,
)
return trend, season_error
if __name__ == "__main__":
pass