import torch import numpy as np import torch.nn as nn import torch.nn.functional as F from torch import Tensor from typing import List, Tuple import math from functools import partial from torch import nn, einsum, diagonal from math import log2, ceil import pdb from sympy import Poly, legendre, Symbol, chebyshevt from scipy.special import eval_legendre def legendreDer(k, x): def _legendre(k, x): return (2 * k + 1) * eval_legendre(k, x) out = 0 for i in np.arange(k - 1, -1, -2): out += _legendre(i, x) return out def phi_(phi_c, x, lb=0, ub=1): mask = np.logical_or(x < lb, x > ub) * 1.0 return np.polynomial.polynomial.Polynomial(phi_c)(x) * (1 - mask) def get_phi_psi(k, base): x = Symbol('x') phi_coeff = np.zeros((k, k)) phi_2x_coeff = np.zeros((k, k)) if base == 'legendre': for ki in range(k): coeff_ = Poly(legendre(ki, 2 * x - 1), x).all_coeffs() phi_coeff[ki, :ki + 1] = np.flip(np.sqrt(2 * ki + 1) * np.array(coeff_).astype(np.float64)) coeff_ = Poly(legendre(ki, 4 * x - 1), x).all_coeffs() phi_2x_coeff[ki, :ki + 1] = np.flip(np.sqrt(2) * np.sqrt(2 * ki + 1) * np.array(coeff_).astype(np.float64)) psi1_coeff = np.zeros((k, k)) psi2_coeff = np.zeros((k, k)) for ki in range(k): psi1_coeff[ki, :] = phi_2x_coeff[ki, :] for i in range(k): a = phi_2x_coeff[ki, :ki + 1] b = phi_coeff[i, :i + 1] prod_ = np.convolve(a, b) prod_[np.abs(prod_) < 1e-8] = 0 proj_ = (prod_ * 1 / (np.arange(len(prod_)) + 1) * np.power(0.5, 1 + np.arange(len(prod_)))).sum() psi1_coeff[ki, :] -= proj_ * phi_coeff[i, :] psi2_coeff[ki, :] -= proj_ * phi_coeff[i, :] for j in range(ki): a = phi_2x_coeff[ki, :ki + 1] b = psi1_coeff[j, :] prod_ = np.convolve(a, b) prod_[np.abs(prod_) < 1e-8] = 0 proj_ = (prod_ * 1 / (np.arange(len(prod_)) + 1) * np.power(0.5, 1 + np.arange(len(prod_)))).sum() psi1_coeff[ki, :] -= proj_ * psi1_coeff[j, :] psi2_coeff[ki, :] -= proj_ * psi2_coeff[j, :] a = psi1_coeff[ki, :] prod_ = np.convolve(a, a) prod_[np.abs(prod_) < 1e-8] = 0 norm1 = (prod_ * 1 / (np.arange(len(prod_)) + 1) * np.power(0.5, 1 + np.arange(len(prod_)))).sum() a = psi2_coeff[ki, :] prod_ = np.convolve(a, a) prod_[np.abs(prod_) < 1e-8] = 0 norm2 = (prod_ * 1 / (np.arange(len(prod_)) + 1) * (1 - np.power(0.5, 1 + np.arange(len(prod_))))).sum() norm_ = np.sqrt(norm1 + norm2) psi1_coeff[ki, :] /= norm_ psi2_coeff[ki, :] /= norm_ psi1_coeff[np.abs(psi1_coeff) < 1e-8] = 0 psi2_coeff[np.abs(psi2_coeff) < 1e-8] = 0 phi = [np.poly1d(np.flip(phi_coeff[i, :])) for i in range(k)] psi1 = [np.poly1d(np.flip(psi1_coeff[i, :])) for i in range(k)] psi2 = [np.poly1d(np.flip(psi2_coeff[i, :])) for i in range(k)] elif base == 'chebyshev': for ki in range(k): if ki == 0: phi_coeff[ki, :ki + 1] = np.sqrt(2 / np.pi) phi_2x_coeff[ki, :ki + 1] = np.sqrt(2 / np.pi) * np.sqrt(2) else: coeff_ = Poly(chebyshevt(ki, 2 * x - 1), x).all_coeffs() phi_coeff[ki, :ki + 1] = np.flip(2 / np.sqrt(np.pi) * np.array(coeff_).astype(np.float64)) coeff_ = Poly(chebyshevt(ki, 4 * x - 1), x).all_coeffs() phi_2x_coeff[ki, :ki + 1] = np.flip( np.sqrt(2) * 2 / np.sqrt(np.pi) * np.array(coeff_).astype(np.float64)) phi = [partial(phi_, phi_coeff[i, :]) for i in range(k)] x = Symbol('x') kUse = 2 * k roots = Poly(chebyshevt(kUse, 2 * x - 1)).all_roots() x_m = np.array([rt.evalf(20) for rt in roots]).astype(np.float64) # x_m[x_m==0.5] = 0.5 + 1e-8 # add small noise to avoid the case of 0.5 belonging to both phi(2x) and phi(2x-1) # not needed for our purpose here, we use even k always to avoid wm = np.pi / kUse / 2 psi1_coeff = np.zeros((k, k)) psi2_coeff = np.zeros((k, k)) psi1 = [[] for _ in range(k)] psi2 = [[] for _ in range(k)] for ki in range(k): psi1_coeff[ki, :] = phi_2x_coeff[ki, :] for i in range(k): proj_ = (wm * phi[i](x_m) * np.sqrt(2) * phi[ki](2 * x_m)).sum() psi1_coeff[ki, :] -= proj_ * phi_coeff[i, :] psi2_coeff[ki, :] -= proj_ * phi_coeff[i, :] for j in range(ki): proj_ = (wm * psi1[j](x_m) * np.sqrt(2) * phi[ki](2 * x_m)).sum() psi1_coeff[ki, :] -= proj_ * psi1_coeff[j, :] psi2_coeff[ki, :] -= proj_ * psi2_coeff[j, :] psi1[ki] = partial(phi_, psi1_coeff[ki, :], lb=0, ub=0.5) psi2[ki] = partial(phi_, psi2_coeff[ki, :], lb=0.5, ub=1) norm1 = (wm * psi1[ki](x_m) * psi1[ki](x_m)).sum() norm2 = (wm * psi2[ki](x_m) * psi2[ki](x_m)).sum() norm_ = np.sqrt(norm1 + norm2) psi1_coeff[ki, :] /= norm_ psi2_coeff[ki, :] /= norm_ psi1_coeff[np.abs(psi1_coeff) < 1e-8] = 0 psi2_coeff[np.abs(psi2_coeff) < 1e-8] = 0 psi1[ki] = partial(phi_, psi1_coeff[ki, :], lb=0, ub=0.5 + 1e-16) psi2[ki] = partial(phi_, psi2_coeff[ki, :], lb=0.5 + 1e-16, ub=1) return phi, psi1, psi2 def get_filter(base, k): def psi(psi1, psi2, i, inp): mask = (inp <= 0.5) * 1.0 return psi1[i](inp) * mask + psi2[i](inp) * (1 - mask) if base not in ['legendre', 'chebyshev']: raise Exception('Base not supported') x = Symbol('x') H0 = np.zeros((k, k)) H1 = np.zeros((k, k)) G0 = np.zeros((k, k)) G1 = np.zeros((k, k)) PHI0 = np.zeros((k, k)) PHI1 = np.zeros((k, k)) phi, psi1, psi2 = get_phi_psi(k, base) if base == 'legendre': roots = Poly(legendre(k, 2 * x - 1)).all_roots() x_m = np.array([rt.evalf(20) for rt in roots]).astype(np.float64) wm = 1 / k / legendreDer(k, 2 * x_m - 1) / eval_legendre(k - 1, 2 * x_m - 1) for ki in range(k): for kpi in range(k): H0[ki, kpi] = 1 / np.sqrt(2) * (wm * phi[ki](x_m / 2) * phi[kpi](x_m)).sum() G0[ki, kpi] = 1 / np.sqrt(2) * (wm * psi(psi1, psi2, ki, x_m / 2) * phi[kpi](x_m)).sum() H1[ki, kpi] = 1 / np.sqrt(2) * (wm * phi[ki]((x_m + 1) / 2) * phi[kpi](x_m)).sum() G1[ki, kpi] = 1 / np.sqrt(2) * (wm * psi(psi1, psi2, ki, (x_m + 1) / 2) * phi[kpi](x_m)).sum() PHI0 = np.eye(k) PHI1 = np.eye(k) elif base == 'chebyshev': x = Symbol('x') kUse = 2 * k roots = Poly(chebyshevt(kUse, 2 * x - 1)).all_roots() x_m = np.array([rt.evalf(20) for rt in roots]).astype(np.float64) # x_m[x_m==0.5] = 0.5 + 1e-8 # add small noise to avoid the case of 0.5 belonging to both phi(2x) and phi(2x-1) # not needed for our purpose here, we use even k always to avoid wm = np.pi / kUse / 2 for ki in range(k): for kpi in range(k): H0[ki, kpi] = 1 / np.sqrt(2) * (wm * phi[ki](x_m / 2) * phi[kpi](x_m)).sum() G0[ki, kpi] = 1 / np.sqrt(2) * (wm * psi(psi1, psi2, ki, x_m / 2) * phi[kpi](x_m)).sum() H1[ki, kpi] = 1 / np.sqrt(2) * (wm * phi[ki]((x_m + 1) / 2) * phi[kpi](x_m)).sum() G1[ki, kpi] = 1 / np.sqrt(2) * (wm * psi(psi1, psi2, ki, (x_m + 1) / 2) * phi[kpi](x_m)).sum() PHI0[ki, kpi] = (wm * phi[ki](2 * x_m) * phi[kpi](2 * x_m)).sum() * 2 PHI1[ki, kpi] = (wm * phi[ki](2 * x_m - 1) * phi[kpi](2 * x_m - 1)).sum() * 2 PHI0[np.abs(PHI0) < 1e-8] = 0 PHI1[np.abs(PHI1) < 1e-8] = 0 H0[np.abs(H0) < 1e-8] = 0 H1[np.abs(H1) < 1e-8] = 0 G0[np.abs(G0) < 1e-8] = 0 G1[np.abs(G1) < 1e-8] = 0 return H0, H1, G0, G1, PHI0, PHI1 class MultiWaveletTransform(nn.Module): """ 1D multiwavelet block. """ def __init__(self, ich=1, k=8, alpha=16, c=128, nCZ=1, L=0, base='legendre', attention_dropout=0.1): super(MultiWaveletTransform, self).__init__() print('base', base) self.k = k self.c = c self.L = L self.nCZ = nCZ self.Lk0 = nn.Linear(ich, c * k) self.Lk1 = nn.Linear(c * k, ich) self.ich = ich self.MWT_CZ = nn.ModuleList(MWT_CZ1d(k, alpha, L, c, base) for i in range(nCZ)) def forward(self, queries, keys, values, attn_mask): B, L, H, E = queries.shape _, S, _, D = values.shape if L > S: zeros = torch.zeros_like(queries[:, :(L - S), :]).float() values = torch.cat([values, zeros], dim=1) keys = torch.cat([keys, zeros], dim=1) else: values = values[:, :L, :, :] keys = keys[:, :L, :, :] values = values.view(B, L, -1) V = self.Lk0(values).view(B, L, self.c, -1) for i in range(self.nCZ): V = self.MWT_CZ[i](V) if i < self.nCZ - 1: V = F.relu(V) V = self.Lk1(V.view(B, L, -1)) V = V.view(B, L, -1, D) return (V.contiguous(), None) class MultiWaveletCross(nn.Module): """ 1D Multiwavelet Cross Attention layer. """ def __init__(self, in_channels, out_channels, seq_len_q, seq_len_kv, modes, c=64, k=8, ich=512, L=0, base='legendre', mode_select_method='random', initializer=None, activation='tanh', **kwargs): super(MultiWaveletCross, self).__init__() print('base', base) self.c = c self.k = k self.L = L H0, H1, G0, G1, PHI0, PHI1 = get_filter(base, k) H0r = H0 @ PHI0 G0r = G0 @ PHI0 H1r = H1 @ PHI1 G1r = G1 @ PHI1 H0r[np.abs(H0r) < 1e-8] = 0 H1r[np.abs(H1r) < 1e-8] = 0 G0r[np.abs(G0r) < 1e-8] = 0 G1r[np.abs(G1r) < 1e-8] = 0 self.max_item = 3 self.attn1 = FourierCrossAttentionW(in_channels=in_channels, out_channels=out_channels, seq_len_q=seq_len_q, seq_len_kv=seq_len_kv, modes=modes, activation=activation, mode_select_method=mode_select_method) self.attn2 = FourierCrossAttentionW(in_channels=in_channels, out_channels=out_channels, seq_len_q=seq_len_q, seq_len_kv=seq_len_kv, modes=modes, activation=activation, mode_select_method=mode_select_method) self.attn3 = FourierCrossAttentionW(in_channels=in_channels, out_channels=out_channels, seq_len_q=seq_len_q, seq_len_kv=seq_len_kv, modes=modes, activation=activation, mode_select_method=mode_select_method) self.attn4 = FourierCrossAttentionW(in_channels=in_channels, out_channels=out_channels, seq_len_q=seq_len_q, seq_len_kv=seq_len_kv, modes=modes, activation=activation, mode_select_method=mode_select_method) self.T0 = nn.Linear(k, k) self.register_buffer('ec_s', torch.Tensor( np.concatenate((H0.T, H1.T), axis=0))) self.register_buffer('ec_d', torch.Tensor( np.concatenate((G0.T, G1.T), axis=0))) self.register_buffer('rc_e', torch.Tensor( np.concatenate((H0r, G0r), axis=0))) self.register_buffer('rc_o', torch.Tensor( np.concatenate((H1r, G1r), axis=0))) self.Lk = nn.Linear(ich, c * k) self.Lq = nn.Linear(ich, c * k) self.Lv = nn.Linear(ich, c * k) self.out = nn.Linear(c * k, ich) self.modes1 = modes def forward(self, q, k, v, mask=None): B, N, H, E = q.shape # (B, N, H, E) torch.Size([3, 768, 8, 2]) _, S, _, _ = k.shape # (B, S, H, E) torch.Size([3, 96, 8, 2]) q = q.view(q.shape[0], q.shape[1], -1) k = k.view(k.shape[0], k.shape[1], -1) v = v.view(v.shape[0], v.shape[1], -1) q = self.Lq(q) q = q.view(q.shape[0], q.shape[1], self.c, self.k) k = self.Lk(k) k = k.view(k.shape[0], k.shape[1], self.c, self.k) v = self.Lv(v) v = v.view(v.shape[0], v.shape[1], self.c, self.k) if N > S: zeros = torch.zeros_like(q[:, :(N - S), :]).float() v = torch.cat([v, zeros], dim=1) k = torch.cat([k, zeros], dim=1) else: v = v[:, :N, :, :] k = k[:, :N, :, :] ns = math.floor(np.log2(N)) nl = pow(2, math.ceil(np.log2(N))) extra_q = q[:, 0:nl - N, :, :] extra_k = k[:, 0:nl - N, :, :] extra_v = v[:, 0:nl - N, :, :] q = torch.cat([q, extra_q], 1) k = torch.cat([k, extra_k], 1) v = torch.cat([v, extra_v], 1) Ud_q = torch.jit.annotate(List[Tuple[Tensor]], []) Ud_k = torch.jit.annotate(List[Tuple[Tensor]], []) Ud_v = torch.jit.annotate(List[Tuple[Tensor]], []) Us_q = torch.jit.annotate(List[Tensor], []) Us_k = torch.jit.annotate(List[Tensor], []) Us_v = torch.jit.annotate(List[Tensor], []) Ud = torch.jit.annotate(List[Tensor], []) Us = torch.jit.annotate(List[Tensor], []) # decompose for i in range(ns - self.L): # print('q shape',q.shape) d, q = self.wavelet_transform(q) Ud_q += [tuple([d, q])] Us_q += [d] for i in range(ns - self.L): d, k = self.wavelet_transform(k) Ud_k += [tuple([d, k])] Us_k += [d] for i in range(ns - self.L): d, v = self.wavelet_transform(v) Ud_v += [tuple([d, v])] Us_v += [d] for i in range(ns - self.L): dk, sk = Ud_k[i], Us_k[i] dq, sq = Ud_q[i], Us_q[i] dv, sv = Ud_v[i], Us_v[i] Ud += [self.attn1(dq[0], dk[0], dv[0], mask)[0] + self.attn2(dq[1], dk[1], dv[1], mask)[0]] Us += [self.attn3(sq, sk, sv, mask)[0]] v = self.attn4(q, k, v, mask)[0] # reconstruct for i in range(ns - 1 - self.L, -1, -1): v = v + Us[i] v = torch.cat((v, Ud[i]), -1) v = self.evenOdd(v) v = self.out(v[:, :N, :, :].contiguous().view(B, N, -1)) return (v.contiguous(), None) def wavelet_transform(self, x): xa = torch.cat([x[:, ::2, :, :], x[:, 1::2, :, :], ], -1) d = torch.matmul(xa, self.ec_d) s = torch.matmul(xa, self.ec_s) return d, s def evenOdd(self, x): B, N, c, ich = x.shape # (B, N, c, k) assert ich == 2 * self.k x_e = torch.matmul(x, self.rc_e) x_o = torch.matmul(x, self.rc_o) x = torch.zeros(B, N * 2, c, self.k, device=x.device) x[..., ::2, :, :] = x_e x[..., 1::2, :, :] = x_o return x class FourierCrossAttentionW(nn.Module): def __init__(self, in_channels, out_channels, seq_len_q, seq_len_kv, modes=16, activation='tanh', mode_select_method='random'): super(FourierCrossAttentionW, self).__init__() print('corss fourier correlation used!') self.in_channels = in_channels self.out_channels = out_channels self.modes1 = modes self.activation = activation def compl_mul1d(self, order, x, weights): x_flag = True w_flag = True if not torch.is_complex(x): x_flag = False x = torch.complex(x, torch.zeros_like(x).to(x.device)) if not torch.is_complex(weights): w_flag = False weights = torch.complex(weights, torch.zeros_like(weights).to(weights.device)) if x_flag or w_flag: return torch.complex(torch.einsum(order, x.real, weights.real) - torch.einsum(order, x.imag, weights.imag), torch.einsum(order, x.real, weights.imag) + torch.einsum(order, x.imag, weights.real)) else: return torch.einsum(order, x.real, weights.real) def forward(self, q, k, v, mask): B, L, E, H = q.shape xq = q.permute(0, 3, 2, 1) # size = [B, H, E, L] torch.Size([3, 8, 64, 512]) xk = k.permute(0, 3, 2, 1) xv = v.permute(0, 3, 2, 1) self.index_q = list(range(0, min(int(L // 2), self.modes1))) self.index_k_v = list(range(0, min(int(xv.shape[3] // 2), self.modes1))) # Compute Fourier coefficients xq_ft_ = torch.zeros(B, H, E, len(self.index_q), device=xq.device, dtype=torch.cfloat) xq_ft = torch.fft.rfft(xq, dim=-1) for i, j in enumerate(self.index_q): xq_ft_[:, :, :, i] = xq_ft[:, :, :, j] xk_ft_ = torch.zeros(B, H, E, len(self.index_k_v), device=xq.device, dtype=torch.cfloat) xk_ft = torch.fft.rfft(xk, dim=-1) for i, j in enumerate(self.index_k_v): xk_ft_[:, :, :, i] = xk_ft[:, :, :, j] xqk_ft = (self.compl_mul1d("bhex,bhey->bhxy", xq_ft_, xk_ft_)) if self.activation == 'tanh': xqk_ft = torch.complex(xqk_ft.real.tanh(), xqk_ft.imag.tanh()) elif self.activation == 'softmax': xqk_ft = torch.softmax(abs(xqk_ft), dim=-1) xqk_ft = torch.complex(xqk_ft, torch.zeros_like(xqk_ft)) else: raise Exception('{} actiation function is not implemented'.format(self.activation)) xqkv_ft = self.compl_mul1d("bhxy,bhey->bhex", xqk_ft, xk_ft_) xqkvw = xqkv_ft out_ft = torch.zeros(B, H, E, L // 2 + 1, device=xq.device, dtype=torch.cfloat) for i, j in enumerate(self.index_q): out_ft[:, :, :, j] = xqkvw[:, :, :, i] out = torch.fft.irfft(out_ft / self.in_channels / self.out_channels, n=xq.size(-1)).permute(0, 3, 2, 1) # size = [B, L, H, E] return (out, None) class sparseKernelFT1d(nn.Module): def __init__(self, k, alpha, c=1, nl=1, initializer=None, **kwargs): super(sparseKernelFT1d, self).__init__() self.modes1 = alpha self.scale = (1 / (c * k * c * k)) self.weights1 = nn.Parameter(self.scale * torch.rand(c * k, c * k, self.modes1, dtype=torch.float)) self.weights2 = nn.Parameter(self.scale * torch.rand(c * k, c * k, self.modes1, dtype=torch.float)) self.weights1.requires_grad = True self.weights2.requires_grad = True self.k = k def compl_mul1d(self, order, x, weights): x_flag = True w_flag = True if not torch.is_complex(x): x_flag = False x = torch.complex(x, torch.zeros_like(x).to(x.device)) if not torch.is_complex(weights): w_flag = False weights = torch.complex(weights, torch.zeros_like(weights).to(weights.device)) if x_flag or w_flag: return torch.complex(torch.einsum(order, x.real, weights.real) - torch.einsum(order, x.imag, weights.imag), torch.einsum(order, x.real, weights.imag) + torch.einsum(order, x.imag, weights.real)) else: return torch.einsum(order, x.real, weights.real) def forward(self, x): B, N, c, k = x.shape # (B, N, c, k) x = x.view(B, N, -1) x = x.permute(0, 2, 1) x_fft = torch.fft.rfft(x) # Multiply relevant Fourier modes l = min(self.modes1, N // 2 + 1) out_ft = torch.zeros(B, c * k, N // 2 + 1, device=x.device, dtype=torch.cfloat) out_ft[:, :, :l] = self.compl_mul1d("bix,iox->box", x_fft[:, :, :l], torch.complex(self.weights1, self.weights2)[:, :, :l]) x = torch.fft.irfft(out_ft, n=N) x = x.permute(0, 2, 1).view(B, N, c, k) return x # ## class MWT_CZ1d(nn.Module): def __init__(self, k=3, alpha=64, L=0, c=1, base='legendre', initializer=None, **kwargs): super(MWT_CZ1d, self).__init__() self.k = k self.L = L H0, H1, G0, G1, PHI0, PHI1 = get_filter(base, k) H0r = H0 @ PHI0 G0r = G0 @ PHI0 H1r = H1 @ PHI1 G1r = G1 @ PHI1 H0r[np.abs(H0r) < 1e-8] = 0 H1r[np.abs(H1r) < 1e-8] = 0 G0r[np.abs(G0r) < 1e-8] = 0 G1r[np.abs(G1r) < 1e-8] = 0 self.max_item = 3 self.A = sparseKernelFT1d(k, alpha, c) self.B = sparseKernelFT1d(k, alpha, c) self.C = sparseKernelFT1d(k, alpha, c) self.T0 = nn.Linear(k, k) self.register_buffer('ec_s', torch.Tensor( np.concatenate((H0.T, H1.T), axis=0))) self.register_buffer('ec_d', torch.Tensor( np.concatenate((G0.T, G1.T), axis=0))) self.register_buffer('rc_e', torch.Tensor( np.concatenate((H0r, G0r), axis=0))) self.register_buffer('rc_o', torch.Tensor( np.concatenate((H1r, G1r), axis=0))) def forward(self, x): B, N, c, k = x.shape # (B, N, k) ns = math.floor(np.log2(N)) nl = pow(2, math.ceil(np.log2(N))) extra_x = x[:, 0:nl - N, :, :] x = torch.cat([x, extra_x], 1) Ud = torch.jit.annotate(List[Tensor], []) Us = torch.jit.annotate(List[Tensor], []) for i in range(ns - self.L): d, x = self.wavelet_transform(x) Ud += [self.A(d) + self.B(x)] Us += [self.C(d)] x = self.T0(x) # coarsest scale transform # reconstruct for i in range(ns - 1 - self.L, -1, -1): x = x + Us[i] x = torch.cat((x, Ud[i]), -1) x = self.evenOdd(x) x = x[:, :N, :, :] return x def wavelet_transform(self, x): xa = torch.cat([x[:, ::2, :, :], x[:, 1::2, :, :], ], -1) d = torch.matmul(xa, self.ec_d) s = torch.matmul(xa, self.ec_s) return d, s def evenOdd(self, x): B, N, c, ich = x.shape # (B, N, c, k) assert ich == 2 * self.k x_e = torch.matmul(x, self.rc_e) x_o = torch.matmul(x, self.rc_o) x = torch.zeros(B, N * 2, c, self.k, device=x.device) x[..., ::2, :, :] = x_e x[..., 1::2, :, :] = x_o return x