liuhuadai's picture
Upload 340 files
6efc863 verified
raw
history blame
No virus
7.77 kB
from ldm.data.preprocess.NAT_mel import MelNet
import os
from tqdm import tqdm
from glob import glob
import math
import pandas as pd
import logging
import math
import audioread
from tqdm.contrib.concurrent import process_map
import torch
import torch.nn as nn
import torchaudio
import numpy as np
from torch.distributed import init_process_group
from torch.utils.data import Dataset,DataLoader,DistributedSampler
import torch.multiprocessing as mp
from argparse import Namespace
from multiprocessing import Pool
import json
class tsv_dataset(Dataset):
def __init__(self,tsv_path,sr,mode='none',hop_size = None,target_mel_length = None) -> None:
super().__init__()
if os.path.isdir(tsv_path):
files = glob(os.path.join(tsv_path,'*.tsv'))
df = pd.concat([pd.read_csv(file,sep='\t') for file in files])
else:
df = pd.read_csv(tsv_path,sep='\t')
self.audio_paths = []
self.sr = sr
self.mode = mode
self.target_mel_length = target_mel_length
self.hop_size = hop_size
for t in tqdm(df.itertuples()):
self.audio_paths.append(getattr(t,'audio_path'))
def __len__(self):
return len(self.audio_paths)
def pad_wav(self,wav):
# wav should be in shape(1,wav_len)
wav_length = wav.shape[-1]
assert wav_length > 100, "wav is too short, %s" % wav_length
segment_length = (self.target_mel_length + 1) * self.hop_size # final mel will crop the last mel, mel = mel[:,:-1]
if segment_length is None or wav_length == segment_length:
return wav
elif wav_length > segment_length:
return wav[:,:segment_length]
elif wav_length < segment_length:
temp_wav = torch.zeros((1, segment_length),dtype=torch.float32)
temp_wav[:, :wav_length] = wav
return temp_wav
def __getitem__(self, index):
audio_path = self.audio_paths[index]
wav, orisr = torchaudio.load(audio_path)
if wav.shape[0] != 1: # stereo to mono (2,wav_len) -> (1,wav_len)
wav = wav.mean(0,keepdim=True)
wav = torchaudio.functional.resample(wav, orig_freq=orisr, new_freq=self.sr)
if self.mode == 'pad':
assert self.target_mel_length is not None
wav = self.pad_wav(wav)
return audio_path,wav
def process_audio_by_tsv(rank,args):
if args.num_gpus > 1:
init_process_group(backend=args.dist_config['dist_backend'], init_method=args.dist_config['dist_url'],
world_size=args.dist_config['world_size'] * args.num_gpus, rank=rank)
sr = args.audio_sample_rate
dataset = tsv_dataset(args.tsv_path,sr = sr,mode=args.mode,hop_size=args.hop_size,target_mel_length=args.batch_max_length)
sampler = DistributedSampler(dataset,shuffle=False) if args.num_gpus > 1 else None
# batch_size must == 1,since wav_len is not equal
loader = DataLoader(dataset, sampler=sampler,batch_size=1, num_workers=16,drop_last=False)
device = torch.device('cuda:{:d}'.format(rank))
mel_net = MelNet(args.__dict__)
mel_net.to(device)
# if args.num_gpus > 1: # RuntimeError: DistributedDataParallel is not needed when a module doesn't have any parameter that requires a gradient.
# mel_net = DistributedDataParallel(mel_net, device_ids=[rank]).to(device)
loader = tqdm(loader) if rank == 0 else loader
for batch in loader:
audio_paths,wavs = batch
wavs = wavs.to(device)
if args.save_resample:
for audio_path,wav in zip(audio_paths,wavs):
psplits = audio_path.split('/')
root,wav_name = psplits[0],psplits[-1]
# save resample
resample_root,resample_name = root+f'_{sr}',wav_name[:-4]+'_audio.npy'
resample_dir_name = os.path.join(resample_root,*psplits[1:-1])
resample_path = os.path.join(resample_dir_name,resample_name)
os.makedirs(resample_dir_name,exist_ok=True)
np.save(resample_path,wav.cpu().numpy().squeeze(0))
if args.save_mel:
mode = args.mode
batch_max_length = args.batch_max_length
for audio_path,wav in zip(audio_paths,wavs):
psplits = audio_path.split('/')
root,wav_name = psplits[0],psplits[-1]
mel_root,mel_name = root+f'_mel{mode}{sr}nfft{args.fft_size}',wav_name[:-4]+'_mel.npy'
mel_dir_name = os.path.join(mel_root,*psplits[1:-1])
mel_path = os.path.join(mel_dir_name,mel_name)
if not os.path.exists(mel_path):
mel_spec = mel_net(wav).cpu().numpy().squeeze(0) # (mel_bins,mel_len)
if mel_spec.shape[1] <= batch_max_length:
if mode == 'tile': # pad is done in dataset as pad wav
n_repeat = math.ceil((batch_max_length + 1) / mel_spec.shape[1])
mel_spec = np.tile(mel_spec,reps=(1,n_repeat))
elif mode == 'none' or mode == 'pad':
pass
else:
raise ValueError(f'mode:{mode} is not supported')
mel_spec = mel_spec[:,:batch_max_length]
os.makedirs(mel_dir_name,exist_ok=True)
np.save(mel_path,mel_spec)
def split_list(i_list,num):
each_num = math.ceil(i_list / num)
result = []
for i in range(num):
s = each_num * i
e = (each_num * (i+1))
result.append(i_list[s:e])
return result
def drop_bad_wav(item):
index,path = item
try:
with audioread.audio_open(path) as f:
totalsec = f.duration
if totalsec < 0.1:
return index # index
except:
print(f"corrupted wav:{path}")
return index
return False
def drop_bad_wavs(tsv_path):# 'audioset.csv'
df = pd.read_csv(tsv_path,sep='\t')
item_list = []
for item in tqdm(df.itertuples()):
item_list.append((item[0],getattr(item,'audio_path')))
r = process_map(drop_bad_wav,item_list,max_workers=16,chunksize=16)
bad_indices = list(filter(lambda x:x!= False,r))
print(bad_indices)
with open('bad_wavs.json','w') as f:
x = [item_list[i] for i in bad_indices]
json.dump(x,f)
df = df.drop(bad_indices,axis=0)
df.to_csv(tsv_path,sep='\t',index=False)
if __name__ == '__main__':
logging.basicConfig(filename='example.log', level=logging.INFO,
format='%(asctime)s %(message)s', datefmt='%m/%d/%Y %I:%M:%S %p')
tsv_path = './musiccap.tsv'
if os.path.isdir(tsv_path):
files = glob(os.path.join(tsv_path,'*.tsv'))
for file in files:
drop_bad_wavs(file)
else:
drop_bad_wavs(tsv_path)
num_gpus = 1
args = {
'audio_sample_rate': 16000,
'audio_num_mel_bins':80,
'fft_size': 1024,# 4000:512 ,16000:1024,
'win_size': 1024,
'hop_size': 256,
'fmin': 0,
'fmax': 8000,
'batch_max_length': 1560, # 4000:312 (nfft = 512,hoplen=128,mellen = 313), 16000:624 , 22050:848 #
'tsv_path': tsv_path,
'num_gpus': num_gpus,
'mode': 'none',
'save_resample':False,
'save_mel' :True
}
args = Namespace(**args)
args.dist_config = {
"dist_backend": "nccl",
"dist_url": "tcp://localhost:54189",
"world_size": 1
}
if args.num_gpus>1:
mp.spawn(process_audio_by_tsv,nprocs=args.num_gpus,args=(args,))
else:
process_audio_by_tsv(0,args=args)
print("done")