Spaces:
Napawit
/
Running on Zero

Zeyue7's picture
AudioX
8ab1cf8
"""The 1D discrete wavelet transform for PyTorch."""
from einops import rearrange
import pywt
import torch
from torch import nn
from torch.nn import functional as F
from typing import Literal
def get_filter_bank(wavelet):
filt = torch.tensor(pywt.Wavelet(wavelet).filter_bank)
if wavelet.startswith("bior") and torch.all(filt[:, 0] == 0):
filt = filt[:, 1:]
return filt
class WaveletEncode1d(nn.Module):
def __init__(self,
channels,
levels,
wavelet: Literal["bior2.2", "bior2.4", "bior2.6", "bior2.8", "bior4.4", "bior6.8"] = "bior4.4"):
super().__init__()
self.wavelet = wavelet
self.channels = channels
self.levels = levels
filt = get_filter_bank(wavelet)
assert filt.shape[-1] % 2 == 1
kernel = filt[:2, None]
kernel = torch.flip(kernel, dims=(-1,))
index_i = torch.repeat_interleave(torch.arange(2), channels)
index_j = torch.tile(torch.arange(channels), (2,))
kernel_final = torch.zeros(channels * 2, channels, filt.shape[-1])
kernel_final[index_i * channels + index_j, index_j] = kernel[index_i, 0]
self.register_buffer("kernel", kernel_final)
def forward(self, x):
for i in range(self.levels):
low, rest = x[:, : self.channels], x[:, self.channels :]
pad = self.kernel.shape[-1] // 2
low = F.pad(low, (pad, pad), "reflect")
low = F.conv1d(low, self.kernel, stride=2)
rest = rearrange(
rest, "n (c c2) (l l2) -> n (c l2 c2) l", l2=2, c2=self.channels
)
x = torch.cat([low, rest], dim=1)
return x
class WaveletDecode1d(nn.Module):
def __init__(self,
channels,
levels,
wavelet: Literal["bior2.2", "bior2.4", "bior2.6", "bior2.8", "bior4.4", "bior6.8"] = "bior4.4"):
super().__init__()
self.wavelet = wavelet
self.channels = channels
self.levels = levels
filt = get_filter_bank(wavelet)
assert filt.shape[-1] % 2 == 1
kernel = filt[2:, None]
index_i = torch.repeat_interleave(torch.arange(2), channels)
index_j = torch.tile(torch.arange(channels), (2,))
kernel_final = torch.zeros(channels * 2, channels, filt.shape[-1])
kernel_final[index_i * channels + index_j, index_j] = kernel[index_i, 0]
self.register_buffer("kernel", kernel_final)
def forward(self, x):
for i in range(self.levels):
low, rest = x[:, : self.channels * 2], x[:, self.channels * 2 :]
pad = self.kernel.shape[-1] // 2 + 2
low = rearrange(low, "n (l2 c) l -> n c (l l2)", l2=2)
low = F.pad(low, (pad, pad), "reflect")
low = rearrange(low, "n c (l l2) -> n (l2 c) l", l2=2)
low = F.conv_transpose1d(
low, self.kernel, stride=2, padding=self.kernel.shape[-1] // 2
)
low = low[..., pad - 1 : -pad]
rest = rearrange(
rest, "n (c l2 c2) l -> n (c c2) (l l2)", l2=2, c2=self.channels
)
x = torch.cat([low, rest], dim=1)
return x