''' @File : subband_util.py @Contact : liu.8948@buckeyemail.osu.edu @License : (C)Copyright 2020-2021 @Modify Time @Author @Version @Desciption ------------ ------- -------- ----------- 2020/4/3 4:54 PM Haohe Liu 1.0 None ''' import torch import torch.nn.functional as F import torch.nn as nn import numpy as np import os.path as op from scipy.io import loadmat def load_mat2numpy(fname=""): ''' Args: fname: pth to mat type: Returns: dic object ''' if len(fname) == 0: return None else: return loadmat(fname) class PQMF(nn.Module): def __init__(self, N, M, project_root): super().__init__() self.N = N # nsubband self.M = M # nfilter try: assert (N, M) in [(8, 64), (4, 64), (2, 64)] except: print("Warning:", N, "subbandand ", M, " filter is not supported") self.pad_samples = 64 self.name = str(N) + "_" + str(M) + ".mat" self.ana_conv_filter = nn.Conv1d( 1, out_channels=N, kernel_size=M, stride=N, bias=False ) data = load_mat2numpy(op.join(project_root, "f_" + self.name)) data = data['f'].astype(np.float32) / N data = np.flipud(data.T).T data = np.reshape(data, (N, 1, M)).copy() dict_new = self.ana_conv_filter.state_dict().copy() dict_new['weight'] = torch.from_numpy(data) self.ana_pad = nn.ConstantPad1d((M - N, 0), 0) self.ana_conv_filter.load_state_dict(dict_new) self.syn_pad = nn.ConstantPad1d((0, M // N - 1), 0) self.syn_conv_filter = nn.Conv1d( N, out_channels=N, kernel_size=M // N, stride=1, bias=False ) gk = load_mat2numpy(op.join(project_root, "h_" + self.name)) gk = gk['h'].astype(np.float32) gk = np.transpose(np.reshape(gk, (N, M // N, N)), (1, 0, 2)) * N gk = np.transpose(gk[::-1, :, :], (2, 1, 0)).copy() dict_new = self.syn_conv_filter.state_dict().copy() dict_new['weight'] = torch.from_numpy(gk) self.syn_conv_filter.load_state_dict(dict_new) for param in self.parameters(): param.requires_grad = False def __analysis_channel(self, inputs): return self.ana_conv_filter(self.ana_pad(inputs)) def __systhesis_channel(self, inputs): ret = self.syn_conv_filter(self.syn_pad(inputs)).permute(0, 2, 1) return torch.reshape(ret, (ret.shape[0], 1, -1)) def analysis(self, inputs): ''' :param inputs: [batchsize,channel,raw_wav],value:[0,1] :return: ''' inputs = F.pad(inputs, ((0, self.pad_samples))) ret = None for i in range(inputs.size()[1]): # channels if ret is None: ret = self.__analysis_channel(inputs[:, i : i + 1, :]) else: ret = torch.cat( (ret, self.__analysis_channel(inputs[:, i : i + 1, :])), dim=1 ) return ret def synthesis(self, data): ''' :param data: [batchsize,self.N*K,raw_wav_sub],value:[0,1] :return: ''' ret = None # data = F.pad(data,((0,self.pad_samples//self.N))) for i in range(data.size()[1]): # channels if i % self.N == 0: if ret is None: ret = self.__systhesis_channel(data[:, i : i + self.N, :]) else: new = self.__systhesis_channel(data[:, i : i + self.N, :]) ret = torch.cat((ret, new), dim=1) ret = ret[..., : -self.pad_samples] return ret def forward(self, inputs): return self.ana_conv_filter(self.ana_pad(inputs)) if __name__ == "__main__": import torch import numpy as np import matplotlib.pyplot as plt from tools.file.wav import * pqmf = PQMF(N=4, M=64, project_root="/Users/admin/Documents/projects") rs = np.random.RandomState(0) x = torch.tensor(rs.rand(4, 2, 32000), dtype=torch.float32) a1 = pqmf.analysis(x) a2 = pqmf.synthesis(a1) print(a2.size(), x.size()) plt.subplot(211) plt.plot(x[0, 0, -500:]) plt.subplot(212) plt.plot(a2[0, 0, -500:]) plt.plot(x[0, 0, -500:] - a2[0, 0, -500:]) plt.show() print(torch.sum(torch.abs(x[...] - a2[...])))