my-tide-env / layers /MultiWaveletCorrelation.py
SeungHyeok Jang
Upload model files with Git LFS
e1ccef5
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from typing import List, Tuple
import math
from functools import partial
from torch import nn, einsum, diagonal
from math import log2, ceil
import pdb
from sympy import Poly, legendre, Symbol, chebyshevt
from scipy.special import eval_legendre
def legendreDer(k, x):
def _legendre(k, x):
return (2 * k + 1) * eval_legendre(k, x)
out = 0
for i in np.arange(k - 1, -1, -2):
out += _legendre(i, x)
return out
def phi_(phi_c, x, lb=0, ub=1):
mask = np.logical_or(x < lb, x > ub) * 1.0
return np.polynomial.polynomial.Polynomial(phi_c)(x) * (1 - mask)
def get_phi_psi(k, base):
x = Symbol('x')
phi_coeff = np.zeros((k, k))
phi_2x_coeff = np.zeros((k, k))
if base == 'legendre':
for ki in range(k):
coeff_ = Poly(legendre(ki, 2 * x - 1), x).all_coeffs()
phi_coeff[ki, :ki + 1] = np.flip(np.sqrt(2 * ki + 1) * np.array(coeff_).astype(np.float64))
coeff_ = Poly(legendre(ki, 4 * x - 1), x).all_coeffs()
phi_2x_coeff[ki, :ki + 1] = np.flip(np.sqrt(2) * np.sqrt(2 * ki + 1) * np.array(coeff_).astype(np.float64))
psi1_coeff = np.zeros((k, k))
psi2_coeff = np.zeros((k, k))
for ki in range(k):
psi1_coeff[ki, :] = phi_2x_coeff[ki, :]
for i in range(k):
a = phi_2x_coeff[ki, :ki + 1]
b = phi_coeff[i, :i + 1]
prod_ = np.convolve(a, b)
prod_[np.abs(prod_) < 1e-8] = 0
proj_ = (prod_ * 1 / (np.arange(len(prod_)) + 1) * np.power(0.5, 1 + np.arange(len(prod_)))).sum()
psi1_coeff[ki, :] -= proj_ * phi_coeff[i, :]
psi2_coeff[ki, :] -= proj_ * phi_coeff[i, :]
for j in range(ki):
a = phi_2x_coeff[ki, :ki + 1]
b = psi1_coeff[j, :]
prod_ = np.convolve(a, b)
prod_[np.abs(prod_) < 1e-8] = 0
proj_ = (prod_ * 1 / (np.arange(len(prod_)) + 1) * np.power(0.5, 1 + np.arange(len(prod_)))).sum()
psi1_coeff[ki, :] -= proj_ * psi1_coeff[j, :]
psi2_coeff[ki, :] -= proj_ * psi2_coeff[j, :]
a = psi1_coeff[ki, :]
prod_ = np.convolve(a, a)
prod_[np.abs(prod_) < 1e-8] = 0
norm1 = (prod_ * 1 / (np.arange(len(prod_)) + 1) * np.power(0.5, 1 + np.arange(len(prod_)))).sum()
a = psi2_coeff[ki, :]
prod_ = np.convolve(a, a)
prod_[np.abs(prod_) < 1e-8] = 0
norm2 = (prod_ * 1 / (np.arange(len(prod_)) + 1) * (1 - np.power(0.5, 1 + np.arange(len(prod_))))).sum()
norm_ = np.sqrt(norm1 + norm2)
psi1_coeff[ki, :] /= norm_
psi2_coeff[ki, :] /= norm_
psi1_coeff[np.abs(psi1_coeff) < 1e-8] = 0
psi2_coeff[np.abs(psi2_coeff) < 1e-8] = 0
phi = [np.poly1d(np.flip(phi_coeff[i, :])) for i in range(k)]
psi1 = [np.poly1d(np.flip(psi1_coeff[i, :])) for i in range(k)]
psi2 = [np.poly1d(np.flip(psi2_coeff[i, :])) for i in range(k)]
elif base == 'chebyshev':
for ki in range(k):
if ki == 0:
phi_coeff[ki, :ki + 1] = np.sqrt(2 / np.pi)
phi_2x_coeff[ki, :ki + 1] = np.sqrt(2 / np.pi) * np.sqrt(2)
else:
coeff_ = Poly(chebyshevt(ki, 2 * x - 1), x).all_coeffs()
phi_coeff[ki, :ki + 1] = np.flip(2 / np.sqrt(np.pi) * np.array(coeff_).astype(np.float64))
coeff_ = Poly(chebyshevt(ki, 4 * x - 1), x).all_coeffs()
phi_2x_coeff[ki, :ki + 1] = np.flip(
np.sqrt(2) * 2 / np.sqrt(np.pi) * np.array(coeff_).astype(np.float64))
phi = [partial(phi_, phi_coeff[i, :]) for i in range(k)]
x = Symbol('x')
kUse = 2 * k
roots = Poly(chebyshevt(kUse, 2 * x - 1)).all_roots()
x_m = np.array([rt.evalf(20) for rt in roots]).astype(np.float64)
# x_m[x_m==0.5] = 0.5 + 1e-8 # add small noise to avoid the case of 0.5 belonging to both phi(2x) and phi(2x-1)
# not needed for our purpose here, we use even k always to avoid
wm = np.pi / kUse / 2
psi1_coeff = np.zeros((k, k))
psi2_coeff = np.zeros((k, k))
psi1 = [[] for _ in range(k)]
psi2 = [[] for _ in range(k)]
for ki in range(k):
psi1_coeff[ki, :] = phi_2x_coeff[ki, :]
for i in range(k):
proj_ = (wm * phi[i](x_m) * np.sqrt(2) * phi[ki](2 * x_m)).sum()
psi1_coeff[ki, :] -= proj_ * phi_coeff[i, :]
psi2_coeff[ki, :] -= proj_ * phi_coeff[i, :]
for j in range(ki):
proj_ = (wm * psi1[j](x_m) * np.sqrt(2) * phi[ki](2 * x_m)).sum()
psi1_coeff[ki, :] -= proj_ * psi1_coeff[j, :]
psi2_coeff[ki, :] -= proj_ * psi2_coeff[j, :]
psi1[ki] = partial(phi_, psi1_coeff[ki, :], lb=0, ub=0.5)
psi2[ki] = partial(phi_, psi2_coeff[ki, :], lb=0.5, ub=1)
norm1 = (wm * psi1[ki](x_m) * psi1[ki](x_m)).sum()
norm2 = (wm * psi2[ki](x_m) * psi2[ki](x_m)).sum()
norm_ = np.sqrt(norm1 + norm2)
psi1_coeff[ki, :] /= norm_
psi2_coeff[ki, :] /= norm_
psi1_coeff[np.abs(psi1_coeff) < 1e-8] = 0
psi2_coeff[np.abs(psi2_coeff) < 1e-8] = 0
psi1[ki] = partial(phi_, psi1_coeff[ki, :], lb=0, ub=0.5 + 1e-16)
psi2[ki] = partial(phi_, psi2_coeff[ki, :], lb=0.5 + 1e-16, ub=1)
return phi, psi1, psi2
def get_filter(base, k):
def psi(psi1, psi2, i, inp):
mask = (inp <= 0.5) * 1.0
return psi1[i](inp) * mask + psi2[i](inp) * (1 - mask)
if base not in ['legendre', 'chebyshev']:
raise Exception('Base not supported')
x = Symbol('x')
H0 = np.zeros((k, k))
H1 = np.zeros((k, k))
G0 = np.zeros((k, k))
G1 = np.zeros((k, k))
PHI0 = np.zeros((k, k))
PHI1 = np.zeros((k, k))
phi, psi1, psi2 = get_phi_psi(k, base)
if base == 'legendre':
roots = Poly(legendre(k, 2 * x - 1)).all_roots()
x_m = np.array([rt.evalf(20) for rt in roots]).astype(np.float64)
wm = 1 / k / legendreDer(k, 2 * x_m - 1) / eval_legendre(k - 1, 2 * x_m - 1)
for ki in range(k):
for kpi in range(k):
H0[ki, kpi] = 1 / np.sqrt(2) * (wm * phi[ki](x_m / 2) * phi[kpi](x_m)).sum()
G0[ki, kpi] = 1 / np.sqrt(2) * (wm * psi(psi1, psi2, ki, x_m / 2) * phi[kpi](x_m)).sum()
H1[ki, kpi] = 1 / np.sqrt(2) * (wm * phi[ki]((x_m + 1) / 2) * phi[kpi](x_m)).sum()
G1[ki, kpi] = 1 / np.sqrt(2) * (wm * psi(psi1, psi2, ki, (x_m + 1) / 2) * phi[kpi](x_m)).sum()
PHI0 = np.eye(k)
PHI1 = np.eye(k)
elif base == 'chebyshev':
x = Symbol('x')
kUse = 2 * k
roots = Poly(chebyshevt(kUse, 2 * x - 1)).all_roots()
x_m = np.array([rt.evalf(20) for rt in roots]).astype(np.float64)
# x_m[x_m==0.5] = 0.5 + 1e-8 # add small noise to avoid the case of 0.5 belonging to both phi(2x) and phi(2x-1)
# not needed for our purpose here, we use even k always to avoid
wm = np.pi / kUse / 2
for ki in range(k):
for kpi in range(k):
H0[ki, kpi] = 1 / np.sqrt(2) * (wm * phi[ki](x_m / 2) * phi[kpi](x_m)).sum()
G0[ki, kpi] = 1 / np.sqrt(2) * (wm * psi(psi1, psi2, ki, x_m / 2) * phi[kpi](x_m)).sum()
H1[ki, kpi] = 1 / np.sqrt(2) * (wm * phi[ki]((x_m + 1) / 2) * phi[kpi](x_m)).sum()
G1[ki, kpi] = 1 / np.sqrt(2) * (wm * psi(psi1, psi2, ki, (x_m + 1) / 2) * phi[kpi](x_m)).sum()
PHI0[ki, kpi] = (wm * phi[ki](2 * x_m) * phi[kpi](2 * x_m)).sum() * 2
PHI1[ki, kpi] = (wm * phi[ki](2 * x_m - 1) * phi[kpi](2 * x_m - 1)).sum() * 2
PHI0[np.abs(PHI0) < 1e-8] = 0
PHI1[np.abs(PHI1) < 1e-8] = 0
H0[np.abs(H0) < 1e-8] = 0
H1[np.abs(H1) < 1e-8] = 0
G0[np.abs(G0) < 1e-8] = 0
G1[np.abs(G1) < 1e-8] = 0
return H0, H1, G0, G1, PHI0, PHI1
class MultiWaveletTransform(nn.Module):
"""
1D multiwavelet block.
"""
def __init__(self, ich=1, k=8, alpha=16, c=128,
nCZ=1, L=0, base='legendre', attention_dropout=0.1):
super(MultiWaveletTransform, self).__init__()
print('base', base)
self.k = k
self.c = c
self.L = L
self.nCZ = nCZ
self.Lk0 = nn.Linear(ich, c * k)
self.Lk1 = nn.Linear(c * k, ich)
self.ich = ich
self.MWT_CZ = nn.ModuleList(MWT_CZ1d(k, alpha, L, c, base) for i in range(nCZ))
def forward(self, queries, keys, values, attn_mask):
B, L, H, E = queries.shape
_, S, _, D = values.shape
if L > S:
zeros = torch.zeros_like(queries[:, :(L - S), :]).float()
values = torch.cat([values, zeros], dim=1)
keys = torch.cat([keys, zeros], dim=1)
else:
values = values[:, :L, :, :]
keys = keys[:, :L, :, :]
values = values.view(B, L, -1)
V = self.Lk0(values).view(B, L, self.c, -1)
for i in range(self.nCZ):
V = self.MWT_CZ[i](V)
if i < self.nCZ - 1:
V = F.relu(V)
V = self.Lk1(V.view(B, L, -1))
V = V.view(B, L, -1, D)
return (V.contiguous(), None)
class MultiWaveletCross(nn.Module):
"""
1D Multiwavelet Cross Attention layer.
"""
def __init__(self, in_channels, out_channels, seq_len_q, seq_len_kv, modes, c=64,
k=8, ich=512,
L=0,
base='legendre',
mode_select_method='random',
initializer=None, activation='tanh',
**kwargs):
super(MultiWaveletCross, self).__init__()
print('base', base)
self.c = c
self.k = k
self.L = L
H0, H1, G0, G1, PHI0, PHI1 = get_filter(base, k)
H0r = H0 @ PHI0
G0r = G0 @ PHI0
H1r = H1 @ PHI1
G1r = G1 @ PHI1
H0r[np.abs(H0r) < 1e-8] = 0
H1r[np.abs(H1r) < 1e-8] = 0
G0r[np.abs(G0r) < 1e-8] = 0
G1r[np.abs(G1r) < 1e-8] = 0
self.max_item = 3
self.attn1 = FourierCrossAttentionW(in_channels=in_channels, out_channels=out_channels, seq_len_q=seq_len_q,
seq_len_kv=seq_len_kv, modes=modes, activation=activation,
mode_select_method=mode_select_method)
self.attn2 = FourierCrossAttentionW(in_channels=in_channels, out_channels=out_channels, seq_len_q=seq_len_q,
seq_len_kv=seq_len_kv, modes=modes, activation=activation,
mode_select_method=mode_select_method)
self.attn3 = FourierCrossAttentionW(in_channels=in_channels, out_channels=out_channels, seq_len_q=seq_len_q,
seq_len_kv=seq_len_kv, modes=modes, activation=activation,
mode_select_method=mode_select_method)
self.attn4 = FourierCrossAttentionW(in_channels=in_channels, out_channels=out_channels, seq_len_q=seq_len_q,
seq_len_kv=seq_len_kv, modes=modes, activation=activation,
mode_select_method=mode_select_method)
self.T0 = nn.Linear(k, k)
self.register_buffer('ec_s', torch.Tensor(
np.concatenate((H0.T, H1.T), axis=0)))
self.register_buffer('ec_d', torch.Tensor(
np.concatenate((G0.T, G1.T), axis=0)))
self.register_buffer('rc_e', torch.Tensor(
np.concatenate((H0r, G0r), axis=0)))
self.register_buffer('rc_o', torch.Tensor(
np.concatenate((H1r, G1r), axis=0)))
self.Lk = nn.Linear(ich, c * k)
self.Lq = nn.Linear(ich, c * k)
self.Lv = nn.Linear(ich, c * k)
self.out = nn.Linear(c * k, ich)
self.modes1 = modes
def forward(self, q, k, v, mask=None):
B, N, H, E = q.shape # (B, N, H, E) torch.Size([3, 768, 8, 2])
_, S, _, _ = k.shape # (B, S, H, E) torch.Size([3, 96, 8, 2])
q = q.view(q.shape[0], q.shape[1], -1)
k = k.view(k.shape[0], k.shape[1], -1)
v = v.view(v.shape[0], v.shape[1], -1)
q = self.Lq(q)
q = q.view(q.shape[0], q.shape[1], self.c, self.k)
k = self.Lk(k)
k = k.view(k.shape[0], k.shape[1], self.c, self.k)
v = self.Lv(v)
v = v.view(v.shape[0], v.shape[1], self.c, self.k)
if N > S:
zeros = torch.zeros_like(q[:, :(N - S), :]).float()
v = torch.cat([v, zeros], dim=1)
k = torch.cat([k, zeros], dim=1)
else:
v = v[:, :N, :, :]
k = k[:, :N, :, :]
ns = math.floor(np.log2(N))
nl = pow(2, math.ceil(np.log2(N)))
extra_q = q[:, 0:nl - N, :, :]
extra_k = k[:, 0:nl - N, :, :]
extra_v = v[:, 0:nl - N, :, :]
q = torch.cat([q, extra_q], 1)
k = torch.cat([k, extra_k], 1)
v = torch.cat([v, extra_v], 1)
Ud_q = torch.jit.annotate(List[Tuple[Tensor]], [])
Ud_k = torch.jit.annotate(List[Tuple[Tensor]], [])
Ud_v = torch.jit.annotate(List[Tuple[Tensor]], [])
Us_q = torch.jit.annotate(List[Tensor], [])
Us_k = torch.jit.annotate(List[Tensor], [])
Us_v = torch.jit.annotate(List[Tensor], [])
Ud = torch.jit.annotate(List[Tensor], [])
Us = torch.jit.annotate(List[Tensor], [])
# decompose
for i in range(ns - self.L):
# print('q shape',q.shape)
d, q = self.wavelet_transform(q)
Ud_q += [tuple([d, q])]
Us_q += [d]
for i in range(ns - self.L):
d, k = self.wavelet_transform(k)
Ud_k += [tuple([d, k])]
Us_k += [d]
for i in range(ns - self.L):
d, v = self.wavelet_transform(v)
Ud_v += [tuple([d, v])]
Us_v += [d]
for i in range(ns - self.L):
dk, sk = Ud_k[i], Us_k[i]
dq, sq = Ud_q[i], Us_q[i]
dv, sv = Ud_v[i], Us_v[i]
Ud += [self.attn1(dq[0], dk[0], dv[0], mask)[0] + self.attn2(dq[1], dk[1], dv[1], mask)[0]]
Us += [self.attn3(sq, sk, sv, mask)[0]]
v = self.attn4(q, k, v, mask)[0]
# reconstruct
for i in range(ns - 1 - self.L, -1, -1):
v = v + Us[i]
v = torch.cat((v, Ud[i]), -1)
v = self.evenOdd(v)
v = self.out(v[:, :N, :, :].contiguous().view(B, N, -1))
return (v.contiguous(), None)
def wavelet_transform(self, x):
xa = torch.cat([x[:, ::2, :, :],
x[:, 1::2, :, :],
], -1)
d = torch.matmul(xa, self.ec_d)
s = torch.matmul(xa, self.ec_s)
return d, s
def evenOdd(self, x):
B, N, c, ich = x.shape # (B, N, c, k)
assert ich == 2 * self.k
x_e = torch.matmul(x, self.rc_e)
x_o = torch.matmul(x, self.rc_o)
x = torch.zeros(B, N * 2, c, self.k,
device=x.device)
x[..., ::2, :, :] = x_e
x[..., 1::2, :, :] = x_o
return x
class FourierCrossAttentionW(nn.Module):
def __init__(self, in_channels, out_channels, seq_len_q, seq_len_kv, modes=16, activation='tanh',
mode_select_method='random'):
super(FourierCrossAttentionW, self).__init__()
print('corss fourier correlation used!')
self.in_channels = in_channels
self.out_channels = out_channels
self.modes1 = modes
self.activation = activation
def compl_mul1d(self, order, x, weights):
x_flag = True
w_flag = True
if not torch.is_complex(x):
x_flag = False
x = torch.complex(x, torch.zeros_like(x).to(x.device))
if not torch.is_complex(weights):
w_flag = False
weights = torch.complex(weights, torch.zeros_like(weights).to(weights.device))
if x_flag or w_flag:
return torch.complex(torch.einsum(order, x.real, weights.real) - torch.einsum(order, x.imag, weights.imag),
torch.einsum(order, x.real, weights.imag) + torch.einsum(order, x.imag, weights.real))
else:
return torch.einsum(order, x.real, weights.real)
def forward(self, q, k, v, mask):
B, L, E, H = q.shape
xq = q.permute(0, 3, 2, 1) # size = [B, H, E, L] torch.Size([3, 8, 64, 512])
xk = k.permute(0, 3, 2, 1)
xv = v.permute(0, 3, 2, 1)
self.index_q = list(range(0, min(int(L // 2), self.modes1)))
self.index_k_v = list(range(0, min(int(xv.shape[3] // 2), self.modes1)))
# Compute Fourier coefficients
xq_ft_ = torch.zeros(B, H, E, len(self.index_q), device=xq.device, dtype=torch.cfloat)
xq_ft = torch.fft.rfft(xq, dim=-1)
for i, j in enumerate(self.index_q):
xq_ft_[:, :, :, i] = xq_ft[:, :, :, j]
xk_ft_ = torch.zeros(B, H, E, len(self.index_k_v), device=xq.device, dtype=torch.cfloat)
xk_ft = torch.fft.rfft(xk, dim=-1)
for i, j in enumerate(self.index_k_v):
xk_ft_[:, :, :, i] = xk_ft[:, :, :, j]
xqk_ft = (self.compl_mul1d("bhex,bhey->bhxy", xq_ft_, xk_ft_))
if self.activation == 'tanh':
xqk_ft = torch.complex(xqk_ft.real.tanh(), xqk_ft.imag.tanh())
elif self.activation == 'softmax':
xqk_ft = torch.softmax(abs(xqk_ft), dim=-1)
xqk_ft = torch.complex(xqk_ft, torch.zeros_like(xqk_ft))
else:
raise Exception('{} actiation function is not implemented'.format(self.activation))
xqkv_ft = self.compl_mul1d("bhxy,bhey->bhex", xqk_ft, xk_ft_)
xqkvw = xqkv_ft
out_ft = torch.zeros(B, H, E, L // 2 + 1, device=xq.device, dtype=torch.cfloat)
for i, j in enumerate(self.index_q):
out_ft[:, :, :, j] = xqkvw[:, :, :, i]
out = torch.fft.irfft(out_ft / self.in_channels / self.out_channels, n=xq.size(-1)).permute(0, 3, 2, 1)
# size = [B, L, H, E]
return (out, None)
class sparseKernelFT1d(nn.Module):
def __init__(self,
k, alpha, c=1,
nl=1,
initializer=None,
**kwargs):
super(sparseKernelFT1d, self).__init__()
self.modes1 = alpha
self.scale = (1 / (c * k * c * k))
self.weights1 = nn.Parameter(self.scale * torch.rand(c * k, c * k, self.modes1, dtype=torch.float))
self.weights2 = nn.Parameter(self.scale * torch.rand(c * k, c * k, self.modes1, dtype=torch.float))
self.weights1.requires_grad = True
self.weights2.requires_grad = True
self.k = k
def compl_mul1d(self, order, x, weights):
x_flag = True
w_flag = True
if not torch.is_complex(x):
x_flag = False
x = torch.complex(x, torch.zeros_like(x).to(x.device))
if not torch.is_complex(weights):
w_flag = False
weights = torch.complex(weights, torch.zeros_like(weights).to(weights.device))
if x_flag or w_flag:
return torch.complex(torch.einsum(order, x.real, weights.real) - torch.einsum(order, x.imag, weights.imag),
torch.einsum(order, x.real, weights.imag) + torch.einsum(order, x.imag, weights.real))
else:
return torch.einsum(order, x.real, weights.real)
def forward(self, x):
B, N, c, k = x.shape # (B, N, c, k)
x = x.view(B, N, -1)
x = x.permute(0, 2, 1)
x_fft = torch.fft.rfft(x)
# Multiply relevant Fourier modes
l = min(self.modes1, N // 2 + 1)
out_ft = torch.zeros(B, c * k, N // 2 + 1, device=x.device, dtype=torch.cfloat)
out_ft[:, :, :l] = self.compl_mul1d("bix,iox->box", x_fft[:, :, :l],
torch.complex(self.weights1, self.weights2)[:, :, :l])
x = torch.fft.irfft(out_ft, n=N)
x = x.permute(0, 2, 1).view(B, N, c, k)
return x
# ##
class MWT_CZ1d(nn.Module):
def __init__(self,
k=3, alpha=64,
L=0, c=1,
base='legendre',
initializer=None,
**kwargs):
super(MWT_CZ1d, self).__init__()
self.k = k
self.L = L
H0, H1, G0, G1, PHI0, PHI1 = get_filter(base, k)
H0r = H0 @ PHI0
G0r = G0 @ PHI0
H1r = H1 @ PHI1
G1r = G1 @ PHI1
H0r[np.abs(H0r) < 1e-8] = 0
H1r[np.abs(H1r) < 1e-8] = 0
G0r[np.abs(G0r) < 1e-8] = 0
G1r[np.abs(G1r) < 1e-8] = 0
self.max_item = 3
self.A = sparseKernelFT1d(k, alpha, c)
self.B = sparseKernelFT1d(k, alpha, c)
self.C = sparseKernelFT1d(k, alpha, c)
self.T0 = nn.Linear(k, k)
self.register_buffer('ec_s', torch.Tensor(
np.concatenate((H0.T, H1.T), axis=0)))
self.register_buffer('ec_d', torch.Tensor(
np.concatenate((G0.T, G1.T), axis=0)))
self.register_buffer('rc_e', torch.Tensor(
np.concatenate((H0r, G0r), axis=0)))
self.register_buffer('rc_o', torch.Tensor(
np.concatenate((H1r, G1r), axis=0)))
def forward(self, x):
B, N, c, k = x.shape # (B, N, k)
ns = math.floor(np.log2(N))
nl = pow(2, math.ceil(np.log2(N)))
extra_x = x[:, 0:nl - N, :, :]
x = torch.cat([x, extra_x], 1)
Ud = torch.jit.annotate(List[Tensor], [])
Us = torch.jit.annotate(List[Tensor], [])
for i in range(ns - self.L):
d, x = self.wavelet_transform(x)
Ud += [self.A(d) + self.B(x)]
Us += [self.C(d)]
x = self.T0(x) # coarsest scale transform
# reconstruct
for i in range(ns - 1 - self.L, -1, -1):
x = x + Us[i]
x = torch.cat((x, Ud[i]), -1)
x = self.evenOdd(x)
x = x[:, :N, :, :]
return x
def wavelet_transform(self, x):
xa = torch.cat([x[:, ::2, :, :],
x[:, 1::2, :, :],
], -1)
d = torch.matmul(xa, self.ec_d)
s = torch.matmul(xa, self.ec_s)
return d, s
def evenOdd(self, x):
B, N, c, ich = x.shape # (B, N, c, k)
assert ich == 2 * self.k
x_e = torch.matmul(x, self.rc_e)
x_o = torch.matmul(x, self.rc_o)
x = torch.zeros(B, N * 2, c, self.k,
device=x.device)
x[..., ::2, :, :] = x_e
x[..., 1::2, :, :] = x_o
return x